可以使用convert_pytorch_checkpoint_to_tf.py将pytorch版本的 bert模型转换为TF版本的bert模型,不过需要注意的是需要将程序进行一定的修改:
原始代码:
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir
)
更改为: