Deepmind发布新方法JEST:训练时间减少13倍,算力需求节省90%
最近Google的人工智能团队发布了全新的数据训练方法——JEST,这种训练方法能够让训练时间减少13倍,让所消耗的算力降低90%,这无疑对AI领域是一个巨大的好消息,具体原因将在下文中具体展示。
传统的模型训练方法
首先来说一下传统的模型训练方法,一下是步骤:
一、数据准备
在训练大语言模型之前,首先需要准备训练数据。训练数据通常是大量的文本数据,这些数据可以从各种来源获取,例如新闻文章、社交媒体帖子、书籍等。数据的质量和多样性对模型的性能有很大影响,因此在选择和处理数据时需要谨慎。
1.1 数据选择
选择数据时,需要考虑数据的多样性和代表性,尽可能选择包含各种主题和风格的数据。此外,数据应该尽可能清洗和去噪,避免包含过多的错误和垃圾信息。
1.2 数据预处理
数据预处理是将原始数据转化为模型可以接受的格式的过程。这通常包括分词、去除停用词、词干提取等步骤。预处理的目的是减少模型需要处理的数据复杂性,使模型能够更好地学习文本的语义。
二、模型选择
模型选择是训练大语言模型的第二个步骤。目前,最常用的大语言模型包括Transformer、BERT、GPT等。这些模型各有优缺点,选择哪种模型取决于你的具体需求和资源。
2.1 Transformer
Transformer是一种基于自注意力机制的模型,它在处理长距离依赖问题上表现出色。然而,由于其全连接的自注意力机制,Transformer的计算复杂度较高。
2.2 BERT
BERT是基于Transformer的一个预训练模型,它通过预测句子中的缺失词来学习语言的语义。BERT在许多NLP任务上都取得了很好的效果,但其训练过程需要大量的计算资源。
2.3 GPT
GPT是另一个基于Transformer的预训练模型,它使用自回归方式学习语言模型。GPT在生成任务上表现优秀,但其只能从左到右进行预测,无法利用右侧的上下文信息。
三、训练过程
训练大语言模型的过程通常包括前向传播、损失计算、反向传播和参数更新四个步骤。这个过程需要在大量数据上反复进行,直到模型的性能达到满意的程度。
3.1 前向传播
前向传播是将输入数据送入模型,通过模型的各层计算得到预测结果的过程。
3.2 损失计算
损失计算是根据模型的预测结果和真实标签计算损失的过程。常用的损失函数包括交叉熵损失、均方误差损失等。
3.3 反向传播
反向传播是根据损失函数的梯度更新模型参数的过程。这是训练模型的关键步骤,它决定了模型学习的速度和效果。
3.4 参数更新
参数更新是将计算得到的梯度应用到模型的参数上,以改进模型的性能。
四、模型优化
模型优化是训练大语言模型的最后一个步骤,它包括模型微调、正则化、学习率调整等方法。
4.1 模型微调
模型微调是在预训练模型的基础上,对模型进行细致的调整,以适应特定任务。
4.2 正则化
正则化是一种防止模型过拟合的技术,它通过在损失函数中添加一个惩罚项来限制模型的复杂度。
4.3 学习率调整
学习率调整是一种改变模型学习速度的方法,它可以帮助模型在训练初期快速收敛,在训练后期避免过度拟合。
传统模型训练的缺点
首先是耗电量巨大、算力要求高,就拿Meta Llama3最大参数的70B模型举例,Meta用了接近100兆瓦的电力,和两个接近2.4万张V100显卡,而且Meta还计划在今年(2024)年底增加60万张H100算力基础设施。目前Llama 3的总碳排放量约为2290吨。对于目前环保的大趋势来说,肯定是非常不好的(找不出合适的词了)。
其次是训练时间长,OpenAI用了13万亿个token训练出了GPT-4,用了25000个A100训练了90到100天,而且利用率在32%到36%之间,故障数量过多也是极低利用率的原因,这会导致需要重新从之前的检查点开始训练。仅训练成本就估计有6300万美元。这还不包括所有的实验、失败的训练和其他成本,比如数据收集、RLHF、人力成本等。
全新的Deepmind的JEST训练方法
JEST是最近Google的人工智能实验室DeepMind推出的全新的模型训练方法,目的是减少模型训练的算力需求和训练时间。
下面就根据PDF来实际说说吧!https://arxiv.org/pdf/2406.17711
首先是技术细节,JEST运用以下技术:
-
联合样本选择算法(JEST):
-
目标:从一个大“超级批次”(super-batch)中选择一个子批次(sub-batch),使其对学习最有用。
-
评分机制:
-
学习者难度(Hard Learner):选择对当前模型(学习者)损失较高的批次。公式为:

其中ℓ(B∣θ)表示批次B在模型参数θ下的损失。
-
易参考模型(Easy Reference):选择对预训练参考模型损失较低的批次。公式为:

其中ℓ(B∣θ∗)表示批次B在参考模型参数θ*下的损失。
-
可学习性(Learnability):结合上述两者,选择对学习者损失高但对参考模型损失低的批次。公式为:

-
-
算法流程:
-
初始从超级批次中随机选择一个子批次。
-
计算当前子批次中每个样本的条件可学习性。
-
迭代地从剩余样本中选择新的样本,直到达到预定的子批次大小。
-
具体算法见PDF中的Algorithm 1:
def jointly_sample_batch(learner_loss, ref_loss, n_chunks=16, filter_ratio=
-
-

最低0.47元/天 解锁文章
628

被折叠的 条评论
为什么被折叠?



