Thanks for sharing your work!
|
def spatial_pool(self, x): |
|
input_x = self.conv_v_right(x) |
|
|
|
batch, channel, height, width = input_x.size() |
|
|
|
# [N, IC, H*W] |
|
input_x = input_x.view(batch, channel, height * width) |
|
|
|
# [N, 1, H, W] |
|
context_mask = self.conv_q_right(x) |
|
|
|
# [N, 1, H*W] |
|
context_mask = context_mask.view(batch, 1, height * width) |
|
|
|
# [N, 1, H*W] |
|
context_mask = self.softmax_right(context_mask) |
|
|
|
# [N, IC, 1] |
|
# context = torch.einsum('ndw,new->nde', input_x, context_mask) |
|
context = torch.matmul(input_x, context_mask.transpose(1,2)) |
|
# [N, IC, 1, 1] |
|
context = context.unsqueeze(-1) |
|
|
|
# [N, OC, 1, 1] |
|
context = self.conv_up(context) |
|
|
|
# [N, OC, 1, 1] |
|
mask_ch = self.sigmoid(context) |
|
|
|
out = x * mask_ch |
|
|
|
return out |

It seems that spatial_pool function is the same with Channel-only self attention module.
Thanks for sharing your work!
PSA/semantic-segmentation/network/PSA.py
Lines 64 to 95 in 588b370
It seems that spatial_pool function is the same with Channel-only self attention module.