-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathtrain.py
More file actions
94 lines (77 loc) · 2.02 KB
/
train.py
File metadata and controls
94 lines (77 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from resnet import ResModel
from senet import SeModel
from vgg2d import vgg2d
from vgg1d import vgg1d, vggmel
from densenet import densenet121
from trainer import train_model
from data import preprocess_mel, preprocess_mfcc, preprocess_wav
list_2d = [('mel', preprocess_mel), ('mfcc', preprocess_mfcc)]
BAGGING_NUM=4
def train_and_predict(cfg_dict, preprocess_list):
for p, preprocess_fun in preprocess_list:
cfg = cfg_dict.copy()
cfg['preprocess_fun'] = preprocess_fun
cfg['CODER'] += '_%s' %p
cfg['bagging_num'] = BAGGING_NUM
print("training ", cfg['CODER'])
train_model(**cfg)
res_config = {
'model_class': ResModel,
'is_1d': False,
'reshape_size': None,
'BATCH_SIZE': 32,
'epochs': 100,
'CODER': 'resnet'
}
print("train resnet.........")
train_and_predict(res_config, list_2d)
se_config = {
'model_class': SeModel,
'is_1d': False,
'reshape_size': 128,
'BATCH_SIZE': 16,
'epochs': 100,
'CODER': 'senet'
}
print("train senet..........")
train_and_predict(se_config, list_2d)
dense_config = {
'model_class': densenet121,
'is_1d': False,
'reshape_size': 128,
'BATCH_SIZE': 16,
'epochs': 100,
'CODER': 'dense'
}
print("train densenet.........")
train_and_predict(dense_config, list_2d)
vgg2d_config = {
'model_class': vgg2d,
'is_1d': False,
'reshape_size': 128,
'BATCH_SIZE': 32,
'epochs': 100,
'CODER': 'vgg2d'
}
print("train vgg2d...........")
train_and_predict(vgg2d_config, list_2d)
vgg1d_config = {
'model_class': vgg1d,
'is_1d': True,
'reshape_size': None,
'BATCH_SIZE': 32,
'epochs': 100,
'CODER': 'vgg1d'
}
print("train vgg1d on raw features..........")
train_and_predict(vgg1d_config, [('raw', preprocess_wav)])
vggmel_config = {
'model_class': vggmel,
'is_1d': True,
'reshape_size': None,
'BATCH_SIZE': 64,
'epochs': 100,
'CODER': 'vgg1d'
}
print("train vgg1d on mel features..........")
train_and_predict(vggmel_config, [('mel', preprocess_mel)])