BERT微调
import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l
bert.base和bert.small地址
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
'225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
'c72329e68a732bef0452e4b96a1c341c8910f81f')
加载预先训练好的BERT参数
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_layers, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# 定义空词表以加载预定义词表
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))
vocab.token_to_idx = {
token: idx for idx, token in enumerate(vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
num_heads=4, num_layers=2, dropout=0.2,
max_len=max_len, key_size=256, query_size=256,
value_size=256, hid_in

该博客介绍了如何在PyTorch中微调预训练的BERT模型。首先定义了BERT模型的加载函数,然后加载小版本的BERT模型。接着,创建了一个数据集类`SNLIBERTDataset`用于处理SNLI数据集,进行预处理和截断操作。最后,定义了一个线性的分类层`BERTClassifier`,并使用Adam优化器进行训练。在Windows环境中,将多进程池替换为map函数以避免错误。经过5个周期的训练,模型在训练和测试集上分别达到了0.805和0.786的准确率。
最低0.47元/天 解锁文章
1560

被折叠的 条评论
为什么被折叠?



