-
Notifications
You must be signed in to change notification settings - Fork 2
Description
When I tested the SCOUT+ model using the provided pre-trained weights "SCOUT+_BDD-A.pt", I encountered an error:
Traceback (most recent call last):
File "test.py", line 339, in
test.test_saved(**vars(args))
File "test.py", line 109, in test_saved
if not self.load_saved(config_dir):
File "test.py", line 100, in load_saved
model.load_state_dict(torch.load(best_model_weights))
File "/cver/llsun/miniconda3/envs/scout/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SCOUT_map_v1:
Missing key(s) in state_dict: "multihead_attn.1.in_proj_weight", "multihead_attn.1.in_proj_bias", "multihead_attn.1.out_proj.weight", "multihead_attn.1.out_proj.bias", "norm.1.weight", "norm.1.bias".
Could you please help me figure out how to resolve it?