-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathbase_average.py
More file actions
43 lines (32 loc) · 1.73 KB
/
base_average.py
File metadata and controls
43 lines (32 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
average all base models
"""
import pandas as pd
import numpy as np
from data import get_label_dict
dense_mel = pd.read_csv("pred_scores/dense_mel.csv", index_col=-1)
dense_mfcc = pd.read_csv("pred_scores/dense_mfcc.csv", index_col=-1)
resnet_mel = pd.read_csv("pred_scores/resnet_mel.csv", index_col=-1)
resnet_mfcc = pd.read_csv("pred_scores/resnet_mfcc.csv", index_col=-1)
senet_mel = pd.read_csv("pred_scores/senet_mel.csv", index_col=-1)
senet_mfcc = pd.read_csv("pred_scores/senet_mfcc.csv", index_col=-1)
vgg2d_mel = pd.read_csv("pred_scores/vgg2d_mel.csv", index_col=-1)
vgg2d_mfcc = pd.read_csv("pred_scores/vgg2d_mfcc.csv", index_col=-1)
vgg1d_mel = pd.read_csv("pred_scores/vgg1d_mel.csv", index_col=-1)
vgg1d_raw = pd.read_csv("pred_scores/vgg1d_raw.csv", index_col=-1)
fname = vgg1d_raw.index
"""Weights were determined by using public LB feedback"""
result = (dense_mel.as_matrix() * 0.6 + dense_mfcc.as_matrix() * 0.4) * 0.15 + \
(resnet_mel.as_matrix() * 0.6 + resnet_mfcc.as_matrix() * 0.4) * 0.15 + \
(senet_mel.as_matrix() * 0.6 + senet_mfcc.as_matrix() * 0.4) * 0.15 + \
(vgg2d_mel.as_matrix() * 0.6 + vgg2d_mfcc.as_matrix() * 0.4) * 0.15 + \
(vgg1d_raw.as_matrix() * 0.75 + vgg1d_mel.as_matrix() * 0.25) * 0.4
label_to_int, int_to_label = get_label_dict()
final_labels = [int_to_label[x] for x in np.argmax(result, 1)]
print(len(final_labels))
labels = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 'unknown', 'silence']
pred_scores = pd.DataFrame(result, columns=labels)
pred_scores['fname'] = fname
pred_scores.to_csv("pred_scores/base_average.csv", index=False)
pd.DataFrame({'fname': fname,
'label': final_labels}).to_csv("sub/base_average.csv", index=False)