-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
31 lines (27 loc) · 1.08 KB
/
models.py
File metadata and controls
31 lines (27 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class ResNet4D(nn.Module):
def __init__(self, num_inputs=4, num_outputs=4, num_hidden_layers=4, num_neurons=50):
super(ResNet4D, self).__init__()
self.input_layer = nn.Linear(num_inputs, num_neurons)
self.hidden_layers = nn.ModuleList(
[nn.Linear(num_neurons, num_neurons) for _ in range(num_hidden_layers)]
)
self.output_layer = nn.Linear(num_neurons, num_outputs)
self.activation = Swish()
def forward(self, x, y, z, t):
x_norm = (x - 0.5) / 0.2887
y_norm = (y - 0.5) / 0.2887
z_norm = (z - 0.5) / 0.2887
t_norm = (t - 0.5) / 0.2887
inputs = torch.cat([x_norm, y_norm, z_norm, t_norm], dim=1)
out = self.activation(self.input_layer(inputs))
for layer in self.hidden_layers:
residual = out
out = self.activation(layer(out))
out = out + residual # Residual connection
out = self.output_layer(out)
return out