自己训练BERT网络

自己训练BERT

本文使用google提供的BERT脚本进行训练,并在TensorBoard中观察BERT的计算图。

bert地址:

https://github.com/google-research/bert

clone这个git repo

git clone https://github.com/google-research/bert

下载BERT预训练模型,里面有vocab.txt文件,后面要用到

wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip

解压BERT预训练模型到某个目录

:~/ugetdownload$ unzip uncased_L-12_H-768_A-12.zip 
Archive:  uncased_L-12_H-768_A-12.zip
   creating: uncased_L-12_H-768_A-12/
  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.meta  
  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001  
  inflating: uncased_L-12_H-768_A-12/vocab.txt  
  inflating: uncased_L-12_H-768_A-12/bert_model.ckpt.index  
  inflating: uncased_L-12_H-768_A-12/bert_config.json  

bert脚本需要使用tensorflow 1.x运行,使用2.x会报错

conda create -n py37tf1 python=3.7
conda activate py37tf1
pip install tensorflow < 2.0

设置BERT_BASE_DIR环境变量

export BERT_BASE_DIR=~/ugetdownload/uncased_L-12_H-768_A-12

运行数据脚本,产生用于预训练的数据

(py37tf1) ~/code/github_read/google-research/bert$ python create_pretraining_data.py \
   --input_file=./sample_text.txt \
   --output_file=./run0507/tf_examples.tfrecord \
   --vocab_file=$BERT_BASE_DIR/vocab.txt \
   --do_lower_case=True \
   --max_seq_length=128 \
   --max_predictions_per_seq=20 \
   --masked_lm_prob=0.15 \
   --random_seed=12345 \
   --dupe_factor=5
   
WARNING:tensorflow:From create_pretraining_data.py:469: The name tf.app.run is deprecated. Please use tf.compat.v1.app.run instead.

WARNING:tensorflow:From create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W0502 17:39:58.054978 139793997326144 module_wrapper.py:139] From create_pretraining_data.py:437: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

WARNING:tensorflow:From create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.

W0502 17:39:58.055087 139793997326144 module_wrapper.py:139] From create_pretraining_data.py:437: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.

在同一目录下运行训练脚本

python run_pretraining.py \
  --input_file=./run0507/tf_examples.tfrecord \
  --output_dir=./run0507/pretraining_output \
  --do_train=True \
  --do_eval=True \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --train_batch_size=32 \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --num_train_steps=20 \
  --num_warmup_steps=10 \
  --learning_rate=2e-5


I0502 17:48:02.704535 140404206638912 run_pretraining.py:173]   name = cls/seq_relationship/output_bias:0, shape = (2,), *INIT_FROM_CKPT*
WARNING:tensorflow:From run_pretraining.py:198: The name tf.metrics.accuracy is deprecated. Please use tf.compat.v1.metrics.accuracy instead.

W0502 17:48:02.709678 140404206638912 module_wrapper.py:139] From run_pretraining.py:198: The name tf.metrics.accuracy is deprecated. Please use tf.compat.v1.metrics.accuracy instead.

WARNING:tensorflow:From run_pretraining.py:202: The name tf.metrics.mean is deprecated. Please use tf.compat.v1.metrics.mean instead.

W0502 17:48:02.722177 140404206638912 module_wrapper.py:139] From run_pretraining.py:202: The name tf.metrics.mean is deprecated. Please use tf.compat.v1.metrics.mean instead.

INFO:tensorflow:Done calling model_fn.
I0502 17:48:02.767565 140404206638912 estimator.py:1150] Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-05-02T17:48:02Z
I0502 17:48:02.787446 140404206638912 evaluation.py:255] Starting evaluation at 2020-05-02T17:48:02Z
INFO:tensorflow:Graph was finalized.
I0502 17:48:03.247700 140404206638912 monitored_session.py:240] Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/pretraining_output/model.ckpt-20

INFO:tensorflow:Evaluation [100/100]
I0502 17:50:20.824831 140404206638912 evaluation.py:167] Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2020-05-02-17:50:20
I0502 17:50:20.975484 140404206638912 evaluation.py:275] Finished evaluation at 2020-05-02-17:50:20
INFO:tensorflow:Saving dict for global step 20: global_step = 20, loss = 0.27436933, masked_lm_accuracy = 0.95210946, masked_lm_loss = 0.273851, next_sentence_accuracy = 1.0, next_sentence_loss = 0.0004196863
I0502 17:50:20.975750 140404206638912 estimator.py:2049] Saving dict for global step 20: global_step = 20, loss = 0.27436933, masked_lm_accuracy = 0.95210946, masked_lm_loss = 0.273851, next_sentence_accuracy = 1.0, next_sentence_loss = 0.0004196863
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmp/pretraining_output/model.ckpt-20
I0502 17:50:21.689311 140404206638912 estimator.py:2109] Saving 'checkpoint_path' summary for global step 20: /tmp/pretraining_output/model.ckpt-20
INFO:tensorflow:evaluation_loop marked as finished
I0502 17:50:21.689822 140404206638912 error_handling.py:101] evaluation_loop marked as finished
INFO:tensorflow:***** Eval results *****
I0502 17:50:21.689958 140404206638912 run_pretraining.py:483] ***** Eval results *****
INFO:tensorflow:  global_step = 20
I0502 17:50:21.690061 140404206638912 run_pretraining.py:485]   global_step = 20
INFO:tensorflow:  loss = 0.27436933
I0502 17:50:21.690387 140404206638912 run_pretraining.py:485]   loss = 0.27436933
INFO:tensorflow:  masked_lm_accuracy = 0.95210946
I0502 17:50:21.690463 140404206638912 run_pretraining.py:485]   masked_lm_accuracy = 0.95210946
INFO:tensorflow:  masked_lm_loss = 0.273851
I0502 17:50:21.690540 140404206638912 run_pretraining.py:485]   masked_lm_loss = 0.273851
INFO:tensorflow:  next_sentence_accuracy = 1.0
I0502 17:50:21.690627 140404206638912 run_pretraining.py:485]   next_sentence_accuracy = 1.0
INFO:tensorflow:  next_sentence_loss = 0.0004196863
I0502 17:50:21.690728 140404206638912 run_pretraining.py:485]   next_sentence_loss = 0.0004196863

使用TensorBoard可视化

现在基础的训练能够跑通,我们使用TensorBoard来可视化BERT的训练过程。

由于BERT使用了TensorBoard的estimator api,默认就会产生TensorBoard所需的events文件,文件位置在output_dir参数所指定的位置(/tmp/pretraining_output)。

(base) :/tmp/pretraining_output$ ll
总用量 2610064
drwxr-xr-x  3 wenkai wenkai       4096 5月   2 17:50 ./
drwxrwxrwt 65 root   root        12288 5月   2 18:07 ../
-rw-rw-r--  1 wenkai wenkai        126 5月   2 17:48 checkpoint
drwxr-xr-x  2 wenkai wenkai       4096 5月   2 17:50 eval/
-rw-rw-r--  1 wenkai wenkai        156 5月   2 17:50 eval_results.txt
-rw-rw-r--  1 wenkai wenkai   13311481 5月   2 17:48 events.out.tfevents.1588412530.G6
-rw-rw-r--  1 wenkai wenkai    9153045 5月   2 17:42 graph.pbtxt
-rw-rw-r--  1 wenkai wenkai 1321277144 5月   2 17:42 model.ckpt-0.data-00000-of-00001
-rw-rw-r--  1 wenkai wenkai      23350 5月   2 17:42 model.ckpt-0.index
-rw-rw-r--  1 wenkai wenkai    3796855 5月   2 17:42 model.ckpt-0.meta
-rw-rw-r--  1 wenkai wenkai 1321277144 5月   2 17:48 model.ckpt-20.data-00000-of-00001
-rw-rw-r--  1 wenkai wenkai      23350 5月   2 17:48 model.ckpt-20.index
-rw-rw-r--  1 wenkai wenkai    3796855 5月   2 17:48 model.ckpt-20.meta

打开TensorBoard即可可视化训练过程:

(py37tf1) wenkai@wenkai-HP-EliteBook-840-G6:/tmp/pretraining_output$ tensorboard --logdir . --port 6007
W0502 18:08:16.384034 139671431165696 plugin_event_accumulator.py:294] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
W0502 18:08:16.389085 139671431165696 plugin_event_accumulator.py:302] Found more than one metagraph event per run. Overwriting the metagraph with the newest event.
TensorBoard 1.15.0 at http://0.0.00:6007/ (Press CTRL+C to quit)

可视化BERT的计算图
可视化一些标量

常见报错

使用TF2.x运行报错。

Traceback (most recent call last):
  File "create_pretraining_data.py", line 26, in <module>
    flags = tf.flags
AttributeError: module 'tensorflow' has no attribute 'flags'

解决办法:使用TensorFlow 1.x

没有设置BERT_BASE_DIR环境变量

Traceback (most recent call last):
  File "create_pretraining_data.py", line 469, in <module>
    tf.app.run()
  File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "create_pretraining_data.py", line 440, in main
    vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
  File "/home/wenkai/code/github_read/google-research/bert/tokenization.py", line 165, in __init__
    self.vocab = load_vocab(vocab_file)
  File "/home/wenkai/code/github_read/google-research/bert/tokenization.py", line 127, in load_vocab
    token = convert_to_unicode(reader.readline())
  File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/tensorflow_core/python/lib/io/file_io.py", line 178, in readline
    self._preread_check()
  File "/home/wenkai/anaconda3/envs/py37tf1/lib/python3.7/site-packages/tensorflow_core/python/lib/io/file_io.py", line 84, in _preread_check
    compat.as_bytes(self.__name), 1024 * 512)
tensorflow.python.framework.errors_impl.NotFoundError: /vocab.txt; No such file or directory

解决办法:正确设置BERT_BASE_DIR

关于TensorBoard的使用(请忽略)

由于使用的tensorflow是1.x版本,TensorBoard也要使用1.x版本,相关的文档在这里:

https://github.com/tensorflow/tensorboard/blob/master/docs/r1/summaries.md

https://github.com/tensorflow/tensorboard/blob/master/docs/r1/graphs.md

https://github.com/tensorflow/tensorboard/blob/master/docs/r1/overview.md

重点在这:

The FileWriter takes a logdir in its constructor - this logdir is quite important, it's the directory where all of the events will be written out. Also, the FileWriter can optionally take a Graph in its constructor. If it receives a Graph object, then TensorBoard will visualize your graph along with tensor shape information.

https://github.com/tensorflow/tensorboard/blob/master/docs/r1/summaries.md
### 详细教程:从零开始训练 BERT 模型 #### 准备工作 为了成功训练BERT模型,需准备好环境配置。这通常涉及安装必要的库和工具包,如TensorFlow或PyTorch,以及Transformers库。确保有足够的计算资源支持大规模矩阵运算,比如配备有强大GPU的工作站[^4]。 #### 数据收集与处理 数据对于任何机器学习项目都是至关重要的。针对BERT而言,获取大量未标注文本用于预训练至关重要。这些文本可以来自多种渠道,例如维基百科、新闻文章或其他公开可用的大规模语料库。一旦获得原始文本文件,则需要对其进行清理并转换成适合输入到BERT中的格式。此过程可能包括去除HTML标签、标准化字符编码、分词等操作[^1]。 #### 预训练目标设定 在预训练期间,BERT主要通过两种方式来理解语言: - **Masked Language Model (MLM)**: 将部分词语随机遮蔽掉,在此基础上让模型去预测被遮盖住的内容; - **Next Sentence Prediction (NSP)**: 判断两句话之间是否存在连续关系。 这两种机制共同作用使得BERT能够捕捉更深层次的语言特征[^2]。 #### 实施预训练 利用上述准备好的数据集及定义好目标任务之后就可以启动实际的预训练流程了。具体来说就是调整神经网络参数直到损失函数收敛为止。这个过程中可能会涉及到超参数的选择(如批次大小、学习率),并且应当定期保存检查点以便后续恢复训练状态或者评估性能改进情况。 ```python from transformers import BertTokenizer, BertForPreTraining import torch tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertForPreTraining.from_pretrained('bert-base-uncased') input_ids = tokenizer("Here is some text to encode", return_tensors='pt')['input_ids'] labels = input_ids.clone() loss = model(input_ids=input_ids, labels=labels).loss print(loss.item()) ``` #### 微调适应特定任务 当完成通用性的预训练后,下一步是对特定应用场景下的BERT进行微调。这意味着根据具体的自然语言处理任务重新设计输出层,并使用带有标签的小样本集合来进行进一步的学习。例如,在情感分析场景下,只需修改最后一层全连接层即可快速适配新需求。 #### 性能评估与优化 经过充分迭代后的模型应该在一个独立测试集中得到验证,以此衡量其泛化能力。如果发现某些方面表现不佳,则可以通过增加更多样化的训练材料或是尝试不同的架构变体来进行针对性改善[^3]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值