-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_train.py
More file actions
26 lines (21 loc) · 858 Bytes
/
main_train.py
File metadata and controls
26 lines (21 loc) · 858 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import torch
from pre_train import *
#Parameter Initialization
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
opt = Args()
opt.set_graphtype('grid_v1') # options: 'grid_v1', 'grid_v2', 'pivotMds_grid'
# opt.executename = "ModelName_DatasetName_TrialID"
opt.executename = "GraphLSTM_pyg-grid_v1-demo1"
# Set the chosen model and related parameters
# All the model candidates:
# "BiLSTM": a 4-layer Bidirectional-LSTM model
# "GraphLSTM_dgl": the proposed Graph LSTM model implemented with the DGL library
# "GraphLSTM_pyg": the proposed Graph LSTM model implemented with the PyG library
opt.model_select = "GraphLSTM_pyg"
precheck(opt)
print(opt.__dict__)
# Start to train the model
model = getmodel(opt)
dataloader, valid_dataloader, test_dataloader = getdataloader(opt)
train_model(model,dataloader, valid_dataloader, test_dataloader,opt)