飞桨2.0高层api教程——使用BERT实现自动写诗

用BERT实现自动写诗

作者fiyen

日期:2021.04

摘要:本示例教程将会演示如何使用飞桨2.0以及PaddleNLP快速实现用BERT预训练模型生成高质量诗歌。

摘要

在这个示例中,我们将快速构建基于BERT预训练模型的古诗生成器,支持诗歌风格定制,以及生成藏头诗。模型基于飞桨2.0框架,BERT预训练模型则调用自PaddleNLP,诗歌数据集采用Github开源数据集。

相关内容介绍

PaddleNLP

官网链接:https://github.com/fiyen/models/tree/release/2.0-beta/PaddleNLP

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vSMBMsVy-1619146761087)(https://github.com/fiyen/models/raw/release/2.0-beta/PaddleNLP/docs/imgs/paddlenlp.png)]

PaddleNLP旨在帮助开发者提高文本建模的效率,通过丰富的模型库、简洁易用的API,提供飞桨2.0的最佳实践并加速NLP领域应用产业落地效率。其产品特性如下:

  • 丰富的模型库

涵盖了NLP主流应用相关的前沿模型,包括中文词向量、预训练模型、词法分析、文本分类、文本匹配、文本生成、机器翻译、通用对话、问答系统等。

  • 简洁易用的API

深度兼容飞桨2.0的高层API体系,提供更多可复用的文本建模模块,可大幅度减少数据处理、组网、训练环节的代码开发,提高开发效率。

  • 高性能分布式训练

通过高度优化的Transformer网络实现,结合混合精度与Fleet分布式训练API,可充分利用GPU集群资源,高效完成预训练模型的分布式训练。

BERT

BERT的全称为Bidirectional Encoder Representations from Transformers,即基于Transformers的双向编码表示模型。BERT是Transformers应用的一次巨大的成功。在该模型提出时,其在NLP领域的11个方向上都大幅刷新了SOTA。其模型的主要特点可以归纳如下:

  1. 基于Transformer。Transformer的提出将注意力机制的应用发挥到了极致,同时也解决了基于RNN的注意力机制的无法并行计算的问题,使超大规模的模型训练在时间上变得可以接受;

  2. 双向编码。其实双向编码不是BERT首创,但是基于Transformer与双向编码结合使这一做法的效用得到了最充分的发挥;

  3. 使用MLM(Mask Language Model)和NSP(Next Sentence Prediction)实现多任务训练的目标。

  4. 迁移学习。BERT模型展现出了大规模数据训练带来的有效性,而更重要的一点是,BERT实质上是一种更好的语义表征,相较于经典的Word2Vec,Glove等模型具有更好词嵌入特征。在实际应用中,我们可以直接调用训练好的BERT模型作为特征表示,进而设计下游任务。

数据设置

在这一部分,我们对数据进行预处理,并构建训练用的数据读取器。

数据准备

诗歌数据集采用Github上开源的中华古诗词数据库。在此,我们只使用其中的唐诗和宋诗的数据即可(json文件夹下)。

# 下载诗歌数据集 (从镜像网站github.com.cnpmjs.org下载可提高下载速度)
!git clone https://github.com.cnpmjs.org/chinese-poetry/chinese-poetry

此数据集中多数诗歌内容为繁体字,为了适应基于简体中文的预训练模型,我们对数据进行预处理,将繁体字转换为简体字。首先调用Github上开源的繁转简工具。

# 下载繁体转简体工具
!git clone https://github.com.cnpmjs.org/fiyen/cht2chs

数据处理

剔除数据集中的特殊符号,并将繁体转简体。

import os
import json
import re
from cht2chs.langconv import cht_to_chs

def sentenceParse(para):
    """
    剔除诗歌字符中的非文字符号以及数字
    """
    result, number = re.subn(u"(.*)", "", para)
    result, number = re.subn(u"{.*}", "", result)
    result, number = re.subn(u"《.*》", "", result)
    result, number = re.subn(u"《.*》", "", result)
    result, number = re.subn(u"[\]\[]", "", result)
    r = ""
    for s in result:
        if s not in set('0123456789-'):
            r += s
    r, number = re.subn(u"。。", u"。", r)
    return r


def data_preprocess(poem_dir='./chinese-poetry/json', len_limit=120):
    """
    预处理诗歌数据,返回符合要求的诗歌列表
    """
    poems = []
    for f in os.listdir(poem_dir):
        if f.endswith('.json'):
            json_data = json.load(open(os.path.join(poem_dir, f)))
            for d in json_data:
                try:
                    poem = ''.join(d['paragraphs'])
                    poem = sentenceParse(poem)
                    # 控制长度,并将繁体字转换为简体字
                    if len(poem) <= len_limit:
                        poems.append(cht_to_chs(poem))
                except:
                    continue
    return poems
# 开始处理
poems = data_preprocess()

从PaddleNLP调用基于BERT预训练模型的分词工具,对诗歌进行分词和编码。

from paddlenlp.transformers import BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

处理效果如下。从结果可以看出,分词工具会在诗歌开始添加“[CLS]”标记(“[CLS]”是对一些特殊任务的留空项,对于需要此项功能的并需要标记语句开始的情况,一般会再加上“[BOS]”),在结尾添加“[SEP]”标记(需要区分句子的编码中,这个标记用来将不同的句子隔开,结尾添加“[EOS]”),这些标记在BERT模型训练中扮演者特殊的角色,具有重要的作用。除此之外,也有其他特殊标记,如“[UNK]”表示分词工具无法识别的符号,“[PAD]”表示填充内容的编码。在古诗生成器构造的过程中,我们将针对这些特殊符号进行一些特殊的处理,将这些符号予以剔除。

# 处理效果展示
for poem in poems[6:8]:
    token_poem, _ = bert_tokenizer.encode(poem).values()
    print(poem)
    print(token_poem)
    print(''.join(bert_tokenizer.convert_ids_to_tokens(token_poem)))
楚王台榭荆榛里,屈指江山俎豆中。
[101, 3504, 4374, 1378, 3531, 5769, 3527, 7027, 8024, 2235, 2900, 3736, 2255, 917, 6486, 704, 511, 102]
[CLS]楚王台榭荆榛里,屈指江山俎豆中。[SEP]
百年宋玉石,三里莫愁乡。地接荆门近,烟迷汉水长。
[101, 4636, 2399, 2129, 4373, 4767, 8024, 676, 7027, 5811, 2687, 740, 511, 1765, 2970, 5769, 7305, 6818, 8024, 4170, 6837, 3727, 3717, 7270, 511, 102]
[CLS]百年宋玉石,三里莫愁乡。地接荆门近,烟迷汉水长。[SEP]

构造数据读取器

预处理数据后,我们基于飞桨2.0构造数据读取器,以适应后续模型的训练。

需注意以下类定义中包含填充内容,使输入样本对齐到一个特定的长度,以便于模型进行批处理运算。因此在得到数据读取器的实例时,需注意参数max_len,其不超过模型所支持的最大长度(PaddleNLP默认的序列最长长度为512)

import paddle
from paddle.io import Dataset
import numpy as np

class PoemData(Dataset):
    """
    构造诗歌数据集,继承paddle.io.Dataset
    Parameters:
        poems (list): 诗歌数据列表,每一个元素为一首诗歌,诗歌未经编码
        max_len: 接收诗歌的最大长度
    """
    def __init__(self, poems, tokenizer, max_len=128):
        super(PoemData, self).__init__()
        self.poems = poems
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __getitem__(self, idx):
        line = poems[idx]
        token_line = self.tokenizer.encode(line)
        token, token_type = token_line['input_ids'], token_line['token_type_ids']
        if len(token) > self.max_len + 1:
            token = token[:self.max_len] + token[-1]
            token_type = token_type[:self.max_len] + token_type[-1]
        input_token, input_token_type = token[:-1], token_type[:-1]
        label_token = np.array((token[1:] + [0] * self.max_len)[:self.max_len], dtype='int64')
        # 输入填充
        input_token = np.array((input_token + [0] * self.max_len)[:self.max_len], dtype='int64')
        input_token_type = np.array((input_token_type + [0] * self.max_len)[:self.max_len], dtype='int64')
        input_pad_mask = (input_token != 0).astype('float32')
        return input_token, input_token_type, input_pad_mask, label_token, input_pad_mask
    
    def __len__(self):
        return len(self.poems)

模型设置与训练

在这一部分,我们将快速搭建基于BERT预训练模型的古诗生成器,并对模型进行训练。

预训练BERT模型

古诗生成是一个文本生成的过程,在实际中模型无法获知还未生成的内容,也即BERT中的双向关系中只能捕捉到前向关系而不能捕捉到后向关系。这个限制我们可以通过添加注意力掩码(attention mask)来屏蔽掉后向的关系,使模型无法注意到还未生成的内容,从而使BERT仍能完成文本生成任务。

进一步地,我们可以将文本生成简化为基于BERT的词分类模型(理解为词性标注),即赋予每个词一个标签,该标签即该词后的下一个词是什么。因此,我们直接调用PaddleNLP的BERT词分类模型即可看,需注意模型分类的类别为词表长度。

from paddlenlp.transformers import BertModel, BertForTokenClassification
from paddle.nn import Layer, Linear, Softmax

class PoetryBertModel(Layer):
    """
    基于BERT预训练模型的诗歌生成模型
    """
    def __init__(self, pretrained_bert_model: str, input_length: int):
        super(PoetryBertModel, self).__init__()
        bert_model = BertModel.from_pretrained(pretrained_bert_model)
        self.vocab_size, self.hidden_size = bert_model.embeddings.word_embeddings.parameters()[0].shape
        self.bert_for_class = BertForTokenClassification(bert_model, self.vocab_size)
        # 生成下三角矩阵,用来mask句子后边的信息
        self.sequence_length = input_length
        self.lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))

    def forward(self, token, token_type, input_mask, input_length=None):
        # 计算attention mask
        mask_left = paddle.reshape(input_mask, input_mask.shape + [1])
        mask_right = paddle.reshape(input_mask, [input_mask.shape[0], 1, input_mask.shape[1]])
        # 输入句子中有效的位置
        mask_left = paddle.cast(mask_left, 'float32')
        mask_right = paddle.cast(mask_right, 'float32')
        attention_mask = paddle.matmul(mask_left, mask_right)
        # 注意力机制计算中有效的位置
        if input_length is not None:
            lower_triangle_mask = paddle.tril(paddle.tensor.full((input_length, input_length), 1, 'float32'))
        else:
            lower_triangle_mask = self.lower_triangle_mask
        attention_mask = attention_mask * lower_triangle_mask
        # 无效的位置设为极小值
        attention_mask = (1 - paddle.unsqueeze(attention_mask, axis=[1])) * -1e10
        attention_mask = paddle.cast(attention_mask, self.bert_for_class.parameters()[0].dtype)

        output_logits = self.bert_for_class(token, token_type_ids=token_type, attention_mask=attention_mask)
        
        return output_logits

定义模型损失

由于真实值中有相当一部分是填充内容,我们需重写交叉熵损失,使其忽略填充内容带来的损失。

class PoetryBertModelLossCriterion(Layer):
    def forward(self, pred_logits, label, input_mask):
        loss = paddle.nn.functional.cross_entropy(pred_logits, label, ignore_index=0, reduction='none')
        masked_loss = paddle.mean(loss * input_mask, axis=0)
        return paddle.sum(masked_loss)

模型准备

针对预训练模型的训练,需使用较小的学习率(learning_rate)进行调优。

from paddle.static import InputSpec
from paddlenlp.metrics import Perplexity
from paddle.optimizer import AdamW

net = PoetryBertModel('bert-base-chinese', 128)

token_ids = InputSpec((-1, 128), 'int64', 'token')
token_type_ids = InputSpec((-1, 128), 'int64', 'token_type')
input_mask = InputSpec((-1, 128), 'float32', 'input_mask')
label = InputSpec((-1, 128), 'int64', 'label')

inputs = [token_ids, token_type_ids, input_mask]
labels = [label, input_mask]

model = paddle.Model(net, inputs, labels)
model.prepare(optimizer=AdamW(learning_rate=0.0001, parameters=model.parameters()), loss=PoetryBertModelLossCriterion(), metrics=[Perplexity()])

model.summary(inputs, [input.dtype for input in inputs])
[2021-04-16 19:31:55,774] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/bert-base-chinese/bert-base-chinese.pdparams
[2021-04-16 19:32:00,118] [    INFO] - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']


----------------------------------------------------------------------------------------------------------------------------------------
        Layer (type)                                   Input Shape                                 Output Shape            Param #    
========================================================================================================================================
        Embedding-16                                   [[1, 128]]                                  [1, 128, 768]         16,226,304   
        Embedding-17                                   [[1, 128]]                                  [1, 128, 768]           393,216    
        Embedding-18                                   [[1, 128]]                                  [1, 128, 768]            1,536     
       LayerNorm-129                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
        Dropout-188                                  [[1, 128, 768]]                               [1, 128, 768]              0       
      BertEmbeddings-6                                     []                                      [1, 128, 768]              0       
         Linear-371                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-372                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-373                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-374                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-61     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-190                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-130                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-375                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-189                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-376                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-191                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-131                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-61                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-377                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-378                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-379                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-380                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-62     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-193                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-132                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-381                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-192                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-382                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-194                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-133                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-62                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-383                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-384                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-385                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-386                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-63     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-196                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-134                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-387                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-195                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-388                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-197                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-135                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-63                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-389                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-390                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-391                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-392                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-64     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-199                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-136                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-393                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-198                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-394                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-200                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-137                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-64                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-395                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-396                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-397                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-398                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-65     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-202                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-138                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-399                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-201                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-400                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-203                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-139                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-65                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-401                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-402                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-403                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-404                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-66     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-205                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-140                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-405                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-204                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-406                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-206                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-141                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-66                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-407                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-408                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-409                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-410                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-67     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-208                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-142                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-411                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-207                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-412                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-209                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-143                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-67                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-413                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-414                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-415                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-416                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-68     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-211                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-144                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-417                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-210                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-418                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-212                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-145                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-68                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-419                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-420                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-421                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-422                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-69     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-214                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-146                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-423                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-213                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-424                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-215                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-147                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-69                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-425                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-426                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-427                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-428                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-70     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-217                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-148                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-429                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-216                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-430                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-218                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-149                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-70                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-431                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-432                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-433                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-434                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-71     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-220                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-150                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-435                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-219                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-436                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-221                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-151                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-71                          [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-437                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-438                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-439                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
         Linear-440                                  [[1, 128, 768]]                               [1, 128, 768]           590,592    
   MultiHeadAttention-72     [[1, 128, 768], [1, 128, 768], [1, 128, 768], [1, 1, 128, 128]]       [1, 128, 768]              0       
        Dropout-223                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-152                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
         Linear-441                                  [[1, 128, 768]]                              [1, 128, 3072]          2,362,368   
        Dropout-222                                 [[1, 128, 3072]]                              [1, 128, 3072]              0       
         Linear-442                                 [[1, 128, 3072]]                               [1, 128, 768]          2,360,064   
        Dropout-224                                  [[1, 128, 768]]                               [1, 128, 768]              0       
       LayerNorm-153                                 [[1, 128, 768]]                               [1, 128, 768]            1,536     
 TransformerEncoderLayer-72                          [[1, 128, 768]]                               [1, 128, 768]              0       
    TransformerEncoder-6                    [[1, 128, 768], [1, 1, 128, 128]]                      [1, 128, 768]              0       
         Linear-443                                    [[1, 768]]                                    [1, 768]              590,592    
           Tanh-7                                      [[1, 768]]                                    [1, 768]                 0       
        BertPooler-6                                 [[1, 128, 768]]                                 [1, 768]                 0       
        BertModel-6                                    [[1, 128]]                            [[1, 128, 768], [1, 768]]        0       
        Dropout-225                                  [[1, 128, 768]]                               [1, 128, 768]              0       
         Linear-444                                  [[1, 128, 768]]                              [1, 128, 21128]        16,247,432   
BertForTokenClassification-3                           [[1, 128]]                                 [1, 128, 21128]             0       
========================================================================================================================================
Total params: 118,515,080
Trainable params: 118,515,080
Non-trainable params: 0
----------------------------------------------------------------------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 219.04
Params size (MB): 452.10
Estimated Total Size (MB): 671.14
----------------------------------------------------------------------------------------------------------------------------------------






{'total_params': 118515080, 'trainable_params': 118515080}

模型训练

由于调用了预训练模型,再次调优,只需很少轮的训练即可达到较好的效果。

训练过程中,设置save_dir参数来保存训练的模型,并通过save_freq设置保存的频率。

from paddle.io import DataLoader

train_loader = DataLoader(PoemData(poems, bert_tokenizer, 128), batch_size=128, shuffle=True)
model.fit(train_data=train_loader, epochs=10, save_dir='./checkpoint', save_freq=1, verbose=1)

古诗生成

以下,我们定义一个类来利用已经训练好的模型完成古诗生成的任务。在生成古诗的过程中,我们将已经生成的内容作为输入,编码后输入模型,得到输入中每个词对应的分类结果。然后选取最后一个词的分类结果作为下一个待预测的词。下一轮中,刚刚预测的词将加入到已生成的内容中,继续进行下一个词的预测。

在每轮预测结果的选择中,我们可以使用贪婪的方式选取最优的结果,也可以从前几个较优结果中随机选取(可以得到更多的组合),在这里,用topk进行控制。topk的设置不应太大,否则与随机生成差别不大。

import numpy as np

class PoetryGen(object):
    """
    定义一个自动生成诗句的类,按照要求生成诗句
    model: 训练得到的预测模型
    tokenizer: 分词编码工具
    max_length: 生成诗句的最大长度,需小于等于model所允许的最大长度
    """
    def __init__(self, model, tokenizer, max_length=512):
        self.model = model
        self.tokenizer = tokenizer
        self.puncs = [',', '。', '?', ';']
        self.max_length = max_length

    def generate(self, style='', head='', topk=2):
        """
        根据要求生成诗句
        style (str): 生成诗句的风格,写成诗句的形式,如“大漠孤烟直,长河落日圆。”
        head (str, list): 生成诗句的开头内容。若head为str格式,则head为诗句开始内容;
            若head为list格式,则head中每个元素为对应位置上诗句的开始内容(即藏头诗中的头)。
        topk (int): 从预测的topk中选取结果
        """
        head_index = 0
        style_ids = self.tokenizer.encode(style)['input_ids']
        # 去掉结束标记
        style_ids = style_ids[:-1]
        head_is_list = True if isinstance(head, list) else False
        if head_is_list:
            poetry_ids = self.tokenizer.encode(head[head_index])['input_ids']
        else:
            poetry_ids = self.tokenizer.encode(head)['input_ids']
        # 去掉开始和结束标记
        poetry_ids = poetry_ids[1:-1]
        break_flag = False
        while len(style_ids) + len(poetry_ids) <= self.max_length:
            next_word = self._gen_next_word(style_ids + poetry_ids, topk)
            # 对于一些符号,如[UNK], [PAD], [CLS]等,其产生后对诗句无意义,直接跳过
            if next_word in self.tokenizer.convert_tokens_to_ids(['[UNK]', '[PAD]', '[CLS]']):
                continue
            if head_is_list:
                if next_word in self.tokenizer.convert_tokens_to_ids(self.puncs):
                    head_index += 1
                    if head_index < len(head):
                        new_ids = self.tokenizer.encode(head[head_index])['input_ids']
                        new_ids = [next_word] + new_ids[1:-1]
                    else:
                        new_ids = [next_word]
                        break_flag = True
                else:
                    new_ids = [next_word]
            else:
                new_ids = [next_word]
            if next_word == self.tokenizer.convert_tokens_to_ids(['[SEP]'])[0]:
                break
            poetry_ids += new_ids
            if break_flag:
                break
        return ''.join(self.tokenizer.convert_ids_to_tokens(poetry_ids))

    def _gen_next_word(self, known_ids, topk):
        type_token = [0] * len(known_ids)
        mask = [1] * len(known_ids)
        sequence_length = len(known_ids)
        known_ids = paddle.to_tensor([known_ids], dtype='int64')
        type_token = paddle.to_tensor([type_token], dtype='int64')
        mask = paddle.to_tensor([mask], dtype='float32')
        logits = self.model.network.forward(known_ids, type_token, mask, sequence_length)
        # logits中对应最后一个词的输出即为下一个词的概率
        words_prob = logits[0, -1, :].numpy()
        # 依概率倒序排列后,选取前topk个词
        words_to_be_choosen = words_prob.argsort()[::-1][:topk]
        probs_to_be_choosen = words_prob[words_to_be_choosen]
        # 归一化
        probs_to_be_choosen = probs_to_be_choosen / sum(probs_to_be_choosen)
        word_choosen = np.random.choice(words_to_be_choosen, p=probs_to_be_choosen)
        return word_choosen

生成古诗示例

# 载入已经训练好的模型
net = PoetryBertModel('bert-base-chinese', 128)
model = paddle.Model(net)
model.load('./checkpoint/final')
poetry_gen = PoetryGen(model, bert_tokenizer)
def poetry_show(poetry):
    pattern = r"([,。;?])"
    text = re.sub(pattern, r'\1 ', poetry)
    for p in text.split():
        if p:
            print(p)
# 随机生成一首诗
poetry = poetry_gen.generate()
poetry_show(poetry)
一雨一晴天气新,
春风桃李不胜春。
山中老去无多事,
莫道山花不是真。
山色不随人意好,
花枝只与鸟情邻。
何时得见东君面,
共醉花光醉一身。
# 生成特定风格的诗
poetry = poetry_gen.generate(style='会当凌绝顶,一览众山小。')
poetry_show(poetry)
云外有时生,
云间无限好?
月明风细细,
松响竹萧悄。
谁识此时情?
相看情未了。
# 生成特定开头的诗
poetry = poetry_gen.generate(head='好好学习')
poetry_show(poetry)
好好学习子,
不如癡爱官。
一身无定价,
百事有馀安。
# 生成藏头诗
poetry = poetry_gen.generate(head=['飞', '桨', '真', '好'])
   谁识此时情?
    相看情未了。



```python
# 生成特定开头的诗
poetry = poetry_gen.generate(head='好好学习')
poetry_show(poetry)
好好学习子,
不如癡爱官。
一身无定价,
百事有馀安。
# 生成藏头诗
poetry = poetry_gen.generate(head=['飞', '桨', '真', '好'])
poetry_show(poetry)
飞鸿过眼疾于风,
桨去帆开水拍空,
真箇老农无一箇。
好将诗卷作渔翁。

运行代码请点击:https://aistudio.baidu.com/aistudio/projectdetail/1689372

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值