-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_feature.py
More file actions
63 lines (46 loc) · 1.8 KB
/
extract_feature.py
File metadata and controls
63 lines (46 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def run_feature(input_data,label,nodescomplete,iters=10000):
import auto_imp
from auto_art import autoencoder,net_work
import numpy as np
input_data = input_data
label = label
nodescomplete = nodescomplete
iters = iters
print("="*40)
print("Building autoencoder")
nodes = nodescomplete[:-1]
ae = auto_imp.form_antoencoder(nodes)
ae = auto_imp.train_encoder(ae, input_data, 100)
print("="*40)
print("Training neural network.....")
aecomplete = net_work(nodescomplete)
# build a full network, initializing weights with trained autoencoder weights
for i in range(len(nodescomplete)-2):
aecomplete.weight[i] = ae.encoder[i].weight[0]
aecomplete = auto_imp.train_network(aecomplete, input_data, label, iters)
return aecomplete
if __name__=="__main__":
import numpy as np
input_data = np.loadtxt("/home/ben/Documents/Pyspace/auto_encoder/input_data.txt",delimiter=',',dtype=float)
label = np.loadtxt("/home/ben/Documents/Pyspace/auto_encoder/label_of_inputdata.txt",dtype=float)
label = label.reshape((len(label),1))
iters = 1000
nodescomplete = [input_data.shape[1],3,2,1]
aecomplete = run_feature(input_data, label, nodescomplete,iters)
print("="*40)
print("Showing input data")
print(aecomplete.act_res[0])
print("Showing label data")
print(label)
print("="*40)
print("Showing the first hidden layer results")
print(aecomplete.act_res[1])
print("="*40)
print("Showing the second hidden layer results")
print(aecomplete.act_res[2])
print("="*40)
print("Showing output of trained by neural network")
print(aecomplete.act_res[-1])
res = aecomplete.act_res[-1] >=0.5
accuracy = np.mean(res==label)*100
print("Traing accuracy:%f" % accuracy)