-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.lua
More file actions
94 lines (74 loc) · 2.1 KB
/
train.lua
File metadata and controls
94 lines (74 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
require 'nn'
require 'torch'
require 'optim'
require 'residual'
require 'data'
require 'image'
-- learning rate and momentum parameter borrowed from the paper
opt = {
batchSize = 1,
learningRate = 0.0001,
numEpoch = 500,
momentum = 0.9,
numClasses = 8,
}
local criterion = nn.BCECriterion()
local model = nn.Sequential()
local imageInput = torch.Tensor(opt.batchSize, 3, 224, 224)
local modelParams, gradModelParams = model:getParameters()
model:add(makeModel())
local totalBatchSize = getNumDataSize()
local label = torch.Tensor(opt.batchSize)
local dataSetCount = 1
local params, gradParams = model:getParameters()
local function makeLabel(firstLetter) -- making label for the particulat file type
if firstLetter == 'A' then
label:fill(0)
elseif firstLetter == 'B' then
label:fill(1)
elseif firstLetter == 'D' then
label:fill(2)
elseif firstLetter == 'L' then
label:fill(3)
elseif firstLetter == 'N' then
label:fill(4)
elseif firstLetter == 'O' then
label:fill(5)
elseif firstLetter == 'S' then
label:fill(6)
elseif firstLetter == 'Y' then
label:fill(7)
end
end
optimState = {
learningRate = opt.learningRate,
momentum = opt.momentum, -- as used in the paper
}
local oneEpoch = function(x)
gradParams:zero()
local imageName = getImage(dataSetCount)
dataSetCount = dataSetCount + 1
local img = image.load(('images/'..imageName), 3, 'float')
makeLabel(imageName:sub(1,1))
imageInput:copy(img)
local output = model:forward(imageInput)
local imgError = criterion:forward(output, label)
local criterionError = criterion:backward(output, label)
model:backward(imageInput, criterionError)
return imgError, gradParams
end
for epoch = 1, opt.numEpoch do
print('epoch count: ' .. epoch)
for batchSizeIndex = 1, totalBatchSize do
optim.adam(oneEpoch, modelParams, optimState)
if batchSizeIndex % 200 == 0 then
print('batch size count '.. batchSizeIndex)
end
end
modelParams, gradModelParams = nil, nil
if epoch % 50 == 0 then
torch.save('TrainedModels/'..epoch, model:clearState()) -- for memory
end
modelParams, gradModelParams = model:getParameters()
dataSetCount = 1
end