diff --git a/demo.py b/demo.py index 80483ac..fd68b6b 100755 --- a/demo.py +++ b/demo.py @@ -102,7 +102,7 @@ def process_image(img_file, bbox_file, openpose_file, input_res=224): # Load pretrained model model = hmr(config.SMPL_MEAN_PARAMS).to(device) - checkpoint = torch.load(args.checkpoint) + checkpoint = torch.load(args.checkpoint, map_location=device) model.load_state_dict(checkpoint['model'], strict=False) # Load SMPL model