BERT 提取特征 (extract_features.py) 源码分析 代码简化

版权声明:本文为博主原创文章,转载请注明出处:https://blog.youkuaiyun.com/ling620/article/details/97789853

之前的文章介绍了如何使用Bert的extract_features.py去提取特征向量,本文对源码进一步的分析。
BERT之提取特征向量 及 bert-as-server的使用

代码位于: bert/extract_features.py

本文主要包含两部分内容:

  1. 对源码进行分析
  2. 对源码进行简化

源码分析

1. 输入参数

必选参数,如下:

  • input_file:数据存放路径
  • vocab_file:字典文件的地址
  • bert_config_file:配置文件
  • init_checkpoint:模型文件
  • output_file:输出文件
if __name__ == "__main__":
    flags.mark_flag_as_required("input_file")
    flags.mark_flag_as_required("vocab_file")
    flags.mark_flag_as_required("bert_config_file")
    flags.mark_flag_as_required("init_checkpoint")
    flags.mark_flag_as_required("output_file")
    tf.app.run()

其他参数:

在文件最开始部分

  • layers:获取的层数索引, 默认值是 [-1, -2, -3, -4] 即表示倒数第一层、倒数第二层、倒数第三层和倒数第四层
  • max_seq_length:输入序列的最大长度,大于此值则截断,小于此值则填充0
  • batch_size:预测的batch大小
  • use_tpu:是否使用TPU
  • use_one_hot_embeddings:是否使用独热编码
flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
flags.DEFINE_integer(
    "max_seq_length", 128,
    "The maximum total input sequence length after WordPiece tokenization. "
    "Sequences longer than this will be truncated, and sequences shorter "
    "than this will be padded.")
    
flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")

flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")

flags.DEFINE_bool(
    "use_one_hot_embeddings", False,
    "If True, tf.one_hot will be used for embedding lookups, otherwise "
    "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
    "since it is much faster.")

2. 主流程

主要有以下几个步骤:

  1. 读取配置文件,构建BertConfig
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  2. 获取tokenization的对象
    tokenization.py是对输入的句子处理,包含两个主要类:BasickTokenizer, FullTokenizer

    BasickTokenizer会对每个字做分割,会识别英文单词,对于数字会合并,例如:

    query: 'Jack,请回答1988, UNwant\u00E9d,running'
    token: ['jack', ',', '请', '回', '答', '1988', ',', 'unwanted', ',', 'running']
    

    FullTokenizer会对英文字符做n-gram匹配,会将英文单词拆分,例如running会拆分为run、##ing,主要是针对英文。

    query: 'UNwant\u00E9d,running'
    token: ["un", "##want", "##ed", ",", "runn", "##ing"]
    
  3. 获取RunConfig对象,作为TPUEstimator的输入参数
    run_config = tf.contrib.tpu.RunConfig()

  4. 读取输入文件,处理为InputExample类型的列表

    examples.append(InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
    
  5. 将输入文件转化为InputFeatures类型的列表
    features = convert_examples_to_features()

  6. 构造model

  7. 构造Estimator对象

  8. 构造输入input_fn

  9. 进行预测,获取结果,存入json文件中

    results =  estimator.predict(input_fn, yield_single_examples=True)
    

    依次将结果读取

    for result in estimator.predict(input_fn, yield_single_examples=True):
    	for (i, token) in enumerate(feature.tokens):
            all_layers = []
            for (j, layer_index) in enumerate(layer_indexes):
                layer_output = result["layer_output_%d" % j]
                layers = collections.OrderedDict()
                layers["index"] = layer_index
                layers["values"] = [
                        round(float(x), 6) for x in layer_output[i:(i + 1)].flat
                    ]
                all_layers.append(layers)
    

上面这部分代码的意思是只取出输入 经过tokenize之后的长度 的向量。
即如果max_seq_lenght设为128, 如果输入的句子为我爱你,则经过tokenize之后的输入tokens=[["CLS"], '我', '爱','你',["SEP"]],实际有效长度为5,而其余128-5位均填充0。
上面代码就是只取出有效长度的向量。
layer_output的维度是(128, 768), layers["values"]的维度是是(5,768)

这在文章BERT之提取特征向量 及 bert-as-server的使用中提到。

上述几个流程详细内容见下一小节。

源码及注释如下:

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
	# 读取配置文件,构建BertConfig
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
	
	# 对句子进行处理,拆分
    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
	
	# 获取RunConfig对象,作为TPUEstimator的输入参数
    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        master=FLAGS.master,
        tpu_config=tf.contrib.tpu.TPUConfig(
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))
	# 读取输入文件,处理为InputExample类型的列表
    examples = read_examples(FLAGS.input_file)
	# 将输入文件转化为InputFeatures类型的列表
    features = convert_examples_to_features(
        examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
	# 构造id到特征的映射字典
    unique_id_to_feature = {
   
   }
    for feature in features:
        unique_id_to_feature[feature.unique_id] = feature
	# 构造model
    model_fn = model_fn_builder(
        bert_config=bert_config,
        init_checkpoint=FLAGS.init_checkpoint,
        layer_indexes=layer_indexes,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS
### Blip2ForConditionalGeneration `from_pretrained` 方法引发的 ImportError 或 AttributeError 的解决方案 当尝试通过 `Blip2ForConditionalGeneration.from_pretrained` 加载预训练模型时遇到 `ImportError` 或 `AttributeError`,这通常是由以下几个原因引起的: #### 1. 版本不兼容问题 如果使用的 `transformers` 库版本较旧,则可能无法支持最新的模型架构或方法调用。例如,在某些情况下,`modeling_utils.py` 文件中的函数定义可能发生更改,而这些更改可能导致加载失败[^1]。 为了验证当前安装的库版本是否适配目标模型,请运行以下命令来检查版本号: ```bash pip show transformers ``` 若发现版本过低,可以升级到最新版以解决问题: ```bash pip install --upgrade transformers ``` #### 2. 缺少依赖项 部分复杂模型(如 BLIP-2)需要额外的依赖包才能正常工作。如果没有正确安装这些依赖项,可能会触发 `ImportError` 错误。对于 BLIP-2 模型而言,常见的缺失依赖包括但不限于 `torch`, `sentencepiece`, 和 `accelerate` 等[^2]。 可以通过执行如下脚本来一次性安装所需的所有依赖: ```bash pip install torch sentencepiece accelerate ``` #### 3. 预训练权重文件损坏或者丢失 另一个常见原因是下载过程中出现了异常情况,致使本地存储的部分权重文件遭到破坏或者是完全不存在的情况。此时建议清除缓存重新获取完整的参数集[^3]。 删除现有缓存路径下的相关内容后再试一次即可: ```python import shutil from pathlib import Path cache_dir = Path.home() / ".cache/huggingface/transformers" shutil.rmtree(cache_dir, ignore_errors=True) ``` 之后再次调用 `from_pretrained` 函数完成初始化操作。 --- 以下是修正后的代码实现方式作为参考: ```python from transformers import Blip2ForConditionalGeneration, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b", use_fast=False) model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", device_map="auto", offload_folder="./offload", offload_state_dict=True ) ``` --- ####
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值