-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot.py
More file actions
50 lines (43 loc) · 1.43 KB
/
Copy pathplot.py
File metadata and controls
50 lines (43 loc) · 1.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python3
import re
import matplotlib.pyplot as plt
def parse_step(step_str):
"""解析 step 值,如 '18K' -> 18000"""
step_str = step_str.strip()
if step_str.endswith('K'):
return float(step_str[:-1]) * 1000
elif step_str.endswith('M'):
return float(step_str[:-1]) * 1000000
else:
return float(step_str)
def extract_loss_from_log(filepath):
"""从日志文件中提取 step 和 loss"""
steps = []
losses = []
pattern = re.compile(r'step:(\S+)\s+.*loss:(\d+\.?\d*)')
with open(filepath, 'r') as f:
for line in f:
match = pattern.search(line)
if match:
step = parse_step(match.group(1))
loss = float(match.group(2))
steps.append(step)
losses.append(loss)
return steps, losses
def plot_loss(steps, losses, output_path='loss_plot.png'):
"""绘制 loss 曲线"""
plt.figure(figsize=(10, 6))
plt.plot(steps, losses, 'b-', linewidth=0.8, alpha=0.7)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
print(f'Plot saved to {output_path}')
if __name__ == '__main__':
log_file = 'ACT_log.txt'
steps, losses = extract_loss_from_log(log_file)
print(f'Extracted {len(losses)} data points')
if losses:
plot_loss(steps, losses)