1、保存pb格式
参考:https://gist.github.com/liprais/78ef34ac0779fb6b604eae044d789f90
注意:
1)os.environ['TF_KERAS'] = '1' 放在bert4keras,keras 前
2)模型save保存这里的tf版本2.X以上
3)模型predict使用的时候tf版本1.X才行
** cpkt转pb
import numpy as np
import os
from collections import Counter
os.environ['TF_KERAS'] = '1'
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import sequence_padding
from bert4keras.snippets import uniout
from keras.models import Model
# import pandas as pd
# bert配置
config_path = r'C:\*****_L-6_H-384_A-12\bert_config.json'
checkpoint_path = r'C*******L-6_H-384_A-12\bert_model.ckpt'
dict_path = r'C:******-6_H-384_A-12\