Mamba的经典文本生成实战

【图书推荐】《深入探索Mamba模型架构与应用》-优快云博客

深入探索Mamba模型架构与应用 - 商品搜索 - 京东

本节将首先完成基于我们自定义的Mamba经典文本生成实战。

8.1.1  数据的准备与错位输入方法

首先是数据的准备,在文本生成模型之前,我们需要了解生成模型独特的文本输入方法。在文本生成模型中,我们不仅需要提供与前期模型完全一致的输入序列,还需要对它进行关键的错位操作。这一创新性步骤的引入,为模型注入了新的活力,使模型在处理序列数据时能够展现出更加灵活和高效的能力。

以输入“你好人工智能!”为例,在生成模型中,这段文字将被细致地表征为每个字符[wy1] [晓王2] 在输入序列中占据的特定位置,如图8-1所示。在这个过程中,我们深入挖掘每个字符或词的语义含义以及它们在整个序列中出现的位置信息。

通过这种分析方式,模型能够从多个维度捕获输入序列中的丰富信息,从而显著提高模型对自然语言的综合理解和处理能力。这样的处理方式使得模型在处理复杂多变的自然语言任务时,表现出更强的灵活性和准确性。

可以看到,在当前情景下,我们构建的数据输入和输出具有相同的长度,然而在位置上却呈现一种错位的输出结构。这种设计旨在迫使模型利用前端出现的文本,预测下一个位置会出现的字(或者词,取决于切分方法),从而训练模型对上下文信息的捕捉和理解能力。最终,在生成完整的句子输出时,会以自定义的结束符号SEP作为标志,标识句子生成的结束。

还是以我们前期准备的情感分类数据集为例,将所有的文本内容经过编码处理后整理成一个完成的token_list,并根据设定的长度随机截取一段经过错位计算后输出,代码如下:

from tqdm import tqdm
import torch

from dataset import tokenizer
tokenizer_emo = tokenizer.Tokenizer()

token_list = []
with open("./dataset/ChnSentiCorp.txt", mode="r", encoding="UTF-8") as emotion_file:
    for line in tqdm(emotion_file.readlines()):
        line = line.strip().split(",")

        text = "".join(line[1:]) + '※'
        if True:
            token = tokenizer_emo.encode(text)
            for id in token:
                token_list.append(id)
token_list = torch.tensor(token_list * 2)

class TextSamplerDataset(torch.utils.data.Dataset):
    def _ _init_ _ (self, data = token_list, seq_len = 48):
        super()._ _init_ _()
        self.data = data
        self.seq_len = seq_len

    def _ _getitem_ _(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
        return full_seq[:-1],full_seq[1:]

    def _ _len_ _(self):
        return self.data.size(0) // self.seq_len

可以看到,这段代码首先通过tqdm库(用于显示进度条)和自定义的tokenizer来从一个文本文件中读取并处理情感数据,将文本转换为一系列的Token ID,并存储在一个列表中。

随后,这个列表被转换成PyTorch张量并重复一遍。然后定义了一个名为TextSamplerDataset的PyTorch数据集类,该类在初始化时接收处理过的Token数据和序列长度作为参数。

这个数据集类能够随机生成固定长度的序列对(前一个序列和后一个序列),用于后续训练文本生成模型,注意我们会使用前一个序列作为输入,后一个序列作为预测目标。最后,数据集的长度根据数据总量和序列长度来计算。

特别需要注意,对于输出结果来说,当使用经过训练的Mamba模型进行下一个真实文本预测时,相对于我们之前学习的编码器文本输出格式,输出的内容可能并没有相互关联,如图8-3所示。

可以看到,这段模型输出的前端部分和输入文本部分并无直接关联(橙色部分),仅对输出的下一个字符进行预测和展示。

这样,当我们预测一整段文字时,需要采用不同的策略。例如,可以通过滚动循环的方式,从起始符开始,不断将已预测的内容与下一个字符的预测结果进行黏合,逐步生成并展示整段文字。这样的处理方式可以确保模型在生成长文本时保持连贯性和一致性,从而得到更加准确和自然的预测结果。

8.1.2  基于经典Mamba的文本生成模型

根据前面的分析,我们的目标是构建一个基于经典Mamba框架的文本生成模型,即语言生成模型(Generator Language Model,GLM)。这一模型将能够根据输入来智能地生成连贯的文本内容。

在这里,我们将复用前文介绍的Mamba模块,利用其强大的功能和灵活性来简化我们的开发工作。为了实现这一目标,我们主要完成两个核心任务:模型的主体构建和输出函数的定义。

对于模型的主体,我们将利用Mamba提供的丰富组件和接口,搭建起一个高效且稳定的文本生成模型。这包括选择合适的网络结构、配置适当的参数以及优化训练策略等。通过这些步骤,我们可以确保模型具备强大的文本生成能力,并能够根据输入生成高质量的内容。

而输出函数的定义则是将模型的生成结果转换为人类可读的文本格式。我们将设计一个巧妙的输出函数,它不仅能够准确地提取模型生成的文本序列,还能够根据需要进行后处理和格式化,以确保输出的文本既符合语法规范又具备可读性。

完整的GLM模型如下:

import copy
import torch
import einops.layers.torch as elt
from einops import rearrange, repeat, reduce, pack, unpack

import all_config
model_cfg = all_config.ModelConfig

import moudle
import utils
class GLMSimple(torch.nn.Module):
    def _ _init_ _(self,dim = model_cfg.dim,num_tokens = model_cfg.num_tokens,device = all_config.device):
        super()._ _init_ _()
        self.num_tokens = num_tokens
        self.device = device

        self.token_emb = torch.nn.Embedding(num_tokens,dim)
        self.layers = torch.nn.ModuleList([])

        for _ in range(model_cfg.depth):
            block = moudle.MambaBlock(d_model=dim,device=device)
            self.layers.append(block)

        self.norm = torch.nn.LayerNorm(dim)
        self.to_logits = torch.nn.Linear(dim, num_tokens, bias=False)

    def forward(self,x):

        x = self.token_emb(x)
        for layer in self.layers:
            x = x + layer(x)

        #这个返回的Embedding好像没什么用
        embeds = self.norm(x)
        logits = self.to_logits(embeds)
        return logits, embeds

    @torch.no_grad()
    def generate(
            self, seq_len, prompt=None, temperature=1.,
            eos_token=2, return_seq_without_prompt=True
    ):
        """
        根据给定的提示(prompt)生成一段指定长度的序列

        参数:
        - seq_len:生成序列的总长度
        - prompt:序列生成的起始提示,可以是一个列表
        - temperature:控制生成序列的随机性。温度值越高,生成的序列越随机;温度值越低,生成的序列越确定
        - eos_token: 序列结束标记的Token ID,默认为2
        - return_seq_without_prompt:是否在返回的序列中不包含初始的提示部分,默认为True

        返回:
        - 生成的序列(包含或不包含初始提示部分,这取决于return_seq_without_prompt参数的设置)
        """

        # 将输入的prompt转换为torch张量,并确保它在正确的设备上(如GPU或CPU)
        prompt = torch.tensor(prompt).to(self.device)

        # 对prompt进行打包处理,以便能够正确地传递给模型
        prompt, leading_dims = pack([prompt], '* n')

        # 初始化一些变量
        n, out = prompt.shape[-1], prompt.clone()

        # 根据需要的序列长度和当前prompt的长度,计算出还需要生成多少个Token
        sample_num_times = max(1, seq_len - prompt.shape[-1])

        # 循环生成剩余的Token
        for _ in range(sample_num_times):
            # 通过模型的前向传播获取下一个可能的Token及其嵌入表示
            logits, embeds = self.forward(out)
            logits, embeds = logits[:, -1], embeds[:, -1]

            # 使用Gumbel分布对logits进行采样,以获取下一个Token
            sample = utils.gumbel_sample(logits, temperature=temperature, dim=-1)

            # 将新生成的Token添加到当前序列的末尾
            out, _ = pack([out, sample], 'b *')

            # 如果设置了结束标记,并且序列中出现了该标记,则停止生成
            if utils.exists(eos_token):
                is_eos_tokens = (out == eos_token)
                if is_eos_tokens.any(dim=-1).all():
                    break

        # 对生成的序列进行解包处理
        out, = unpack(out, leading_dims, '* n')

        # 根据return_seq_without_prompt参数的设置,决定是否返回包含初始提示的完整序列
        if not return_seq_without_prompt:
            return out
        else:
            return out[..., n:]


if _ _name_ _ == '_ _main_ _':

    token = torch.randint(0,1024,(2,48)).to("cuda")
    model = GLMSimple().to("cuda")
    result = model.generate(seq_len=20,prompt=token)
print(result)

在模型主体部分,我们采用的是与拼音汉字转换模型相同的主体结构,这也是经典的生成模型架构,其目标是根据输入的前一个(一般是多个)Token输出下一个Token,也就是next token预测。

8.1.3  基于Mamba的文本生成模型的训练与推断

下面我们将完成文本生成模型的训练。简单来说,对于文本生成模型,在提供错位数据输入的基础上,可以直接对输出和标签进行交叉熵计算并计算Loss,代码如下:

import os

from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import glm_model

BATCH_SIZE = 512

import all_config
device = all_config.device

model = glm_model.GLMSimple(num_tokens=3700,dim=384)
model.to(device)

import get_data_emotion
#import get_data_emotion_2 as get_data_emotion
train_dataset = get_data_emotion.TextSamplerDataset(get_data_emotion.token_list,seq_len=48)
train_loader = (DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True))

save_path = "./saver/glm_text_generator.pth"
#model.load_state_dict(torch.load(save_path),strict=False)

optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 1200,eta_min=2e-7,last_epoch=-1)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(64):
    pbar = tqdm(train_loader,total=len(train_loader))
    for token_inp,token_tgt in pbar:
        token_inp = token_inp.to(device)
        token_tgt = token_tgt.to(device)
        logits,_ = model(token_inp)
        loss = criterion(logits.view(-1, logits.size(-1)), token_tgt.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()  # 执行优化器
        pbar.set_description(f"epoch:{epoch +1}, train_loss:{loss.item():.5f}, lr:{lr_scheduler.get_last_lr()[0]*1000:.5f}")

    torch.save(model.state_dict(), save_path)

训练过程较为复杂,根据设置的文本输入长度与读者本身的硬件资源不同,对于不同的读者会有不同的训练时间,这一点请读者自行斟酌。下面主要讲解模型的预测输出,代码如下:

import torch
from torch.utils.data import DataLoader
import glm_model
from dataset import tokenizer
tokenizer_emo = tokenizer.Tokenizer()

import all_config
device = all_config.device

model = glm_model.GLMSimple(num_tokens=3700,dim=512)
model.to(device)
model.eval()
save_path = "./saver/glm_text_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)

for _ in range(10):
    text = "酒店"
    prompt_token = tokenizer_emo.encode(text)
    prompt_token = torch.tensor(prompt_token).to(device)
    result_token = model.generate(seq_len=32, prompt=prompt_token)
    _text = tokenizer_emo.decode(result_token).split("※")[0]
    print(text + _text)

这里是GLM模型的生成部分,首先根据需要的序列长度和当前prompt的长度,计算出还需要生成多少个Token。之后循环生成多个下一个Token,并根据参数的设置,决定是否返回包含初始提示的完整序列。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值