-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgetTrainedNetwork.m
More file actions
44 lines (30 loc) · 1.08 KB
/
getTrainedNetwork.m
File metadata and controls
44 lines (30 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
%
% This code belongs to:
% Ahmet Emre Unal
% S001974
% emre.unal@ozu.edu.tr
%
%% getTrainedNetwork: The trainer for a single ODR network object
function [trainedNetwork] = getTrainedNetwork(network, digitToLearn, EPSILON, NUM_PASSES, samples)
CORRECT = 1;
WRONG = 0;
fprintf('Digit %d training starting, please wait.\n', digitToLearn);
iteration = 0;
trainedNetwork = network;
for pass = 1:NUM_PASSES
% Learned digit
sampleNum = randi([1,13],1,1);
[X, Y] = getSample(samples(digitToLearn + 1), sampleNum);
trainedNetwork.learn([X, Y], CORRECT, EPSILON);
% Other digits
otherDigit = getOtherRandDigit(digitToLearn);
sampleNum = randi([1,13],1,1);
[X, Y] = getSample(samples(otherDigit + 1), sampleNum);
trainedNetwork.learn([X, Y], WRONG, EPSILON);
if(rem(pass, (NUM_PASSES/10)) == 0)
iteration = iteration + 1;
fprintf('Digit %d training %d/10 complete.\n', digitToLearn, iteration);
end
end
fprintf('Digit %d training complete.\n', digitToLearn);
end