diff --git a/vanish_nodes.py b/vanish_nodes.py index 2baa6ee..7de6a9e 100644 --- a/vanish_nodes.py +++ b/vanish_nodes.py @@ -59,19 +59,23 @@ def execute( ) -> io.NodeOutput: if mask is None and image_as_mask is None: raise ValueError("Either 'mask' or 'image_as_mask' must be provided.") - + if mask is None: mask = image_as_mask.mean(dim=-1) - + if mask.ndim == 4: mask = mask[:, :, :, 0] + from comfy import model_management + device = model_management.get_torch_device() + mask = mask.to(device) + s_kernel = spatial_radius * 2 + 1 t_kernel = temporal_radius * 2 + 1 # Separable dilation: 2D spatial + 1D temporal (much faster than 3D pooling) if s_kernel > 1: - mask = mask.unsqueeze(1) # (B, 1, H, W) + mask = mask.unsqueeze(1) # (B, 1, H, W) mask = F.max_pool2d( mask, kernel_size=s_kernel, stride=1, padding=spatial_radius ) @@ -86,6 +90,7 @@ def execute( mask = mask.reshape(H, W, B).permute(2, 0, 1) mask = (mask > 0.5).float() + mask = mask.cpu() return io.NodeOutput(mask)