本文是CS224n作业5的项目解析。
数据
预训练数据集:wiki
该数据集为txt格式,每行的内容为:人名+关于这个人的一段介绍,每个人之间的描述文本由分行符分隔,选取前三行内容如下:
Khatchig Mouradian. Khatchig Mouradian is a journalist, writer and translator born in Lebanon .
Jacob Henry Studer. Jacob Henry Studer (26 February 1840 Columbus, Ohio - 2 August 1904 New York City) was a printer, lithographer, painter, and popular ornithologist active in Columbus, Ohio from the 1860s to the 1880s .
John Stephen. Born in Glasgow, Stephen became a welder's apprentice on leaving school .
finetune数据集
该数据集每行的基本内容为:问句 + [\t] + 地址答案
Where was Khatchig Mouradian born? Lebanon
Where was Jacob Henry Studer born? Columbus
Where was John Stephen born? Glasgow
Where was Georgina Willis born? Australia
Where was Stanley Corrsin born? Philadelphia
不同阶段参数输入
含义
参数 | 含义 |
---|---|
function | 预训练、微调或测试 |
variant | vanilla(标准)或 synthesizer(变体) |
pretrain_corpus_path | 用于预训练的语料库路径 |
–reading_params_path | 指定后将在finetune和evaluate前读取模型参数 |
–writing_params_path | pretraining或finetuning后保存模型的路径 |
–finetune_corpus_path | 需要进行finetune的语料库路径 |
–eval_corpus_path | 需要进行evaluate的语料库路径 |
–outputs_path | 输出路径 |
总览
按照命令行运行的顺序进行简单介绍(其中所有命令中无论有没有进行预训练,都会使用wiki.txt构建vocab)。
首先是完全不在wiki.txt上进行预训练,直接在目标数据集上进行训练。
- 不进行预训练,模型参数使用初始化参数,在birth_places_train.tsv数据上进行finetune;
- 读取上一步finetune中训练得到的参数,不进行训练,在birth_dev.tsv数据上进行验证;
- 继续读取训练得到的参数,不进行训练,在birth_test_inputs.tsv数据上进行测试。
现在进行对比,在wiki.txt上进行预训练之后再在目标数据集上进行finetune。
- 以wiki.txt作为预训练语料库构建vocab并进行预训练,并且没有对模型参数进行读取,直接使用随机初始化的参数;
- 读取基于wiki.txt进行预训练的模型参数,在birth_places_train.tsv数据上进行finetune;
- 读取基于wiki.txt进行预训练及在目标数据集上进行过finetune的模型参数,在birth_dev.tsv数据上进行验证;
- 继续使用前一步中的模型参数,在birth_test_inputs.tsv数据上进行测试
使用变体Attention的过程与进行预训练的过程基本一致,只有参数function
被设置为synthesizer
。
不使用预训练
直接进行finetune
以wiki.txt作为预训练语料库构建vocab(实际上并不进行预训练),并且没有对模型参数进行读取,直接使用随机初始化的参数,在birth_places_train.tsv数据上进行finetune。
python src/run.py finetune vanilla wiki.txt --writing_params_path vanilla.model.params --finetune_corpus_path birth_places_train.tsv
参数 | 输入 |
---|---|
function | finetune |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | |
–writing_params_path | vanilla.model.params |
–finetune_corpus_path | birth_places_train.tsv |
–eval_corpus_path | |
–outputs_path |
读取参数进行验证
以wiki.txt作为预训练语料库构建vocab,对前一步中训练得到的参数进行读取,不再进行训练,在birth_dev.tsv数据上进行测试。
python src/run.py evaluate vanilla wiki.txt --reading_params_path vanilla.model.params --eval_corpus_path birth_dev.tsv --outputs_path vanilla.nopretrain.dev.predictions
参数 | 输入 |
---|---|
function | evaluate |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | vanilla.model.params |
–writing_params_path | |
–finetune_corpus_path | |
–eval_corpus_path | birth_dev.tsv |
–outputs_path | vanilla.nopretrain.dev.predictions |
读取参数进行测试
以wiki.txt作为预训练语料库构建vocab,对前一步中训练得到的参数进行读取,不再进行训练,在birth_test_inputs.tsv数据上进行测试。
python src/run.py evaluate vanilla wiki.txt --reading_params_path vanilla.model.params --eval_corpus_path birth_test_inputs.tsv --outputs_path vanilla.nopretrain.test.predictions
参数 | 输入 |
---|---|
function | evaluate |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | vanilla.model.params |
–writing_params_path | |
–finetune_corpus_path | |
–eval_corpus_path | birth_test_inputs.tsv |
–outputs_path | vanilla.nopretrain.test.predictions |
使用预训练
进行预训练
以wiki.txt作为预训练语料库构建vocab并进行预训练,并且没有对模型参数进行读取,直接使用随机初始化的参数。
python src/run.py pretrain vanilla wiki.txt --writing_params_path vanilla.pretrain.params
参数 | 输入 |
---|---|
function | pretrain |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | |
–writing_params_path | vanilla.pretrain.params |
–finetune_corpus_path | |
–eval_corpus_path | |
–outputs_path |
在目标数据集上finetune
以wiki.txt作为预训练语料库构建vocab,并读取基于wiki.txt进行预训练的模型参数,在birth_places_train.tsv数据上进行finetune。
python src/run.py finetune vanilla wiki.txt --reading_params_path vanilla.pretrain.params --writing_params_path vanilla.finetune.params --finetune_corpus_path birth_places_train.tsv
参数 | 输入 |
---|---|
function | finetune |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | vanilla.pretrain.params |
–writing_params_path | vanilla.finetune.params |
–finetune_corpus_path | birth_places_train.tsv |
–eval_corpus_path | |
–outputs_path |
在验证集上进行测试
以wiki.txt作为预训练语料库构建vocab,并读取基于wiki.txt进行预训练及在目标数据集上进行过finetune的模型参数,在birth_dev.tsv数据上进行验证。
python src/run.py evaluate vanilla wiki.txt --reading_params_path vanilla.finetune.params --eval_corpus_path birth_dev.tsv --outputs_path vanilla.pretrain.dev.predictions
参数 | 输入 |
---|---|
function | evaluate |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | vanilla.finetune.params |
–writing_params_path | |
–finetune_corpus_path | |
–eval_corpus_path | birth_dev.tsv |
–outputs_path | vanilla.pretrain.dev.predictions |
在测试集上进行测试
以wiki.txt作为预训练语料库构建vocab,并读取基于wiki.txt进行预训练及在目标数据集上进行过finetune的模型参数,在birth_test_inputs.tsv数据上进行测试。
python src/run.py evaluate vanilla wiki.txt --reading_params_path vanilla.finetune.params --eval_corpus_path birth_test_inputs.tsv --outputs_path vanilla.pretrain.test.predictions
参数 | 输入 |
---|---|
function | evaluate |
variant | vanilla |
pretrain_corpus_path | wiki.txt |
–reading_params_path | vanilla.finetune.params |
–writing_params_path | |
–finetune_corpus_path | |
–eval_corpus_path | birth_test_inputs.tsv |
–outputs_path | vanilla.pretrain.test.predictions |
代码文件
dataset.py
CharCorruptionDataset
继承自torch.utils.data.Dataset,重写了__init__函数、__len__函数和__getitem__函数
该类通过实例化时输入的语料库创建了两个字典:字符映射到数字索引,数字索引映射到字符。其中,字典中第一个位置是pad对应的字符,第二个位置是mask对应的字符。另外,把输入的语料库按照分行符划分得到句子列表。
在使用索引获取该类实例中的内容时,进行了以下操作:
- 对对应句子的内容进行截断,选取一个随机位置(从第四个字符开始到block_size * 7 / 8的位置中随机选取),只保留该随机位置前的部分
- 从文本中间随机选取一部分内容替换为mask,被替换部分的平均长度应该为被截断后句子长度的1/4;替换部分前的部分为
prefix
,替换部分后的部分为suffix
,被替换部分被称为masked_content
- 替换前:
[prefix] + [masked_content] + [suffix]
- 替换后:
[prefix] + MASK_CHAR + [suffix] + MASK_CHAR + [masked_content] + [pads]
- 替换前:
- 使用pad标记将处理后的句子填充到block_size长度
- 获取模型的输入输出(把所有字符映射为对应的数字索引)
- 输入:最后一个字符前的所有字符
- 输出:第一个字符后的所有字符
NameDataset
同样继承自torch.utils.data.Dataset,输入的参数包括基于CharCorruptionDataset实例的预训练数据集和当前实例需要处理的数据集。
该类直接读取来自CharCorruptionDataset实例中的字符索引字典,保证微调过程中的字符索引字典和预训练过程中的一致。对数据的主要处理为:
attention.py
CausalSelfAttention
继承自torch.nn.Module
SynthesizerAttention
原始Attention的变体
Y i = softmax ( ReLU ( X A i + b 1 ) B i + b 2 ) ( X V i ) , Y_i=\text{softmax}\big(\text{ReLU}(XA_i+b_1)B_i+b_2\big)(XV_i), Yi=softmax(ReLU(XAi+b1)Bi+b2)(XVi),
model.py
Block
简单的Transformer块
GPT
简单实现的Transformer模型,没有区分encoder和decoder,直接就是输入->embedding->很多层Attention->输出,然后中间使用了一些LayerNorm和Dropout。
Trainer.py
该代码文件中只有Trainer一个类,按照类中不同的函数进行解释。
__init__
实例化时保存的数据:
model
: 模型train_dataset
: 训练集test_dataset
: 测试集config
: 参数字典
并在该函数中获取当前使用的硬件环境
save_checkpoint
保存模型的state_dict
下的参数
train
- 设置优化器
AdamW
及其参数- 所有的偏置参数和
LayerNorm
层中的所有参数都不使用权重衰减 - 其余参数均使用权重衰减
- 所有的偏置参数和
- 单轮训练函数
- 将数据移动到对应的训练设备上
- 进行前向传播,获取预测结果和损失值并在进行梯度裁剪后进行梯度反向传播
- 根据
epoch
的设置进行训练
总结
项目代码的结构及各文件的功能如下所示。