I have the following network:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1,8,3,padding=1,bias=False)
self.batch1 = nn.BatchNorm2d(8,affine=False)
#self.conv2 = nn.Conv2d(8,16,3,padding=0,bias=False)
self.conv2offset = nn.Conv2d(8,2*3*3,3,padding=0,bias=False)
self.deform_conv2 = ConvOffset2d(8,16,3,padding=0,num_deformable_groups=1)
self.batch2 = nn.BatchNorm2d(16,affine=False)
self.pooling = nn.MaxPool2d(2)
self.fc1 = nn.Linear(6*6*16,10)
self.activation = nn.ReLU()
def forward(self,x):
x = self.conv1(x)
x = self.pooling(x)
x = self.batch1(x)
x = self.activation(x)
#x = self.conv2(x)
offset = self.conv2offset(x)
x = self.deform_conv2(x,offset)
x = self.pooling(x)
x = self.batch2(x)
x = self.activation(x)
logits = self.fc1(x.view(-1,6*6*16))
probas = F.softmax(logits, dim=1)
return logits, probas
I train it on MNIST for 2 batches. It takes 327 seconds to run (97.64% accuracy on test set).
Now if I remove deform conv and replace it with normal convolution (commented in the code above), it takes 19 seconds for 2 batches (97.54% accuracy on test set).
What do you think is the cause?
Pytorch v0.3.0
Python v3.6.1
I have the following network:
I train it on MNIST for 2 batches. It takes 327 seconds to run (97.64% accuracy on test set).
Now if I remove deform conv and replace it with normal convolution (commented in the code above), it takes 19 seconds for 2 batches (97.54% accuracy on test set).
What do you think is the cause?
Pytorch v0.3.0
Python v3.6.1