Skip to content

在试图使用训练好的模型进行评测时遇到了参数不对应的问题 #36

@charpng

Description

@charpng

请问在训练过程中保存下来的模型该如何加载呢?我在试图使用训练好的模型进行评测时遇到了参数不对应的问题,报错如下:
File "/DBCNN-PyTorch-master/DBCNN.py", line 53, in init
scnn.load_state_dict(torch.load(scnn_root))
File "/home/guest1/anaconda3/envs/wpf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.1.weight", "module.features.1.bias", "module.features.1.running_mean", "module.features.1.running_var", "module.features.3.weight", "module.features.3.bias", "module.features.4.weight", "module.features.4.bias", "module.features.4.running_mean", "module.features.4.running_var", "module.features.6.weight", "module.features.6.bias", "module.features.7.weight", "module.features.7.bias", "module.features.7.running_mean", "module.features.7.running_var", "module.features.9.weight", "module.features.9.bias", "module.features.10.weight", "module.features.10.bias", "module.features.10.running_mean", "module.features.10.running_var", "module.features.12.weight", "module.features.12.bias", "module.features.13.weight", "module.features.13.bias", "module.features.13.running_mean", "module.features.13.running_var", "module.features.15.weight", "module.features.15.bias", "module.features.16.weight", "module.features.16.bias", "module.featu
Unexpected key(s) in state_dict: "module.features1.0.weight", "module.features1.0.bias", "module.features1.2.weight", "module.features1.2.bias", "module.features1.5.weight", "module.features1.5.bias", "module.features1.7.weight", "module.features1.7.bias", "module.features1.10.weight", "module.features1.10.bias", "module.features1.12.weight", "module.features1.12.bias", "module.features1.14.weight", "module.features1.14.bias", "module.features1.17.weight", "module.features1.17.bias", "module.features1.19.weight", "module.features1.19.bias", "module.features1.21.weight", "module.features1.21.bias", "module.features1.24.weight", "module.features1.24.bias", "module.features1.26.weight", "module.features1.26.bias", "module.features1.28.weight", "module.features1.28.bias", "module.features2.0.weight", "module.features2.0.bias", "module.features2.1.weight", "module.features2.1.bias", "module.features2.1.running_mean", "module.features2.1.running_var", "module.features2.1.num_batches_tracked", "module.features2.3.weight", "module.features2.3.bias", "module.features2.4.weight", "module.features2.4.bias", "module.features2.4.running_mean", "module.features2.4.running_var", "module.features2.4.num_batches_tracked", "module.features2.6.weight", "module.features2.6.bias", "module.features2.7.weight", "module.features2.7.bias", "module.features2.7.running_mean", "module.features2.7.running_var", "module.features2.7.num_batches_tracked", "module.features2.9.weight", "module.features2.9.bias", "module.features2.10.weight", "module.features2.10.bias", "module.features2.10.running_mean", "module.features2.10.running_var", "module.features2.10.num_batches_tracked", "module.features2.12.weight", "module.features2.12.bias", "module.features2.13.weight", "module.features2.13.bias", "module.features2.13.running_mean", "module.features2.13.running_var", "module.features2.13.num_batches_tracked", "module.features2.15.weight", "module.features2.15.bias", "module.features2.16.weight", "module.features2.16.bias", "module.features2.16.running_mean", "module.features2.16.running_var", "module.features2.16.num_batches_tracked", "module.features2.18.weight", "module.features2.18.bias", "module.features2.19.weight", "module.features2.19.bias", "module.features2.19.running_mean", "module.features2.19.running_var", "module.features2.19.num_baar", "module.features2.16.num_batches_tracked", "module.features2.18.weight", "module.features2.18.bias", "module.features2.19.weight", "module.features2.19.bias", "module.features2.19.running_mean", "module.features2.19.running_var", "module.features2.19.num_batches_tracked", "module.features2.21.weight", "module.features2.21.bias", "module.features2.22.weight", "module.features2.22.bias", "module.features2.22.running_mean", "module.features2.22.running_var", "module.features2.22.num_batches_tracked", "module.featurestches_tracked", "module.features2.21.weight", "module.features2.21.bias", "module.features2.22.weight", "module.features2.22.bias", "module.features2.22.running_mean", "module.features2.22.running_var", "module.features2.22.num_batches_tracked", "module.features2.24.weight", "module.features2.24.bias", "module.features2.25.weight", "module.features2.25.bias", "module.features2.25.running_mean", "module.features2.25.running_var", "module.features2.25.num_batches_tracked", "module.fc.weight", "module.fc.bias".

以下是我用来加载模型的代码,请问问题出在哪里,该如何修正呢?

model_path = '/DBCNN-PyTorch-master/fc_models/readability_params_best.pkl'
def evaluate_model(model_path, image_dir, excel_file):
        # 加载模型
    options_ = {
        'base_lr': 5e-5,
        'batch_size': 4,
        'epochs': 1,
        'weight_decay': 5e-4,
        'dataset':'tianma',
        'fc': False,
        'train_index': [],
        'test_index': []
        }
    
    model = DBCNN(scnn_root=model_path, options=options_)
    model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(scnn_root=model_path))
    model.eval()  # 设置为评估模式

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions