-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathConv2dSOT.py
More file actions
27 lines (22 loc) · 1.3 KB
/
Conv2dSOT.py
File metadata and controls
27 lines (22 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from math import log2, ceil
import torch
from SOT import SOT
class Conv2dSOT(SOT):
def __init__(self, number_of_leaves: int, kernel_size: int, stride: int = 2, padding: int = 0, lr: float = 0.3, device = torch.device("cpu")):
super().__init__(number_of_leaves, kernel_size * kernel_size, lr, device)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def img2patches(self, x):
padded = torch.nn.functional.pad(x, [self.padding] * 4, "constant", 0)
p = padded.unfold(0, self.kernel_size, self.stride).unfold(1, self.kernel_size, self.stride)
rows, cols = p.shape[0], p.shape[1]
out = p.reshape(p.shape[0] * p.shape[1], p.shape[2] * p.shape[3])
return out, rows, cols
def forward(self, X):
X, rows, cols = self.img2patches(X)
indices, bmu_indices, bmu_dists = self._propagate_through_tree(X, patch_number = X.shape[0])
neighborhood_lrs = torch.gather(input=self.learning_rates, index= indices.flatten(), dim=0).reshape(indices.shape)
neighborhood_updates = (neighborhood_lrs.unsqueeze(2).to(self.device) * (X.unsqueeze(1).to(self.device) - self.nodes[1:,:].unsqueeze(0).to(self.device))).mean(0)
self.nodes[1:, :] += neighborhood_updates
return bmu_indices.reshape(rows, cols)