-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
378 lines (291 loc) · 13.5 KB
/
utils.py
File metadata and controls
378 lines (291 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import torch, os, sys
from torch import Tensor
from torch_geometric.datasets import MD17
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from torch_geometric.nn.conv import MessagePassing
from torch.optim import Optimizer, SGD
import matplotlib.pyplot as plt
from typing import Callable
import numpy as np
from scipy.spatial.transform import Rotation
torch.set_default_dtype(torch.double)
def sanity_check(model: MessagePassing, rho:float=1-1e-2, num_items:int=1024, batch_size:int=32, num_epochs:int=10) -> None:
"""puts the model through a very brief training run to check for elementary bugs, making a matplotlib plot of loss.
parameters
----------
model : MesagePassing
the model to be checked.
rho : float, optional
default : 1-1e-2.
loss = (1-rho)*E_loss + rho*F_loss.
num_items : int, optional
default : 1024
the number of data items in each epoch.
batch_size : int, optional
default : 32
self-explanatory.
num_epochs : int, optional
default : 10
self-explanatory.
returns
-------
None; prints matplotlib plot of loss.
"""
# make dataloader
dataloader = get_mini_dataloader(version='alcatraz', molecule='benzene', num_items=num_items, batch_size=batch_size)
# SGD for maximal simplicity
optimizer = SGD(model.parameters(), lr=0.001)
# MSE for maximal simplicity
loss_fn = torch.nn.MSELoss()
# track losses
losses = []
# training loop
for _ in range(num_epochs):
for data in dataloader:
# clear gradients
optimizer.zero_grad()
# target values
E = data.energy
F = data.force
# predictions from the model
E_hat, F_hat = model(data)
E_hat.squeeze_(dim=1)
# squared error for energy loss
E_loss = loss_fn(E_hat, E)
# a version of squared error for force loss
F_loss = F_loss_fn(F_hat, F, loss_fn)
# canonical loss
loss = (1-rho)*E_loss + rho*F_loss
# calculate gradients
loss.backward()
# update
optimizer.step()
# track losses
losses.append(loss.item())
# make plot of losses to check for convergence
plt.plot(range(len(losses)), losses)
plt.show()
def F_loss_fn(F: Tensor, F_hat: Tensor, loss_fn: Callable) -> Tensor:
"""calculates the atomwise Euclidean distances between the predicted and actual force vectors and returns a loss via loss_fn.
parameters
----------
F : Tensor
target atomwise force vector.
dimensions are [3, num_atoms].
F_hat : Tensor
predicted atomwise force vector.
dimensions are [3, num_atoms].
loss_fn : Callable
takes in [num_atoms] tensor and returns 1-item loss tensor.
returns:
--------
1-item Tensor containing loss.
"""
# avoid bugs when the parameters do not make sense
assert F.size() == F_hat.size(), f'expected F and F_hat to be the same size. got F.size()={F.size()} and F_hat.size()={F_hat.size()}'
# Euclidean distance between the target and predicted force vectors
F_error = torch.sqrt(torch.sum(torch.square(F - F_hat), dim=1))
# takes in [num_atoms] tensor and returns 1-item loss tensor
F_loss = loss_fn(F_error, torch.zeros_like(F_error))
# return F_loss
return F_loss
def bessel_rbf(x: Tensor, n: int, r_cut: float) -> Tensor:
"""takes in a tensor representing distance and expands it into a vector (tensor) in a Bessel radial basis.
formula for a Bessel radial basis function:
..math::
\\sin(\\frac{(n\\pi)}{r_{\\mathrm{cut}}} \\Vert \\vec{r}_{ij} \\Vert) / \\Vert \\vec{r}_{ij} \\Vert.
notation consistent with page 4 of https://arxiv.org/pdf/2102.03150, which follows the lead of page 5 of https://arxiv.org/pdf/2003.03123.
this method creates `n` Bessel radial basis functions and returns a vector (tensor) whose i-th element is the value of `x` written in the i-th basis element.
parameters
----------
x : Tensor
1-element tensor, representing distance, to be expanded in the Bessel radial basis.
n : int
cardinality of Bessel basis.
r_cut : float
cutoff distance, representing the maximum distance between two connected nodes.
returns
-------
vector (Tensor) representing input distance in a Bessel radial basis, as specified in function call.
"""
# frequency tensor of length n
ns = torch.arange(1, n+1).view(1,-1).double()
# output as defined in Bessel radial basis function equation
out = torch.div(torch.sin(torch.div(torch.matmul(x.double(),ns) * torch.tensor(torch.pi).double(), torch.tensor(r_cut).double())), x.double())
# return
return out
def cosine_cutoff(x: Tensor, r_cut: float) -> Tensor:
"""takes in a tensor representing distance and returns its coefficient under a cosine cutoff.
formula for a cosine cutoff function:
..math::
0.5 \\cos(\\frac{\\pi x}{r_{\mathrm{cut}}} + 1)
for :math:`x \leq r_{\mathrm{cut}}`, and :math:`0` for :math:`x > r_{\mathrm{cut}}`.
it is desirable to set the value of the basis for all values greater than `r_cut` to 0 without introducing a discontinuity at `r_cut`.
cosine cutoff maps 0 to 1, leaving distances near 0 minimally affected, and maps `r_cut` to 0, giving distances slightly smaller than `r_cut` values near 0. it maps values greater than `r_cut` uniformly to 0.
it is :math:`C^\\infty`, which allows it to interact nicely with all basis functions.
parameters
----------
x : Tensor
1-element tensor, representing basis, whose coefficient under a cosine cutoff is to be calculated.
r_cut : float
maximum distance between connected nodes.
returns
-------
Tensor representing coefficient of input distance under cosine cutoff with `r_cut` as specified in function call.
"""
# f(0) = 1 and f(r_cut) = 0 smoothly
cutoff_distances = 0.5 * (torch.cos(torch.pi * x / r_cut) + 1).double()
# truncate everything beyond r_cut
cutoff_distances[x > r_cut] = 0.0
# return
return cutoff_distances
def get_random_roto_reflection_matrix() -> Tensor:
"""
"""
# generate a random rotation using scipy's Rotation module
rotation = Rotation.random()
rotation_matrix = torch.tensor(rotation.as_matrix()).double()
if np.random.rand() > 0.5:
roto_reflection_matrix = -rotation_matrix
else:
roto_reflection_matrix = rotation_matrix
print('Random roto-reflection:')
for row in roto_reflection_matrix:
print(f'\t{row[0].item(): 6.3f} {row[1].item(): 6.3f} {row[2].item(): 6.3f}')
return roto_reflection_matrix
def get_random_translation_vector(max_translation=1.0) -> Tensor:
"""
"""
# generate a random translation vector within the range [-1.0, 1.0]
translation = torch.tensor(np.random.uniform(-1, 1, size=3)).double()
# translation = torch.zeros(3)
print('Random translation:')
for coordinate in translation:
print(f'\t{coordinate.item(): 6.3}')
return translation
def get_random_roto_reflection_translation() -> [Tensor, Tensor]:
"""
"""
return [get_random_roto_reflection_matrix(), get_random_translation_vector()]
def E3_transform_molecule(molecule: Data, roto_reflection_translation: [Tensor, Tensor]) -> Data:
"""
"""
new_molecule = molecule.clone()
new_molecule.pos = torch.matmul(roto_reflection_translation[0], new_molecule.pos.double().transpose(0,1)).transpose(0,1)
new_molecule.pos = new_molecule.pos + roto_reflection_translation[1]
return new_molecule
def E3_transform_force(force_tensor: Tensor, roto_reflection_translation: [Tensor, Tensor]) -> Tensor:
"""
"""
force_tensor = torch.matmul(roto_reflection_translation[0], force_tensor.transpose(0,1)).transpose(0,1)
return force_tensor
def plot_molecules(molecules: [Data], colors: [str], labels: [str], title='Benzene Molecule with Bonds') -> None:
"""
"""
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for molecule, color, label in zip(molecules, colors, labels):
# makes debugging less annoying when I want to see the molecule after I have put it through the model
molecule = molecule.detach()
x, y, z = zip(*molecule.pos)
ax.scatter(x, y, z, c=color, marker='o', label=label)
edge_pair_list = [[i.item(), j.item()] for i, j in zip(molecule.edge_index[0], molecule.edge_index[1])]
for edge_pair in edge_pair_list:
x_values = [molecule.pos[edge_pair][0][0].item(), molecule.pos[edge_pair][1][0].item()]
y_values = [molecule.pos[edge_pair][0][1].item(), molecule.pos[edge_pair][1][1].item()]
z_values = [molecule.pos[edge_pair][0][2].item(), molecule.pos[edge_pair][1][2].item()]
ax.plot(x_values, y_values, z_values, c=color)
if labels != ['']:
ax.legend()
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(title)
plt.show()
def plot_molecules_with_forces(molecules, forces, colors, labels, title='Benzene Molecule with Bonds and Atomwise Forces'):
"""
"""
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for molecule, force, color, label in zip(molecules, forces, colors, labels):
# makes debugging less annoying when I want to see the molecule after I have put it through the model
molecule = molecule.detach()
x, y, z = zip(*molecule.pos)
dir_x, dir_y, dir_z = zip(*force)
ax.scatter(x, y, z, c=color[0], marker='o', label=label)
edge_pair_list = [[i.item(), j.item()] for i, j in zip(molecule.edge_index[0], molecule.edge_index[1])]
for edge_pair in edge_pair_list:
x_values = [molecule.pos[edge_pair][0][0].item(), molecule.pos[edge_pair][1][0].item()]
y_values = [molecule.pos[edge_pair][0][1].item(), molecule.pos[edge_pair][1][1].item()]
z_values = [molecule.pos[edge_pair][0][2].item(), molecule.pos[edge_pair][1][2].item()]
ax.plot(x_values, y_values, z_values, c=color[0])
ax.quiver(x, y, z, dir_x, dir_y, dir_z, normalize=False, color=color[1], arrow_length_ratio=0.5, pivot='tip')
if labels != ['']:
ax.legend()
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(title)
plt.show()
def get_mini_dataloader(molecule: str, num_items: int, batch_size: int) -> DataLoader:
"""returns a DataLoader object as specified in function call; especially useful for getting small DataLoader objects to use in experimentation.
parameters
----------
molecule : str
which of molecule datasets (benzene, uracil, aspirin) to fetch.
num_items : int
self-explanatory.
batch_size : int
self-explanatory.
returns
-------
DataLoader object as specified in function call.
"""
# load in the dataset
dataset = MD17(root='data/', name=f'{molecule}', pre_transform=None, force_reload=False)
# make mini_dataset out of dataset
mini_dataset, _ = random_split(dataset, [num_items, len(dataset)-num_items])
# make min_dataloader out of mini_dataset
mini_dataloader = DataLoader(mini_dataset, batch_size=batch_size)
# return DataLoader
return mini_dataloader
def get_molecule(type: str) -> Data:
"""returns the rirst item in dataset of type of molecule specified in function call.
parameters
----------
molecule : str
molecule whose first instance in dataset will be returned.
returns
-------
first data item in dataset of molecule specified in function call.
"""
return MD17(root='data/', name=f'{type}', pre_transform=None, force_reload=False)[0]
def make_v0(pos: Tensor, edge_index: Tensor, emb_dim: int) -> Tensor:
"""returns [num_nodes x 3 x emb_dim] tensor, where each node's tensor is sum of all its outgoing edge vectors copied `emb_dim` times.
note that `pos` is passed as a parameter into function call so gradients of this operation can be included in computational graph for backprop.
parameters
----------
pos : Tensor
self-explanatory.
edge_index : Tensor
self-explanatory.
returns
-------
[num_nodes x 3 x emb_dim] tensor, where each node's tensor is sum of all its outgoing edge vectors copied `emb_dim` times.
"""
# unpack edge_index
idx1, idx2 = edge_index
# vectorized calculation of edge vectors
edge_vectors = pos[idx2] - pos[idx1]
# determines dimension 0 of output tensor
num_nodes = pos.size(0)
# make container for outgoing vector sums
outgoing_edge_vector_sums = torch.zeros((num_nodes, 3), dtype=pos.dtype)
# fancy PyTorch adding function
outgoing_edge_vector_sums = outgoing_edge_vector_sums.scatter_add(0, idx1.unsqueeze(1).expand(-1, 3), edge_vectors) # EQUIVARIANT OPERATION: adding equivariant vectors
# copy over emb_dim
v0 = outgoing_edge_vector_sums.unsqueeze(2).expand(-1,-1,emb_dim)
# return result
return v0