-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_select.py
More file actions
25 lines (21 loc) · 880 Bytes
/
model_select.py
File metadata and controls
25 lines (21 loc) · 880 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from models import *
class BaseModel():
@classmethod
def create(cls, message_type = 'resnet'):
MESSAGE_TYPE_TO_CLASS_MAP = {
'resnet': resnet.ResNet18,
# 'densenet': densenet.DenseNet,
'densenet': densenetBC.Network,
#'pyramidnet': pyramidnet.Network,
'resnet_basic': resnet.ResNet101,
#'resnext': resnext.Network,
#'pnasnet': pnasnet.PNASNetB,
#'vgg' : vgg.VGG,
#'vgg_bearpaw' : vgg_bearpaw.vgg16_bn,
#'shake' : shake_resnet.ShakeResNet,
}
if message_type not in MESSAGE_TYPE_TO_CLASS_MAP:
raise ValueError('Bad message type {}'.format(message_type))
if message_type == 'vgg':
return MESSAGE_TYPE_TO_CLASS_MAP[message_type]('VGG19')
return MESSAGE_TYPE_TO_CLASS_MAP[message_type]()