diff --git a/detic/modeling/text/text_encoder.py b/detic/modeling/text/text_encoder.py index 3ec5090..6e80f8d 100644 --- a/detic/modeling/text/text_encoder.py +++ b/detic/modeling/text/text_encoder.py @@ -171,11 +171,12 @@ def forward(self, captions): return features -def build_text_encoder(pretrain=True): +def build_text_encoder(pretrain=True, clip_download_root=None): text_encoder = CLIPTEXT() if pretrain: import clip - pretrained_model, _ = clip.load("ViT-B/32", device='cpu') + pretrained_model, _ = clip.load("ViT-B/32", device='cpu', + download_root=clip_download_root) state_dict = pretrained_model.state_dict() to_delete_keys = ["logit_scale", "input_resolution", \ "context_length", "vocab_size"] + \ @@ -186,4 +187,4 @@ def build_text_encoder(pretrain=True): print('Loading pretrained CLIP') text_encoder.load_state_dict(state_dict) # import pdb; pdb.set_trace() - return text_encoder \ No newline at end of file + return text_encoder diff --git a/detic/predictor.py b/detic/predictor.py index 047ed80..d3c43a6 100644 --- a/detic/predictor.py +++ b/detic/predictor.py @@ -11,28 +11,18 @@ from detectron2.engine.defaults import DefaultPredictor from detectron2.utils.video_visualizer import VideoVisualizer from detectron2.utils.visualizer import ColorMode, Visualizer -from pathlib import Path -from hashlib import md5 from .modeling.utils import reset_cls_test -def get_clip_embeddings(vocabulary, prompt='a '): - # NOTE: need hashing due to filename length limit - hash_value = md5("-".join(sorted(vocabulary)).encode()).hexdigest() - cache_file_path = f"/tmp/detic-clip-embeddings-{hash_value}.pt" - if Path(cache_file_path).exists(): - print(f"loading embeddings for {vocabulary} from {cache_file_path}") - return torch.load(cache_file_path) - else: - from detic.modeling.text.text_encoder import build_text_encoder - text_encoder = build_text_encoder(pretrain=True) - text_encoder.eval() - texts = [prompt + x for x in vocabulary] - emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() - print(f"saved embeddings for {vocabulary} to {cache_file_path}") - torch.save(emb, cache_file_path) - return emb +def get_clip_embeddings(vocabulary, prompt='a ', clip_download_root=None): + from detic.modeling.text.text_encoder import build_text_encoder + text_encoder = build_text_encoder(pretrain=True, + clip_download_root=clip_download_root) + text_encoder.eval() + texts = [prompt + x for x in vocabulary] + emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return emb BUILDIN_CLASSIFIER = { 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy', @@ -50,18 +40,22 @@ def get_clip_embeddings(vocabulary, prompt='a '): class VisualizationDemo(object): def __init__(self, cfg, args, - instance_mode=ColorMode.IMAGE, parallel=False): + instance_mode=ColorMode.IMAGE, parallel=False, + clip_download_root=None): """ Args: cfg (CfgNode): instance_mode (ColorMode): parallel (bool): whether to run the model in different processes from visualization. Useful since the visualization logic can be slow. + clip_download_root (str): Custom clip download root path """ + self.clip_download_root = clip_download_root if args.vocabulary == 'custom': self.metadata = MetadataCatalog.get("__unused") self.metadata.thing_classes = args.custom_vocabulary.split(',') - classifier = get_clip_embeddings(self.metadata.thing_classes) + classifier = get_clip_embeddings(self.metadata.thing_classes, + clip_download_root=self.clip_download_root) self._default_vocabulary = None else: self.metadata = MetadataCatalog.get( @@ -88,7 +82,8 @@ def change_vocabulary(self, vocab): """ self.metadata = MetadataCatalog.get("__unused+"+str(random.random())) self.metadata.thing_classes = vocab.split(',') - classifier = get_clip_embeddings(self.metadata.thing_classes) + classifier = get_clip_embeddings(self.metadata.thing_classes, + clip_download_root=self.clip_download_root) num_classes = len(self.metadata.thing_classes) reset_cls_test(self.predictor.model, classifier, num_classes)