以BERT为代表的预训练模型是目前NLP领域最火热的方向,但是Google发布的 BERT 是Tensorflow格式的,这让使用pytorch格式 程序猿 们很为难。
为解决这个问题,本篇以BERT为例,介绍将Tensorflow格式的模型转换为Pytorch格式的模型。
1. 工具安装

使用工具为:Transformers(链接),该工具对常用的预训练模型进行封装,可以非常方便的使用 pytorch调用预训练模型。
使用如下命令安装:
pip install transformers
2. 模型转换
- 下载google的
BERT模型; - 使用如下命令进行转换:
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
transformers bert \
$BERT_BASE_DIR/bert_model.ckpt \
$BERT_BASE_DIR/bert_config.json \
$BERT_BASE_DIR/pytorch_model.bin
BERT模型跨框架移植
本文介绍如何将Google的BERT模型从Tensorflow格式转换为Pytorch格式,通过使用Transformers工具,实现模型的无缝迁移,方便Pytorch用户进行NLP任务。
6439





