-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathprocess_caption.py
More file actions
150 lines (119 loc) · 5.12 KB
/
process_caption.py
File metadata and controls
150 lines (119 loc) · 5.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import nltk
import string
import pickle
from transformers import BertTokenizer, BertModel
import torch
from tqdm import tqdm
from scipy.spatial.distance import cosine
from datasets.dataset_vars import ADE20K_SEM_SEG_FULL_CATEGORIES as ADE20K_CATEGORIES
from itertools import chain
from utils.utilsSAM import read_pickle
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk import pos_tag
# Download stopwords and tokenizer models if not already done
# nltk.download('stopwords')
# nltk.download('punkt_tab')
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger_eng')
PATH_CAPTION = "datasets/captions_val/coco_captions.pkl"
def extract_noun_phrases(text):
tokens = word_tokenize(text)
tokens = [token for token in tokens if token not in string.punctuation]
tagged = pos_tag(tokens)
# print(tagged)
grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}'
cp = nltk.RegexpParser(grammar)
result = cp.parse(tagged)
nouns = []
for subtree in result.subtrees():
if subtree.label() == 'NP':
for noun, type in subtree.leaves():
nouns.append(noun)
return set(nouns)
def load_bert_model():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
return tokenizer, model
# Compute cosine similarity
def compute_similarity(embedding1, embedding2):
similarity = 1 - cosine(embedding1.numpy(), embedding2.numpy())
return similarity
def get_ade_dict(ade_gt, tokenizer, model):
ade_encoding = {}
ade_name_to_id = {}
ade_id_to_name = {}
for sample in ade_gt:
name, id, trainId = sample['name'], sample['id'], sample['trainId']
ade_name_to_id[name] = trainId
ade_id_to_name[trainId] = name
ade_encoding[name] = encode_text(name, tokenizer, model)
# Sort the encoding by trainId
sorted_ade_encoding = sorted(ade_encoding.items(), key=lambda item: ade_name_to_id[item[0]])
# Convert sorted_ade_encoding to a tensor
ade_encoding_tensor = torch.stack([item[1] for item in sorted_ade_encoding])
return ade_encoding_tensor, ade_name_to_id, ade_id_to_name
def update_old_vocab(ade_gt, new_voc):
# ade_gt list[dict] --> with keys (name, id, trainId)
# Load the model
tokenizer, model = load_bert_model()
# Encode the old vocab and store it in a dictionary
# Ordered list of the old vocab by trainId
ade_vocab = sorted(ade_gt, key=lambda x: x['trainId'])
# Encode the old vocab
encoded_ade = encode_word_list([sample['name'] for sample in ade_vocab], tokenizer, model)
# Encode the new vocab in batches of BATCH_SIZE
BATCH_SIZE = 64
new_voc = list(new_voc)
# Init with -2 to indicate its the original word
ret_voc = {int(sample['trainId']) : (sample['name'], -2) for sample in ade_vocab}
new_voc_batches = [new_voc[i:i + BATCH_SIZE] for i in range(0, len(new_voc), BATCH_SIZE)]
for batch in tqdm(new_voc_batches, desc="Processing new vocab", total=len(new_voc_batches)):
new_voc_embeddings = (encode_word_list(batch, tokenizer, model))
# Compute the similarity between the new vocab and the old vocab
similarity = (new_voc_embeddings @ encoded_ade.T)
# Find the most similar old vocab for each new vocab
max_similarity, max_indices = similarity.topk(1, dim=1)
for i, (sim, idx) in enumerate(zip(max_similarity, max_indices)):
idx = idx.item()
# if sim > 0.5:
if ret_voc[idx][1] < sim.item():
ret_voc[int(idx)] = (batch[i], sim.item())
# Substitute all sim -2 with 1
unchanged_words = 0
for key, value in ret_voc.items():
if value[1] == -2:
ret_voc[key] = (value[0], (value[0], 1))
unchanged_words += 1
if ret_voc[key][0] == ade_gt[key]['name']:
ret_voc[key] = (ret_voc[key][0], 1)
unchanged_words += 1
print(f"Unchanged words: {unchanged_words}")
return ret_voc
def encode_word_list(word_list, tokenizer, model):
inputs = tokenizer(word_list, return_tensors="pt", padding=True, truncation=True, max_length=128, add_special_tokens=True)
with torch.no_grad():
outputs = model(**inputs)
# Use the [CLS] token embedding for each word
cls_embeddings = outputs.last_hidden_state[:, 0, :] # Shape: [batch_size, hidden_size]
return cls_embeddings / torch.norm(cls_embeddings, dim=1, keepdim=True)
def strip_noun(noun):
if 'a ' in noun:
noun = noun.replace('a ', '')
elif 'an ' in noun:
noun = noun.replace('an ', '')
elif 'the ' in noun:
noun = noun.replace('the ', '')
return noun
if __name__ == '__main__':
captions = read_pickle(PATH_CAPTION)
text = " ".join(captions)
nouns = extract_noun_phrases(text)
print(list(nouns), len(nouns))
# save captions in a pickle file
with open("datasets/captions_val/nouns_ade.pkl", "wb") as f:
pickle.dump(list(nouns), f)
nouns = [strip_noun(noun) for noun in nouns]
print(sorted(list(set(nouns))))
new_vocab = update_old_vocab(ADE20K_CATEGORIES, nouns)
print(new_vocab)