因为这里ELECTRA-LARGE只有tensorflow版本,但是目前pytorch有现成的代码想用。所以思考如何用pytorch来读取tensorflow村的模型:
方法:用transformers将tensorflow版本的模型转换为pytorch版本的模型,然后在进行读取。
个人环境:transformer==2.10.0,tensorflow==2.1.0,pytorch==1.5.1
官方源码:
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ELECTRA checkpoint."""
import argparse
import logging
import torch
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):
# Initialise PyTorch model
config = ElectraConfig.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
if discriminator_or_generator == "discriminator":
model = ElectraForPreTraining(config)
elif discriminator_or_generator == "generator":
model = ElectraForMaskedLM(config)
else:
raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'")
# Load weights from tf checkpoint
load_tf_weights_in_electra(
model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator
)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--discriminator_or_generator",
default="discriminator",
type=str,
required=True,
help="Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or "
"'generator'.",
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator
)
命令参考:
python convert_electra_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path K:/DATASET/ChinesePreTrain/chinese_electra_large_L-24_H-1024_A-16/electra_large --config_file K:/DATASET/ChinesePreTrain/chinese_electra_large_L-24_H-1024_A-16/large_discriminator_config.json --pytorch_dump_path ./electra_base/pytorch_model.bin --discriminator_or_generator discriminator
出现了一个bug:
UserWarning: h5py is running against HDF5 1.10.5 when it was built against 1.10.4, this may cause problems
'{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)
Warning! ***HDF5 library version mismatched error***
The HDF5 header files used to compile this application do not match
the version used by the HDF5 library to which this application is linked.
Data corruption or segmentation faults may occur if the application continues.
This can happen when an application was compiled by one version of HDF5 but
linked with a different version of static or shared HDF5 library.
You should recompile the application or check your shared library related
settings such as 'LD_LIBRARY_PATH'.
You can, at your own risk, disable this warning by setting the environment
variable 'HDF5_DISABLE_VERSION_CHECK' to a value of '1'.
Setting it to 2 or higher will suppress the warning messages totally.
Headers are 1.10.4, library is 1.10.5
SUMMARY OF THE HDF5 CONFIGURATION
=================================
General Information:
-------------------
HDF5 Version: 1.10.5
Configured on: 2019-03-04
Configured by: Visual Studio 15 2017 Win64
Host system: Windows-10.0.17763
Uname information: Windows
Byte sex: little-endian
Installation point: C:/Program Files/HDF5
Compiling Options:
------------------
Build Mode:
Debugging Symbols:
Asserts:
Profiling:
Optimization Level:
Linking Options:
----------------
Libraries:
Statically Linked Executables: OFF
LDFLAGS: /machine:x64
H5_LDFLAGS:
AM_LDFLAGS:
Extra libraries:
Archiver:
Ranlib:
Languages:
----------
C: yes
C Compiler: C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.16.27023/bin/Hostx86/x64/cl.exe 19.16.27027.1
CPPFLAGS:
H5_CPPFLAGS:
AM_CPPFLAGS:
CFLAGS: /DWIN32 /D_WINDOWS /W3
H5_CFLAGS:
AM_CFLAGS:
Shared C Library: YES
Static C Library: YES
Fortran: OFF
Fortran Compiler:
Fortran Flags:
H5 Fortran Flags:
AM Fortran Flags:
Shared Fortran Library: YES
Static Fortran Library: YES
C++: ON
C++ Compiler: C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.16.27023/bin/Hostx86/x64/cl.exe 19.16.27027.1
C++ Flags: /DWIN32 /D_WINDOWS /W3 /GR /EHsc
H5 C++ Flags:
AM C++ Flags:
Shared C++ Library: YES
Static C++ Library: YES
JAVA: OFF
JAVA Compiler:
Features:
---------
Parallel HDF5: OFF
Parallel Filtered Dataset Writes:
Large Parallel I/O:
High-level library: ON
Threadsafety: OFF
Default API mapping: v110
With deprecated public symbols: ON
I/O filters (external): DEFLATE DECODE ENCODE
MPE:
Direct VFD:
dmalloc:
Packages w/ extra debug output:
API Tracing: OFF
Using memory checker: OFF
Memory allocation sanity checks: OFF
Function Stack Tracing: OFF
Strict File Format Checks: OFF
Optimization Instrumentation:
Bye...
原因:hdf5版本不对,(conda这个时候装的tensorflow==2.1.0最高只能下载hdf5==1.10.4,而这里需要1.10.5)
解决方案:
pip uninstall h5py
pip install h5py
重装后成功日志:
成功转换: