pytorch如何读取tensorflow的ELECTRA

本文介绍了一种使用transformers库将TensorFlow版本的ELECTRA模型转换为PyTorch版本的方法,包括解决hdf5版本不匹配的问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

因为这里ELECTRA-LARGE只有tensorflow版本,但是目前pytorch有现成的代码想用。所以思考如何用pytorch来读取tensorflow村的模型:

方法:用transformers将tensorflow版本的模型转换为pytorch版本的模型,然后在进行读取。

个人环境:transformer==2.10.0,tensorflow==2.1.0,pytorch==1.5.1

官方源码:

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ELECTRA checkpoint."""


import argparse
import logging

import torch

from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra


logging.basicConfig(level=logging.INFO)


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):
    # Initialise PyTorch model
    config = ElectraConfig.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))

    if discriminator_or_generator == "discriminator":
        model = ElectraForPreTraining(config)
    elif discriminator_or_generator == "generator":
        model = ElectraForMaskedLM(config)
    else:
        raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'")

    # Load weights from tf checkpoint
    load_tf_weights_in_electra(
        model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator
    )

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    parser.add_argument(
        "--discriminator_or_generator",
        default="discriminator",
        type=str,
        required=True,
        help="Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or "
        "'generator'.",
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(
        args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator
    )

命令参考:

python convert_electra_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path K:/DATASET/ChinesePreTrain/chinese_electra_large_L-24_H-1024_A-16/electra_large --config_file K:/DATASET/ChinesePreTrain/chinese_electra_large_L-24_H-1024_A-16/large_discriminator_config.json --pytorch_dump_path ./electra_base/pytorch_model.bin --discriminator_or_generator discriminator

出现了一个bug:

UserWarning: h5py is running against HDF5 1.10.5 when it was built against 1.10.4, this may cause problems
  '{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)
Warning! ***HDF5 library version mismatched error***
The HDF5 header files used to compile this application do not match
the version used by the HDF5 library to which this application is linked.
Data corruption or segmentation faults may occur if the application continues.
This can happen when an application was compiled by one version of HDF5 but
linked with a different version of static or shared HDF5 library.
You should recompile the application or check your shared library related
settings such as 'LD_LIBRARY_PATH'.
You can, at your own risk, disable this warning by setting the environment
variable 'HDF5_DISABLE_VERSION_CHECK' to a value of '1'.
Setting it to 2 or higher will suppress the warning messages totally.
Headers are 1.10.4, library is 1.10.5
        SUMMARY OF THE HDF5 CONFIGURATION
        =================================

General Information:
-------------------
                   HDF5 Version: 1.10.5
                  Configured on: 2019-03-04
                  Configured by: Visual Studio 15 2017 Win64
                    Host system: Windows-10.0.17763
              Uname information: Windows
                       Byte sex: little-endian
             Installation point: C:/Program Files/HDF5

Compiling Options:
------------------
                     Build Mode:
              Debugging Symbols:
                        Asserts:
                      Profiling:
             Optimization Level:

Linking Options:
----------------
                      Libraries:
  Statically Linked Executables: OFF
                        LDFLAGS: /machine:x64
                     H5_LDFLAGS:
                     AM_LDFLAGS:
                Extra libraries:
                       Archiver:
                         Ranlib:

Languages:
----------
                              C: yes
                     C Compiler: C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.16.27023/bin/Hostx86/x64/cl.exe 19.16.27027.1
                       CPPFLAGS:
                    H5_CPPFLAGS:
                    AM_CPPFLAGS:
                         CFLAGS:  /DWIN32 /D_WINDOWS /W3
                      H5_CFLAGS:
                      AM_CFLAGS:
               Shared C Library: YES
               Static C Library: YES

                        Fortran: OFF
               Fortran Compiler:
                  Fortran Flags:
               H5 Fortran Flags:
               AM Fortran Flags:
         Shared Fortran Library: YES
         Static Fortran Library: YES

                            C++: ON
                   C++ Compiler: C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.16.27023/bin/Hostx86/x64/cl.exe 19.16.27027.1
                      C++ Flags: /DWIN32 /D_WINDOWS /W3 /GR /EHsc
                   H5 C++ Flags:
                   AM C++ Flags:
             Shared C++ Library: YES
             Static C++ Library: YES

                            JAVA: OFF
                   JAVA Compiler:

Features:
---------
                   Parallel HDF5: OFF
Parallel Filtered Dataset Writes:
              Large Parallel I/O:
              High-level library: ON
                    Threadsafety: OFF
             Default API mapping: v110
  With deprecated public symbols: ON
          I/O filters (external):  DEFLATE DECODE ENCODE
                             MPE:
                      Direct VFD:
                         dmalloc:
  Packages w/ extra debug output:
                     API Tracing: OFF
            Using memory checker: OFF
 Memory allocation sanity checks: OFF
          Function Stack Tracing: OFF
       Strict File Format Checks: OFF
    Optimization Instrumentation:
Bye...

原因:hdf5版本不对,(conda这个时候装的tensorflow==2.1.0最高只能下载hdf5==1.10.4,而这里需要1.10.5)

解决方案:

pip uninstall h5py
pip install h5py

重装后成功日志:

成功转换:

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值