From 05733b7cf20b5b739a520e8c5b092ef932306fcc Mon Sep 17 00:00:00 2001 From: Adam Yala Date: Fri, 17 Aug 2018 15:37:11 -0400 Subject: [PATCH] gpu-trained on cpu --- rationale_net/utils/model.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rationale_net/utils/model.py b/rationale_net/utils/model.py index 370792a..297fef4 100644 --- a/rationale_net/utils/model.py +++ b/rationale_net/utils/model.py @@ -19,11 +19,18 @@ def get_model(args, embeddings, train_data): print('\nLoading model from [%s]...' % args.snapshot) try: gen_path = learn.get_gen_path(args.snapshot) - if os.path.exists(gen_path): - gen = torch.load(gen_path) - model = torch.load(args.snapshot) + if args.cuda: + if os.path.exists(gen_path): + gen = torch.load(gen_path) + model = torch.load(args.snapshot) + else: + if os.path.exists(gen_path): + gen = torch.load(gen_path, map_location=lambda storage, loc: storage) + gen.args.cuda = "false" + model = torch.load(args.snapshot, map_location=lambda storage, loc: storage) + model.args.cuda = "false" except : - print("Sorry, This snapshot doesn't exist."); exit() + print("Sorry, This snapshot doesn't exist.") if args.num_gpus > 1: model = nn.DataParallel(model,