forked from HenryHuYu/DiffPhysDrone
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
40 lines (33 loc) · 1.18 KB
/
model.py
File metadata and controls
40 lines (33 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch import nn
def g_decay(x, alpha):
return x * alpha + x.detach() * (1 - alpha)
class Model(nn.Module):
def __init__(self, dim_obs=9, dim_action=4) -> None:
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(1, 32, 2, 2, bias=False), # 1, 12, 16 -> 32, 6, 8
nn.LeakyReLU(0.05),
nn.Conv2d(32, 64, 3, bias=False), # 32, 6, 8 -> 64, 4, 6
nn.LeakyReLU(0.05),
nn.Conv2d(64, 128, 3, bias=False), # 64, 4, 6 -> 128, 2, 4
nn.LeakyReLU(0.05),
nn.Flatten(),
nn.Linear(128*2*4, 192, bias=False),
)
self.v_proj = nn.Linear(dim_obs, 192)
self.v_proj.weight.data.mul_(0.5)
self.gru = nn.GRUCell(192, 192)
self.fc = nn.Linear(192, dim_action, bias=False)
self.fc.weight.data.mul_(0.01)
self.act = nn.LeakyReLU(0.05)
def reset(self):
pass
def forward(self, x: torch.Tensor, v, hx=None):
img_feat = self.stem(x)
x = self.act(img_feat + self.v_proj(v))
hx = self.gru(x, hx)
act = self.fc(self.act(hx))
return act, None, hx
if __name__ == '__main__':
Model()