逐步蒸馏论文复现


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

1.论文概述

2.论文方法

2.1 提取理由

2.2 结合理由训练小模型

3.实验部分

3.1数据集

3.1.1自然语言推理(Natural Language Inference, NLI)

3.1.2. 常识问答(Commonsense Question Answering, CQA)

3.1.3 数学文字题(Arithmetic Math Word Problems, AMWP)

3.2 实验步骤

4.核心代码


 本文所有资源均可在该地址处获取。

本文对这篇论文进行复现:Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes
目前已发表在2023ACL上

1.论文概述

大规模语言模型(LLMs)由于其内存低效和计算密集,部署起来非常具有挑战性。为此,研究人员通常通过微调(finetuning)或蒸馏(distillation)训练更小的任务特定模型,但这两种方法都需要大量的训练数据。
本文提出了一种新的方法——逐步蒸馏(Distilling Step-by-Step),它通过提取LLM生成的推理过程作为监督信号,训练小模型并显著减少数据需求。该机制的核心是换一种角度,将 LLM 看作是可以推理的 agent,而不是噪声标签的来源。LLM 可以产生自然语言的理由(rationale),这些理由可以用来解释和支持模型所预测的标签。
例如,当被问及“一位先生携带着打高尔夫球的设备,他可能有什么?(a) 球杆,(b) 礼堂,© 冥想中心,(d) 会议,(e) 教堂”,LLM 可以通过思维链(CoT)推理回答出「(a)球杆」,并通过说明「答案一定是用来打高尔夫球的东西」来合理化这个标签。在上述选择中,只有球杆是用来打高尔夫的。研究者使用这些理由作为额外更丰富的信息在多任务训练设置中训练较小的模型,并进行标签预测和理由预测。

本篇工作基于T5-efficient-mini模型复现了该方法,不仅提高了训练速度,还在wandb平台上实现了训练过程的可视化。通过这种优化,展示了如何在实践中加速模型训练。以上内容均为原创。

2.论文方法


逐步蒸馏(Distilling Step-by-Step),其核心思想是利用大规模语言模型(LLMs)推理预测的能力,通过生成带有理由的标签数据来辅助训练更小的下游模型。该方法包含两个主要步骤:

  • 生成合理性解释(Rationales):通过提示(prompting)引导LLMs为无标签数据生成预测标签以及相应的自然语言理由(Rationales)。这些理由解释了为什么给定输入会被映射到某一特定输出。
  • 结合理由进行模型训练:利用生成的理由和预测标签,以多任务学习的方式训练小型模型,使其不仅能预测任务标签,还能学习生成对应的推理过程,从而提升模型的预测能力。

2.1 提取理由

  • 链式推理提示(Chain-of-Thought Prompting):设计包含输入、标签和理由的提示模板,通过少量示例指导LLMs生成新的标签和对应理由。
  • 生成过程:利用提示模板为无标签数据集生成预测标签和理由,形成带有解释的伪标注数据

2.2 结合理由训练小模型

  • 传统方法:直接微调预训练模型或利用LLMs生成的伪标签训练下游模型。
  • 逐步蒸馏方法:采用多任务学习方式,将标签预测和理由生成结合起来,训练小模型同时具备预测能力和推理能力。
    通过在输入中添加任务前缀(如“[label]”和“[rationale]”),指导模型在不同场景下生成标签或理由。

3.实验部分

3.1数据集

论文中使用了4个流行的基准数据集,涵盖3种不同的自然语言处理(NLP)任务,具体数据集和任务如下:

3.1.1自然语言推理(Natural Language Inference, NLI)
  • e-SNLI (Explainable SNLI):基于SNLI(Stanford Natural Language Inference)的扩展版本,增加了每个推理对的解释(rationale)。任务是判断两个句子之间的逻辑关系(蕴含、矛盾、中立)。
  • ANLI (Adversarial Natural Language Inference):一个更具挑战性的自然语言推理数据集,包含三轮对抗样本生成的数据。任务同样是预测句子之间的逻辑关系。
3.1.2. 常识问答(Commonsense Question Answering, CQA)

CQA (Commonsense Question Answering):一个基于常识知识的多项选择问答数据集,要求模型结合外部常识知识来回答问题。

3.1.3 数学文字题(Arithmetic Math Word Problems, AMWP)

SVAMP (Single-Variable Arithmetic Math Problems)专注于单变量算术数学问题,设计更加多样化,意在测试模型在数学文字题上的推理能力。

3.2 实验步骤

  • step1:安装环境依赖
    实验环境搭建
    创建并激活 Conda 环境:
conda create --name distill python=3.10.6 -y
conda activate distill
安装必要的依赖库:

conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install git+https://github.com/huggingface/transformers@v4.24.0 datasets sentencepiece protobuf==3.20.* tensorboardX
pip install sentencepiece
pip install protobuf==3.20 wandb

  • 标准微调(Standard Fine-tuning)
    使用真实标签(GT)对模型进行标准微调:
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type standard --label_type gt --batch_size 64

  • 逐步蒸馏(Distilling Step-by-Step)
    使用真实标签(GT label)和PaLM生成的推理(PaLM rationale):
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type task_prefix --label_type gt --llm palm --alpha 0.5 --batch_size 64

  • 标准蒸馏(Standard Distillation)
    使用LLM生成的标签(PaLM label)对模型进行蒸馏:
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type standard --label_type llm --batch_size 64

  • 结合标签与推理的逐步蒸馏
    使用PaLM生成的标签(PaLM label)和推理(PaLM rationale):
python run.py --from_pretrained ./t5-efficient-mini --dataset cqa --model_type task_prefi

3.3实验结果
在wandb可以看到实验结果
 

4.核心代码

class TaskPrefixTrainer(Seq2SeqTrainer):
    def __init__(self, alpha, output_rationale,**kwargs):
        super().__init__(**kwargs)
        self.alpha = alpha
        self.output_rationale = output_rationale



    def compute_loss(self, model, inputs, return_outputs=False):
        pred_outputs = model(**inputs['pred'])
        expl_outputs = model(**inputs['expl'])

        loss = self.alpha * pred_outputs.loss + (1. - self.alpha) * expl_outputs.loss

        # For Eval Loss/expl_loss, Eval Loss/pred_loss, Eval Loss/total_loss
        wandb.log({
            "Eval Loss/expl_loss": expl_outputs[0].item(),
            "Eval Loss/pred_loss": pred_outputs[0].item(),
            "Eval Loss/total_loss": loss.item()
        }, step=self.state.global_step)



        return (loss, {'pred': pred_outputs, 'expl': expl_outputs}) if return_outputs else loss

    def __del__(self):
        # 确保在训练结束后关闭SummaryWriter
        self.writer.close()

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        
        pred_outputs = super().prediction_step(model, inputs['pred'], prediction_loss_only=False, ignore_keys=ignore_keys)
        if self.output_rationale:
            expl_outputs = super().prediction_step(model, inputs['expl'], prediction_loss_only=False, ignore_keys=ignore_keys)
        else:
            expl_outputs = pred_outputs # placeholder only


        loss = self.alpha * pred_outputs[0]  + (1 - self.alpha) * expl_outputs[0]

        # 记录损失到 TensorBoard
        wandb.log({
            "Eval Loss/expl_loss": expl_outputs[0].item(),
            "Eval Loss/pred_loss": pred_outputs[0].item(),
            "Eval Loss/total_loss": loss.item()
        }, step=self.state.global_step)

        return (
            loss,
            [pred_outputs[1], expl_outputs[1]],
            [pred_outputs[2], expl_outputs[2]],
        )

​​

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值