-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
41 lines (34 loc) · 992 Bytes
/
utils.py
File metadata and controls
41 lines (34 loc) · 992 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.utils.data as data
import torch.nn.functional as F
import numpy as np
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
def overall_acc(pred, label): # overall accuracy
r = (torch.argmax(pred, dim=-1) == label).float()
acc = torch.mean(r)
return acc
class AvgMeter(object):
def __init__(self, num=40):
self.num = num
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.losses = []
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.losses = []
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.losses.append(val)
def show(self):
return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses) - self.num, 0):]))