-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathCustomizedLinear.py
More file actions
144 lines (118 loc) · 5.27 KB
/
CustomizedLinear.py
File metadata and controls
144 lines (118 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
extended torch.nn module which cusmize connection.
This code base on https://pytorch.org/docs/stable/notes/extending.html
"""
import math
import torch
import torch.nn as nn
#################################
# Define custome autograd function for masked connection.
class CustomizedLinearFunction(torch.autograd.Function):
"""
autograd function which masks it's weights by 'mask'.
"""
# Note that both forward and backward are @staticmethods
@staticmethod
# bias, mask is an optional argument
def forward(ctx, input, weight, bias=None, mask=None):
if mask is not None:
# change weight to 0 where mask == 0
weight = weight * mask
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
ctx.save_for_backward(input, weight, bias, mask)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias, mask = ctx.saved_tensors
grad_input = grad_weight = grad_bias = grad_mask = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if mask is not None:
# change grad_weight to 0 where mask == 0
grad_weight = grad_weight * mask
#if bias is not None and ctx.needs_input_grad[2]:
if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight, grad_bias, grad_mask
class CustomizedLinear(nn.Module):
def __init__(self, mask, bias=True):
"""
extended torch.nn module which mask connection.
Argumens
------------------
mask [torch.tensor]:
the shape is (n_input_feature, n_output_feature).
the elements are 0 or 1 which declare un-connected or
connected.
bias [bool]:
flg of bias.
"""
super(CustomizedLinear, self).__init__()
self.input_features = mask.shape[0]
self.output_features = mask.shape[1]
if isinstance(mask, torch.Tensor):
self.mask = mask.type(torch.float).t()
else:
self.mask = torch.tensor(mask, dtype=torch.float).t()
self.mask = nn.Parameter(self.mask, requires_grad=False)
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = nn.Parameter(torch.Tensor(self.output_features, self.input_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(self.output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
self.reset_parameters()
# mask weight
self.weight.data = self.weight.data * self.mask
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return CustomizedLinearFunction.apply(input, self.weight, self.bias, self.mask)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
if __name__ == 'check grad':
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
customlinear = CustomizedLinearFunction.apply
input = (
torch.randn(20,20,dtype=torch.double,requires_grad=True),
torch.randn(30,20,dtype=torch.double,requires_grad=True),
None,
None,
)
test = gradcheck(customlinear, input, eps=1e-6, atol=1e-4)
print(test)