Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions py/segment_anything_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from urllib.parse import urlparse
import folder_paths
import comfy.model_management
from sam_hq.predictor import SamPredictorHQ
from sam_hq.build_sam_hq import sam_model_registry
from .sam_hq.predictor import SamPredictorHQ
from .sam_hq.build_sam_hq import sam_model_registry

import glob
import folder_paths
Expand Down Expand Up @@ -70,17 +70,21 @@ def get_bert_base_uncased_model_path():
def list_sam_model():
return list(sam_model_list.keys())

def load_sam_model(model_name):
def load_sam_model(model_name, device=None):
sam_checkpoint_path = get_local_filepath(
sam_model_list[model_name]["model_url"], sam_model_dir_name)
model_file_name = os.path.basename(sam_checkpoint_path)
model_type = model_file_name.split('.')[0]
if 'hq' not in model_type and 'mobile' not in model_type:
model_type = '_'.join(model_type.split('_')[:-1])
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path)
sam_device = comfy.model_management.get_torch_device()
sam.to(device=sam_device)
sam.eval()

if device:
sam.to(device=device)
else:
sam.to(device=comfy.model_management.get_torch_device())

# sam.eval()
sam.model_name = model_file_name
return sam

Expand All @@ -104,7 +108,7 @@ def get_local_filepath(url, dirname, local_file_name=None):
download_url_to_file(url, destination)
return destination

def load_groundingdino_model(model_name):
def load_groundingdino_model(model_name, device=None):
from local_groundingdino.util.utils import clean_state_dict as local_groundingdino_clean_state_dict
from local_groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig
from local_groundingdino.models import build_model as local_groundingdino_build_model
Expand All @@ -127,9 +131,13 @@ def load_groundingdino_model(model_name):
)
dino.load_state_dict(local_groundingdino_clean_state_dict(
checkpoint['model']), strict=False)
device = comfy.model_management.get_torch_device()
dino.to(device=device)
dino.eval()

if device:
dino.to(device=device)
else:
dino.to(device=comfy.model_management.get_torch_device())

# dino.eval()
return dino

def list_groundingdino_model():
Expand All @@ -139,7 +147,8 @@ def groundingdino_predict(
dino_model,
image,
prompt,
threshold
threshold,
device = None
):
from local_groundingdino.datasets import transforms as T
def load_dino_image(image_pil):
Expand All @@ -158,8 +167,12 @@ def get_grounding_output(model, image, caption, box_threshold):
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
device = comfy.model_management.get_torch_device()
image = image.to(device)

if device:
image = image.to(device=device)
else:
image = image.to(device=comfy.model_management.get_torch_device())

with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
Expand Down Expand Up @@ -209,7 +222,8 @@ def split_image_mask(image):
def sam_segment(
sam_model,
image,
boxes
boxes,
device=None
):
if boxes.shape[0] == 0:
return None
Expand All @@ -222,12 +236,18 @@ def sam_segment(
image_np_rgb = image_np[..., :3]
predictor.set_image(image_np_rgb)
transformed_boxes = predictor.transform.apply_boxes_torch(
boxes, image_np.shape[:2])
sam_device = comfy.model_management.get_torch_device()
boxes, image_np.shape[:2]
)

if device:
transformed_boxes.to(device=device)
else:
transformed_boxes.to(device=comfy.model_management.get_torch_device())

masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(sam_device),
boxes=transformed_boxes,
multimask_output=False)
masks = masks.permute(1, 0, 2, 3).cpu().numpy()
return create_tensor_output(image_np, masks, boxes)
12 changes: 7 additions & 5 deletions py/segment_anything_ultra_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def segment_anything_ultra_v2(self, image, sam_model, grounding_dino_model, thre
local_files_only = False

if self.previous_sam_model != sam_model or self.SAM_MODEL is None:
self.SAM_MODEL = load_sam_model(sam_model)
self.SAM_MODEL = load_sam_model(sam_model, device)
self.previous_sam_model = sam_model
if self.previous_dino_model != grounding_dino_model or self.DINO_MODEL is None:
self.DINO_MODEL = load_groundingdino_model(grounding_dino_model)
self.DINO_MODEL = load_groundingdino_model(grounding_dino_model, device)
self.previous_dino_model = grounding_dino_model

self.SAM_MODEL.to(device)
self.DINO_MODEL.to(device)
# SAM_MODEL = load_sam_model(sam_model)
# DINO_MODEL = load_groundingdino_model(grounding_dino_model)
ret_images = []
Expand All @@ -74,10 +76,10 @@ def segment_anything_ultra_v2(self, image, sam_model, grounding_dino_model, thre
i = torch.unsqueeze(i, 0)
i = pil2tensor(tensor2pil(i).convert('RGB'))
_image = tensor2pil(i).convert('RGBA')
boxes = groundingdino_predict(self.DINO_MODEL, _image, prompt, threshold)
boxes = groundingdino_predict(self.DINO_MODEL, _image, prompt, threshold, device)
if boxes.shape[0] == 0:
break
(_, _mask) = sam_segment(self.SAM_MODEL, _image, boxes)
(_, _mask) = sam_segment(self.SAM_MODEL, _image, boxes, device)
_mask = _mask[0]
detail_range = detail_erode + detail_dilate
if process_detail:
Expand Down Expand Up @@ -116,5 +118,5 @@ def segment_anything_ultra_v2(self, image, sam_model, grounding_dino_model, thre
}

NODE_DISPLAY_NAME_MAPPINGS = {
"LayerMask: SegmentAnythingUltra V2": "LayerMask: SegmentAnythingUltra V2(Advance)",
"LayerMask: SegmentAnythingUltra V2": "LayerMask: SegmentAnythingUltra V2",
}