diff --git a/py/segment_anything_func.py b/py/segment_anything_func.py index 2be0d35..e6e47bf 100644 --- a/py/segment_anything_func.py +++ b/py/segment_anything_func.py @@ -15,8 +15,8 @@ from urllib.parse import urlparse import folder_paths import comfy.model_management -from sam_hq.predictor import SamPredictorHQ -from sam_hq.build_sam_hq import sam_model_registry +from .sam_hq.predictor import SamPredictorHQ +from .sam_hq.build_sam_hq import sam_model_registry import glob import folder_paths @@ -70,7 +70,7 @@ def get_bert_base_uncased_model_path(): def list_sam_model(): return list(sam_model_list.keys()) -def load_sam_model(model_name): +def load_sam_model(model_name, device=None): sam_checkpoint_path = get_local_filepath( sam_model_list[model_name]["model_url"], sam_model_dir_name) model_file_name = os.path.basename(sam_checkpoint_path) @@ -78,9 +78,13 @@ def load_sam_model(model_name): if 'hq' not in model_type and 'mobile' not in model_type: model_type = '_'.join(model_type.split('_')[:-1]) sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path) - sam_device = comfy.model_management.get_torch_device() - sam.to(device=sam_device) - sam.eval() + + if device: + sam.to(device=device) + else: + sam.to(device=comfy.model_management.get_torch_device()) + + # sam.eval() sam.model_name = model_file_name return sam @@ -104,7 +108,7 @@ def get_local_filepath(url, dirname, local_file_name=None): download_url_to_file(url, destination) return destination -def load_groundingdino_model(model_name): +def load_groundingdino_model(model_name, device=None): from local_groundingdino.util.utils import clean_state_dict as local_groundingdino_clean_state_dict from local_groundingdino.util.slconfig import SLConfig as local_groundingdino_SLConfig from local_groundingdino.models import build_model as local_groundingdino_build_model @@ -127,9 +131,13 @@ def load_groundingdino_model(model_name): ) dino.load_state_dict(local_groundingdino_clean_state_dict( checkpoint['model']), strict=False) - device = comfy.model_management.get_torch_device() - dino.to(device=device) - dino.eval() + + if device: + dino.to(device=device) + else: + dino.to(device=comfy.model_management.get_torch_device()) + + # dino.eval() return dino def list_groundingdino_model(): @@ -139,7 +147,8 @@ def groundingdino_predict( dino_model, image, prompt, - threshold + threshold, + device = None ): from local_groundingdino.datasets import transforms as T def load_dino_image(image_pil): @@ -158,8 +167,12 @@ def get_grounding_output(model, image, caption, box_threshold): caption = caption.strip() if not caption.endswith("."): caption = caption + "." - device = comfy.model_management.get_torch_device() - image = image.to(device) + + if device: + image = image.to(device=device) + else: + image = image.to(device=comfy.model_management.get_torch_device()) + with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) @@ -209,7 +222,8 @@ def split_image_mask(image): def sam_segment( sam_model, image, - boxes + boxes, + device=None ): if boxes.shape[0] == 0: return None @@ -222,12 +236,18 @@ def sam_segment( image_np_rgb = image_np[..., :3] predictor.set_image(image_np_rgb) transformed_boxes = predictor.transform.apply_boxes_torch( - boxes, image_np.shape[:2]) - sam_device = comfy.model_management.get_torch_device() + boxes, image_np.shape[:2] + ) + + if device: + transformed_boxes.to(device=device) + else: + transformed_boxes.to(device=comfy.model_management.get_torch_device()) + masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, - boxes=transformed_boxes.to(sam_device), + boxes=transformed_boxes, multimask_output=False) masks = masks.permute(1, 0, 2, 3).cpu().numpy() return create_tensor_output(image_np, masks, boxes) diff --git a/py/segment_anything_ultra_v2.py b/py/segment_anything_ultra_v2.py index 9494f72..462f878 100644 --- a/py/segment_anything_ultra_v2.py +++ b/py/segment_anything_ultra_v2.py @@ -59,12 +59,14 @@ def segment_anything_ultra_v2(self, image, sam_model, grounding_dino_model, thre local_files_only = False if self.previous_sam_model != sam_model or self.SAM_MODEL is None: - self.SAM_MODEL = load_sam_model(sam_model) + self.SAM_MODEL = load_sam_model(sam_model, device) self.previous_sam_model = sam_model if self.previous_dino_model != grounding_dino_model or self.DINO_MODEL is None: - self.DINO_MODEL = load_groundingdino_model(grounding_dino_model) + self.DINO_MODEL = load_groundingdino_model(grounding_dino_model, device) self.previous_dino_model = grounding_dino_model + self.SAM_MODEL.to(device) + self.DINO_MODEL.to(device) # SAM_MODEL = load_sam_model(sam_model) # DINO_MODEL = load_groundingdino_model(grounding_dino_model) ret_images = [] @@ -74,10 +76,10 @@ def segment_anything_ultra_v2(self, image, sam_model, grounding_dino_model, thre i = torch.unsqueeze(i, 0) i = pil2tensor(tensor2pil(i).convert('RGB')) _image = tensor2pil(i).convert('RGBA') - boxes = groundingdino_predict(self.DINO_MODEL, _image, prompt, threshold) + boxes = groundingdino_predict(self.DINO_MODEL, _image, prompt, threshold, device) if boxes.shape[0] == 0: break - (_, _mask) = sam_segment(self.SAM_MODEL, _image, boxes) + (_, _mask) = sam_segment(self.SAM_MODEL, _image, boxes, device) _mask = _mask[0] detail_range = detail_erode + detail_dilate if process_detail: @@ -116,5 +118,5 @@ def segment_anything_ultra_v2(self, image, sam_model, grounding_dino_model, thre } NODE_DISPLAY_NAME_MAPPINGS = { - "LayerMask: SegmentAnythingUltra V2": "LayerMask: SegmentAnythingUltra V2(Advance)", + "LayerMask: SegmentAnythingUltra V2": "LayerMask: SegmentAnythingUltra V2", }