使用torch.nn.DataParallel多卡训练模型之后,加载模型前也需要打开多卡读取模型。
我最近使用多卡训练了一个模型。保存的方式是state_dict的方式。
然后在加载模型的时候就一直出错。
raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
RuntimeError: Error(s) in loading state_dict for BertMultiClassification:
Missing key(s) in state_dict
意思是这个状态字典中缺少键。
输出模型需要的字典中key 的名称
bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model_path)
bert_config = BertConfig.from_pretrained(pretrained_model_path)
bert_model = BertModel.from_pretrained(pretrained_model_path)
bert_model