-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
72 lines (55 loc) · 2.33 KB
/
utils.py
File metadata and controls
72 lines (55 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import json
import logging
import numpy as np
import torch
def getLogger(name,
format_str='%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s',
date_format='%Y-%m-%d %H:%M:%S',
log_file=False):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
# file or console
handler = logging.StreamHandler() if not log_file else logging.FileHandler(name)
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def numParams(net):
count = sum([int(np.prod(param.shape)) for param in net.parameters()])
return count
def countFrames(n_samples, win_size, hop_size):
n_overlap = win_size // hop_size
fn = lambda x: x // hop_size + n_overlap - 1
n_frames = torch.stack(list(map(fn, n_samples)), dim=0)
return n_frames
def lossMask(shape, n_frames, device):
loss_mask = torch.zeros(shape, dtype=torch.float32, device=device)
for i, seq_len in enumerate(n_frames):
loss_mask[i,:,0:seq_len,:] = 1.0
return loss_mask
def lossLog(log_filename, ckpt, logging_period):
if ckpt.ckpt_info['cur_epoch'] == 0 and ckpt.ckpt_info['cur_iter'] + 1 == logging_period:
with open(log_filename, 'w') as f:
f.write('epoch, iter, tr_loss, cv_loss\n')
f.write('{}, {}, {:.4f}, {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1,
ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['tr_loss'], ckpt.ckpt_info['cv_loss']))
else:
with open(log_filename, 'a') as f:
f.write('{}, {}, {:.4f}, {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1,
ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['tr_loss'], ckpt.ckpt_info['cv_loss']))
def wavNormalize(*sigs):
# sigs is a list of signals to be normalized
scale = max([np.max(np.abs(sig)) for sig in sigs]) + np.finfo(np.float32).eps
sigs_norm = [sig / scale for sig in sigs]
return sigs_norm
def dump_json(filename, obj):
with open(filename, 'w') as f:
json.dump(obj, f, indent=4, sort_keys=True)
return
def load_json(filename):
if not os.path.isfile(filename):
raise FileNotFoundError('Could not find json file: {}'.format(filename))
with open(filename, 'r') as f:
obj = json.load(f)
return obj