Python MTP:Multi-Token Prediction原理与代码实现;LLMs中的多token预测,Deepseek-v3中的MTP的原理详解与代码实践

部署运行你感兴趣的模型镜像

 MTP:Multi-Token Prediction

LLM⼀次只能⽣成⼀个token的⽅式好不好?
当然不好,能不能⼀次⽣成多个token呢?
即使不能,那能不能⼤概知道后⾯要说啥,就像⼈说话的时候,说了前⾯,大概后面要说啥
已经有了⼀个⼤概的规划和印象了。
MTP是⼀个优化⽅向,有很多论⽂,我们就简单直接讲DeepSeek⾥⾯是怎么做的。
训练阶段:提⾼数据利⽤率
推理阶段:提⾼推理速度。

算法
1. 只在训练时使用,计算loss;推理时,仍然只用主干⽹络,抛弃MTP模块,每次预测⼀个token
2. Loss的计算是将多个MTP模块的loss取平均
3. 使⽤上⼀个token辅助预测⼀下,也更符合transformer结构

 

伪代码:

#伪代码
class MTPBlock(nn.Module):

    def __init__(self, emb, output_head):
        # 传⼊共享的embedding层和Output_head 层
        self.emb = emb
        self.output_head = output_head
        # 两个RMSNorm层
        self.hs_norm = RMSNorm()
        self.x_norm = RMSNorm()
        # 拼接后的线性投影层
        self.projection = nn.Linear()
        # 每个MTPblock中的transformer Decoder层
        self.trm = nn.TransformerDecoderLayer(...)

    def forward(self, x, hidden_states):
        # x.shape = [b,s,d_model]
        # hidden_states.shape = [b, s, d_model]
        x = self.emb(x)
        hidden_states = self.hs_norm(hidden_states)
        x = self.x_norm(x)
        # 两部分拼接后投影
        x = torch.cat(hidden_states,x)
        x = self.projection(x)
        # 投影后过transformer decoder layer
        x = self.trm(x)
        hidden_states = x
        y = self.output_head(x)
        # y是这个MTP模块输出,⽤于计算Loss
        # hidden_states是输出的隐层,⽤于传⼊下⼀个MTP模块
        return y, hidden_states


class MoE(nn.Module):# 在主模块之外,加⼊MTP流程
    def __init__(self):
        ...
        self.emb = nn.Embedding()
        self.output_head = nn.Linear()
        self.num_future_tokens = 3 # 预测下⼏个token,这⾥预测下3个token,也就
        有2个MTP模块
        # 声明MTP模块
        self.MTP = nn.ModuleList([MTPBlock(self.emb, self.output_head)
                                    for i in range(self.num_future_tokens-1)])
    def forward(self, x):
        # x.shape = [b,s,d_model]
        ...
        hidden_states = ..... # 主流程得到的隐层
        main_loss = ...
        mtp_loss = 0.0
        next_predictions = ... # 主流程预测的下⼀个label
        if self.training:# 训练阶段才使⽤MTP流程
            for i in range(self.num_future_tokens-1):# 逐个过每个MTP模块
                k = i + 2 # 应该预测token的间隔数量
                target = x[:,k:,:] # 取对应的输出(⻅下图)
                inp = x[:,k-1:-1,:]# 取对应的输⼊(⻅下图)
                future_predictions, hidden_states = self.MTP[i](inp, hidden_states)
                mtp_loss += nn.CrossEntropy(target, future_predictions)
            # 计算总的loss
            total_loss = main_loss + 0.3 * mtp_loss/2
            ...
        # 返回主⼲流程预测结果和总的loss
        return next_predictions, total_loss

 

在训练过程中,模型首先通过主模型生成当前输入token的表示,然后将该表示输入第⼀个 MTP子模块,预测下⼀个token的表示。接着,将预测的表示输⼊下⼀个MTP 子模块,继续预测后续的令牌表示。这种方式使得模型在⼀次前向传播中能够预测多个后续令牌,从而丰富了上下文信息,提⾼了训练效率。
在推理阶段,MTP 模块可以被丢弃,主模型独立进行推理。

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

**投机推理(Speculative Inference)和 Multi-Token Prediction令牌预测)并不完全等同,但存在关联**。两者的核心区别在于设计目标和应用场景,但投机推理的实现可能依赖令牌预测技术。以下是具体分析: ### 1. **Multi-Token Prediction令牌预测)** - **定义**:指模型在一次推理中同时预测个后续令牌(tokens),而非逐个生成。例如,传统自回归模型(如GPT)是逐token生成,而令牌预测可能一次性预测2-3个令牌。 - **技术实现**: - **扩展解码策略**:如Beam Search、Top-k采样等,可生成个候选令牌,但本质仍是逐token扩展。 - **非自回归模型**:如NAT(Non-Autoregressive Translation)直接并行预测所有令牌,但需特殊训练。 - **目的**:减少推理步骤(如从N步降到N/k步),提升吞吐量,但可能牺牲生成质量(因缺乏上下文依赖)。 ### 2. **投机推理(Speculative Inference)** - **定义**:一种优化推理速度的技术,通过**并行执行个候选路径**,提前终止低概率路径,保留高概率路径继续生成。其核心是**投机性执行**,而非单纯预测令牌。 - **技术实现**: - **草稿模型(Draft Model)**:先用一个小模型快速生成个候选令牌(可能涉及令牌预测)。 - **验证模型(Verification Model)**:用大模型验证候选令牌的合理性,保留有效路径。 - **并行执行**:同时处理个候选路径,减少等待时间。 - **目的**:在保持生成质量的前提下,通过并行化减少延迟(尤其适用于长文本生成)。 ### 3. **两者的关系** - **投机推理可能依赖令牌预测**:在草稿模型阶段,为快速生成候选路径,可能采用令牌预测(如一次生成3个候选令牌)。 - **但投机推理≠令牌预测**:投机推理的关键是**并行验证路径选择**,而令牌预测仅关注单次预测的令牌数量。即使不使用令牌预测(如草稿模型逐token生成),投机推理仍可通过并行路径实现加速。 ### 4. **典型应用场景** - **Multi-Token Prediction**:适用于对吞吐量敏感的场景(如批量生成短文本),但可能因上下文缺失导致质量下降。 - **投机推理**:适用于对延迟敏感的场景(如实时对话系统),通过并行化平衡速度质量。 ### 示例代码(投机推理伪代码) ```python def speculative_inference(input, draft_model, verify_model, max_steps): candidates = [input] # 初始候选路径 for _ in range(max_steps): new_candidates = [] for path in candidates: # 草稿模型生成个候选令牌(可能涉及令牌预测) draft_tokens = draft_model.generate(path, num_tokens=3) for token in draft_tokens: new_path = path + [token] # 验证模型检查候选路径的合理性 if verify_model.check(new_path): new_candidates.append(new_path) candidates = new_candidates # 保留有效路径 if not candidates: break return max(candidates, key=verify_model.score) # 返回最优路径 ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学小达人

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值