-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
38 lines (30 loc) · 1 KB
/
train.py
File metadata and controls
38 lines (30 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import os
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
print("GPU:", torch.cuda.get_device_name(0))
# tiny MLP
model = torch.nn.Sequential(
torch.nn.Linear(1024, 2048),
torch.nn.ReLU(),
torch.nn.Linear(2048, 10)
).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
B = int(os.getenv("BS", "2048"))
steps = int(os.getenv("STEPS", "10"))
for step in range(1, steps+1):
x = torch.randn(B, 1024, device=device)
y = torch.randint(0, 10, (B,), device=device)
opt.zero_grad(set_to_none=True)
logits = model(x)
loss = loss_fn(logits, y)
loss.backward()
opt.step()
if device.type == "cuda":
torch.cuda.synchronize()
free, total = torch.cuda.mem_get_info()
print(f"step={step:02d} loss={loss.item():.4f} vram={((total-free)/1e9):.2f}GB used")
print("OK: backward + step ran on", device)