-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhw04.py
More file actions
34 lines (28 loc) · 855 Bytes
/
hw04.py
File metadata and controls
34 lines (28 loc) · 855 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]
p=torch.Tensor([1.0])
p.requires_grad = True
q=torch.Tensor([1.0])
q.requires_grad = True
b=torch.Tensor([1.0])
b.requires_grad = True
def forward(x):
return p*x*x + q*x + b
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)**2
print('Predict(before training):', 4, forward(4).item())
for epoch in range(100):
for x,y in zip(x_data, y_data):
l = loss(x, y)
l.backward()
print('\tgrad', x, y, p.grad.item(), q.grad.item(), b.grad.item())
p.data = p.data - 0.01 * p.grad.data
q.data = q.data - 0.01 * q.grad.data
b.data = b.data - 0.01 * b.grad.data
p.grad.data.zero_()
q.grad.data.zero_()
b.grad.data.zero_()
print('Progress:' ,epoch, l.item())
print('Predict(after training):', 4, forward(4).item())