Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 104 additions & 6 deletions src/models/FeatureBank.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def __init__(

@property
def features(self):
return self.memory[:self.num_pos]
return self.memory[: self.num_pos]

@property
def clutter(self):
return self.memory[self.num_pos:]
return self.memory[self.num_pos :]

def features_of(self, class_, n_list_set):
start = class_*self.single_feature_dim
return self.memory[start:start+n_list_set[class_]]
start = class_ * self.single_feature_dim
return self.memory[start : start + n_list_set[class_]]

def class_id_from_vertex_id(self, vertex_id):
return vertex_id // self.single_feature_dim
Expand Down Expand Up @@ -276,6 +276,104 @@ def forward(self, x, y, visible, img_label):
label_weight_onehot,
)

def forward_siglip(self, x, y, visible, img_label):
# k = feature_space
# n = batch_size
n_pos = self.num_pos
n_neg = self.num_noise
count_label = torch.bincount(img_label, minlength=self.nb_classes)
label_weight_onehot = fun_label_onehot(img_label, count_label, self.nb_classes)

# set max_group is it is explicitly given. (we give max_group as 512, so not in our case)
if (
self.max_lru == -1
and n_neg > 0
and x.shape[0] <= (self.nlem - n_pos) / n_neg
):
self.max_lru = (self.memory.shape[0] - n_pos) // (n_neg * x.shape[0])

momentum = self.params[3].item()

with torch.set_grad_enabled(false):
# [n, k, k]
# # change label to one-hot format. 1 at diagonal.
# y_onehot = one_hot(y, n_class).view(x.shape[0], -1, n_class)

get = torch.matmul(
label_weight_onehot.transpose(0, 1),
(
x[:, 0 : self.single_feature_dim, :]
* visible.type(x.dtype).view(*visible.shape, 1)
).view(x.shape[0], -1),
)
get = get.view(get.shape[0], -1, x.shape[-1])
# handle 0 in get, case that no img of one class is in the batch
tmp = (count_label == 0).nonzero(as_tuple=true)[0]
for i in tmp:
# copy memory to get
get[i] = self.memory[
i * self.single_feature_dim : (i + 1) * self.single_feature_dim
]
get = get.view(-1, x.shape[-1])

if n_neg > 0:
if x.shape[0] > (self.nlem - n_pos) / n_neg:
self.memory = f.normalize(
torch.cat(
[
self.memory[0:n_pos, :] * momentum
+ get * (1 - momentum),
x[:, self.single_feature_dim : :, :]
.contiguous()
.view(-1, x.shape[2])[0 : self.memory.shape[0] - n_pos],
],
dim=0,
),
dim=1,
p=2,
)
else:
# handle case if batchsize is not larger than max_group
# neg_parts updated based on tagging
neg_parts = torch.cat(
[
self.memory[
n_pos : n_pos + self.lru * n_neg * x.shape[0],
:,
],
x[:, self.single_feature_dim : :, :]
.contiguous()
.view(-1, x.shape[2]),
self.memory[
n_pos + (self.lru + 1) * n_neg * x.shape[0] : :,
:,
],
],
dim=0,
)

self.memory = f.normalize(
torch.cat(
[
self.memory[0:n_pos, :] * momentum
+ get * (1 - momentum),
neg_parts,
],
dim=0,
),
dim=1,
p=2,
)
else:
self.memory = f.normalize(
self.memory[0:n_pos, :] * momentum + get * (1 - momentum),
dim=1,
p=2,
)

self.lru += 1
self.lru = self.lru % self.max_lru

def set_zero(self, n_pos):
self.accumulate_num = torch.zeros(
n_pos,
Expand Down Expand Up @@ -313,6 +411,6 @@ def cuda(self, device=None):
super().cuda(device)
self.memory = self.memory.cuda(device)
return self

def load_memory(self, memory):
self.memory = memory
Loading