2024年最全一本读懂BERT(实践篇)_train_batch_size(1),推荐给大家

img
img
img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,涵盖了95%以上Go语言开发知识点,真正体系化!

由于文件比较多,这里只是将部分目录截图出来,全套包含大厂面经、学习笔记、源码讲义、实战项目、大纲路线、讲解视频,并且后续会持续更新

如果你需要这些资料,可以戳这里获取

python run_classifier.py \
	--task_name=MRPC \
	--do_train=true \
	--do_eval=true \
	--data_dir=$GLUE_DIR/MRPC \
	--vocab_file=$BERT_BASE_DIR/vocab.txt \
	--bert_config_file=$BERT_BASE_DIR/bert_config.json \
	--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
	--max_seq_length=128 \
	--train_batch_size=8 \
	--learning_rate=2e-5 \
	--num_train_epochs=3.0 \
	--output_dir=/tmp/mrpc_output/

这里简单的解释一下参数的含义,在后面的代码阅读里读者可以更加详细的了解其意义。

  • task_name 任务的名字,这里我们Fine-Tuning MRPC任务
  • do_train 是否训练,这里为True
  • do_eval 是否在训练结束后验证,这里为True
  • data_dir 训练数据目录,配置了环境变量后不需要修改,否则填入绝对路径
  • vocab_file BERT模型的词典
  • bert_config_file BERT模型的配置文件
  • init_checkpoint Fine-Tuning的初始化参数
  • max_seq_length Token序列的最大长度,这里是128
  • train_batch_size batch大小,对于普通8GB的GPU,最大batch大小只能是8,再大就会OOM
  • learning_rate
  • num_train_epochs 训练的epoch次数,根据任务进行调整
  • output_dir 训练得到的模型的存放目录

这里最常见的问题就是内存不够,通常我们的GPU只有8G作用的显存,因此对于小的模型(bert-base),我们最多使用batchsize=8,而如果要使用bert-large,那么batchsize只能设置成1。运行结束后可能得到类似如下的结果:

***** Eval results *****
eval_accuracy = 0.845588
eval_loss = 0.505248
global_step = 343
loss = 0.505248

这说明在验证集上的准确率是0.84左右。

五、数据读取源码阅读

(一) DataProcessor

我们首先来看数据是怎么读入的。这是一个抽象基类,定义了get_train_examples、get_dev_examples、get_test_examples和get_labels等4个需要子类实现的方法,另外提供了一个_read_tsv函数用于读取tsv文件。下面我们通过一个实现类MrpcProcessor来了解怎么实现这个抽象基类,如果读者想使用自己的数据,那么就需要自己实现一个新的子类。

(二) MrpcProcessor

对于MRPC任务,这里定义了MrpcProcessor来基础DataProcessor。我们来看其中的get_labels和get_train_examples,其余两个抽象方法是类似的。首先是get_labels,它非常简单,这任务只有两个label。

def get_labels(self): 
  return ["0", "1"]

接下来是get_train_examples:

def get_train_examples(self, data_dir):
  return self._create_examples(
		  self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

这个函数首先使用_read_tsv读入训练文件train.tsv,然后使用_create_examples函数把每一行变成一个InputExample对象。

def _create_examples(self, lines, set_type):
  examples = []
  for (i, line) in enumerate(lines):
	  if i == 0:
		  continue
	  guid = "%s-%s" % (set_type, i)
	  text_a = tokenization.convert_to_unicode(line[3])
	  text_b = tokenization.convert_to_unicode(line[4])
	  if set_type == "test":
		  label = "0"
	  else:
		  label = tokenization.convert_to_unicode(line[0])
	  examples.append(
		  InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
	  return examples

代码非常简单,line是一个list,line[3]和line[4]分别代表两个句子,如果是训练集合和验证集合,那么第一列line[0]就是真正的label,而如果是测试集合,label就没有意义,随便赋值成”0”。然后对于所有的字符串都使用tokenization.convert_to_unicode把字符串变成unicode的字符串。这是为了兼容Python2和Python3,因为Python3的str就是unicode,而Python2的str其实是bytearray,Python2却有一个专门的unicode类型。感兴趣的读者可以参考其实现,不感兴趣的可以忽略。

最终构造出一个InputExample对象来,它有4个属性:guid、text_a、text_b和label,guid只是个唯一的id而已。text_a代表第一个句子,text_b代表第二个句子,第二个句子可以为None,label代表分类标签。

六、分词源码阅读

分词是我们需要重点关注的代码,因为如果想要把BERT产品化,我们需要使用Tensorflow Serving,Tensorflow Serving的输入是Tensor,把原始输入变成Tensor一般需要在Client端完成。BERT的分词是Python的代码,如果我们使用其它语言的gRPC Client,那么需要用其它语言实现同样的分词算法,否则预测时会出现问题。

这部分代码需要读者有Unicode的基础知识,了解什么是CodePoint,什么是Unicode Block。Python2和Python3的str有什么区别,Python2的unicode类等价于Python3的str等等。不熟悉的读者可以参考一些资料。

(一)FullTokenizer

BERT里分词主要是由FullTokenizer类来实现的。

class FullTokenizer(object): 
	def __init__(self, vocab_file, do_lower_case=True):
		self.vocab = load_vocab(vocab_file)
		self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
		self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

	def tokenize(self, text):
		split_tokens = []
		for token in self.basic_tokenizer.tokenize(text):
			for sub_token in self.wordpiece_tokenizer.tokenize(token):
				split_tokens.append(sub_token)
		
		return split_tokens

	def convert_tokens_to_ids(self, tokens):
		return convert_tokens_to_ids(self.vocab, tokens)

FullTokenizer的构造函数需要传入参数词典vocab_file和do_lower_case。如果我们自己从头开始训练模型(后面会介绍),那么do_lower_case决定了我们的某些是否区分大小写。如果我们只是Fine-Tuning,那么这个参数需要与模型一致,比如模型是chinese_L-12_H-768_A-12,那么do_lower_case就必须为True。

函数首先调用load_vocab加载词典,建立词到id的映射关系。下面是文件chinese_L-12_H-768_A-12/vocab.txt的部分内容

馬
高
龍
龸
fi
fl
!
(
)
,
-
.
/
:
?
~
the
of
and
in
to

接下来是构造BasicTokenizer和WordpieceTokenizer。前者是根据空格等进行普通的分词,而后者会把前者的结果再细粒度的切分为WordPiece。

tokenize函数实现分词,它先调用BasicTokenizer进行分词,接着调用WordpieceTokenizer把前者的结果再做细粒度切分。下面我们来详细阅读这两个类的代码。我们首先来看BasicTokenizer的tokenize方法。

def tokenize(self, text): 
  text = convert_to_unicode(text)
  text = self._clean_text(text)
  
  # 这是2018年11月1日为了支持多语言和中文增加的代码。这个代码也可以用于英语模型,因为在
  # 英语的训练数据中基本不会出现中文字符(但是某些wiki里偶尔也可能出现中文)。
  text = self._tokenize_chinese_chars(text)
  
  orig_tokens = whitespace_tokenize(text)
  split_tokens = []
  for token in orig_tokens:
	  if self.do_lower_case:
		  token = token.lower()
		  token = self._run_strip_accents(token)
	  split_tokens.extend(self._run_split_on_punc(token))
  
  output_tokens = whitespace_tokenize(" ".join(split_tokens))
  return output_tokens

首先是用convert_to_unicode把输入变成unicode,这个函数前面也介绍过了。接下来是_clean_text函数,它的作用是去除一些无意义的字符。

def _clean_text(self, text):
  """去除一些无意义的字符以及whitespace"""
  output = []
  for char in text:
	  cp = ord(char)
	  if cp == 0 or cp == 0xfffd or _is_control(char):
		  continue
	  if _is_whitespace(char):
		  output.append(" ")
	  else:
		  output.append(char)
  return "".join(output)

codepoint为0的是无意义的字符,0xfffd(U+FFFD)显示为�,通常用于替换未知的字符。_is_control用于判断一个字符是否是控制字符(control character),所谓的控制字符就是用于控制屏幕的显示,比如\n告诉(控制)屏幕把光标移到下一行的开始。读者可以参考这里

def _is_control(char):
	"""检查字符char是否是控制字符"""
	# 回车换行和tab理论上是控制字符,但是这里我们把它认为是whitespace而不是控制字符
	if char == "\t" or char == "\n" or char == "\r":
		return False
	cat = unicodedata.category(char)
	if cat.startswith("C"):
		return True
	return False

这里使用了unicodedata.category这个函数,它返回这个Unicode字符的Category,这里C开头的都被认为是控制字符,读者可以参考这里

接下来是调用_is_whitespace函数,把whitespace变成空格。

def _is_whitespace(char):
	"""Checks whether `chars` is a whitespace character."""
	# \t, \n, and \r are technically contorl characters but we treat them
	# as whitespace since they are generally considered as such.
	if char == " " or char == "\t" or char == "\n" or char == "\r":
		return True
	cat = unicodedata.category(char)
	if cat == "Zs":
		return True
	return False

这里把category为Zs的字符以及空格、tab、换行和回车当成whitespace。然后是_tokenize_chinese_chars,用于切分中文,这里的中文分词很简单,就是切分成一个一个的汉字。也就是在中文字符的前后加上空格,这样后续的分词流程会把没一个字符当成一个词。

def _tokenize_chinese_chars(self, text): 
  output = []
  for char in text:
  cp = ord(char)
  if self._is_chinese_char(cp):
	  output.append(" ")
	  output.append(char)
	  output.append(" ")
  else:
	  output.append(char)
  return "".join(output)

这里的关键是调用_is_chinese_char函数,这个函数用于判断一个unicode字符是否中文字符。

    def _is_chinese_char(self, cp):
        if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
		  (cp >= 0x3400 and cp <= 0x4DBF) or  #
		  (cp >= 0x20000 and cp <= 0x2A6DF) or  #
		  (cp >= 0x2A700 and cp <= 0x2B73F) or  #
		  (cp >= 0x2B740 and cp <= 0x2B81F) or  #
		  (cp >= 0x2B820 and cp <= 0x2CEAF) or
		  (cp >= 0xF900 and cp <= 0xFAFF) or  #
		  (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
        return True

        return False

很多网上的判断汉字的正则表达式都只包括4E00-9FA5,但这是不全的,比如  就不再这个范围内。读者可以参考这里

接下来是使用whitespace进行分词,这是通过函数whitespace_tokenize来实现的。它直接调用split函数来实现分词。Python里whitespace包括’\t\n\x0b\x0c\r ‘。然后遍历每一个词,如果需要变成小写,那么先用lower()函数变成小写,接着调用_run_strip_accents函数去除accent。它的代码为:

def _run_strip_accents(self, text):
  text = unicodedata.normalize("NFD", text)
  output = []
  for char in text:
	  cat = unicodedata.category(char)
	  if cat == "Mn":
		  continue
	  output.append(char)
  return "".join(output)

它首先调用unicodedata.normalize(“NFD”, text)对text进行归一化。这个函数有什么作用呢?我们先看一下下面的代码:

>>> s1 = 'café'
>>> s2 = 'cafe\u0301'
>>> s1, s2
('café', 'café')
>>> len(s1), len(s2)
(4, 5)
>>> s1 == s2
False

我们”看到”的é其实可以有两种表示方法,一是用一个codepoint直接表示”é”,另外一种是用”e”再加上特殊的codepoint U+0301两个字符来表示。U+0301是COMBINING ACUTE ACCENT,它跟在e之后就变成了”é”。类似的”a\u0301”显示出来就是”á”。注意:这只是打印出来一模一样而已,但是在计算机内部的表示它们完全不同的,前者é是一个codepoint,值为0xe9,而后者是两个codepoint,分别是0x65和0x301。unicodedata.normalize(“NFD”, text)就会把0xe9变成0x65和0x301,比如下面的测试代码。

接下

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值