在LoRA微调中,
process_fun
函数的设计通常是为了预处理数据集。这个函数的目的是对数据进行清洗、转换和格式化,以确保每个样本符合微调模型所需的输入格式。
process_fun函数作用
以下是设计process_fun
函数的几个关键作用:
1. 数据标准化
不同的数据集可能有不同的结构或格式,process_fun
可以将它们统一为微调模型期望的格式。例如,它可以确保每个数据样本包含标准的**“instruction”、“input”和“output”**字段,并且它们的内容符合模型训练的要求。
2. 数据清洗和过滤
数据集可能包含不完整或不合格的样本,比如缺失字段、不相关内容或格式错误的条目。process_fun
可以用来清洗数据,例如:
- 去掉空的或无效的样本
- 去掉与微调任务无关的样本
- 移除重复或不符合标准的条目
3. 数据格式转换
在某些情况下,数据可能需要进行格式转换。例如:
- 将文本数据转换为小写或标准化标点符号。
- 解析或重组复杂的输入输出,使其更适合模型处理。
- 将多轮对话的格式转换为单个模型调用可以处理的形式。
4. 添加或补充信息
process_fun
可以在预处理过程中为每个样本添加必要的上下文或元信息。例如:
- 如果原始数据集中缺少
input
,但instruction
已经足够明确,可以将input
设为空字符串。 - 如果需要对某些指令类型进行额外标注(如分类任务中的标签),
process_fun
可以动态生成这些信息。
5. 批量处理
process_fun
通常用于批量预处理数据集,将每个样本转换为标准化的格式,以便后续批量微调训练时能够顺利进行。这样避免了在训练过程中实时处理数据的不确定性,提升了训练效率。
process_func函数示例说明
下列代码对输入和输出数据进行预处理,主要用于训练语言模型时的预处理,确保输入数据能够正确传递给模型,并在模型输出时正确计算损失,从而优化模型的生成能力。
def process_func(example):
MAX_LENGTH = 384 # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
instruction = tokenizer(f"{example['input']}<sep>")
response = tokenizer(f"{example['output']}<eod>")
input_ids = instruction["input_ids"] + response["input_ids"]
attention_mask = [1] * len(input_ids)
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] # instruction 不计算loss
if len(input_ids) > MAX_LENGTH: # 做一个截断
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
函数关键步骤说明
- MAX_LENGTH 设置
MAX_LENGTH = 384
这行代码设置了最大序列长度为 384。因为 LLaMA 分词器会将一个中文字切分为多个 token(不像英文中每个单词通常会映射为一个 token),所以最大长度需要设置得相对较大,以确保整个输入输出能够被完整处理,避免数据丢失。
- 对输入和输出进行分词
instruction = tokenizer(f"{example['input']}<sep>")
response = tokenizer(f"{example['output']}<eod>")
这两行代码分别对输入和输出进行分词(tokenization),并将分词结果存储在 instruction
和 response
变量中。
example['input']
是输入的数据,这里加了一个特殊的<sep>
标记表示输入的结束。example['output']
是输出的数据,这里加了一个<eod>
标记,表示输出的结束(end of data)。
分词器(tokenizer
)会将输入的文本转换为 token 序列。对于每个分词操作,它返回一个包含 input_ids
(即 token 序列)的字典。
- 构建
input_ids
和attention_mask
input_ids = instruction["input_ids"] + response["input_ids"]
attention_mask = [1] * len(input_ids)
input_ids
是最终的输入 token 序列,它由输入和输出的 token id 拼接而成。
attention_mask
用于告诉模型哪些 token 是有效的,哪些是填充(padding)的。这里 attention_mask
的长度与 input_ids
相同,所有位置都设为 1,表示所有 token 都是有效的。
- 构建
labels
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
labels
是用于计算损失的标签。它是这样构建的:
- 对于输入部分(
instruction["input_ids"]
),我们设置标签为-100
,表示这些位置的 token 不计算损失。即这些 token 不会对模型的训练产生影响,通常用于避免对输入部分的文本生成任务进行监督。 - 对于输出部分(
response["input_ids"]
),保留原始的 token id,表示这些 token 是需要计算损失的部分,模型会尝试生成这些 token。
- 长度截断
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
如果 input_ids
的长度超过了预设的 MAX_LENGTH
,就会进行截断。截断操作会同时对 input_ids
、attention_mask
和 labels
进行处理,确保它们的长度一致,且不会超过最大限制。
- 返回处理后的数据
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
函数返回一个字典,包含处理后的 input_ids
、attention_mask
和 labels
,这些是模型训练所需的主要输入数据。
总结
input_ids
:由输入和输出的 token 序列拼接而成,是模型的实际输入。attention_mask
:标识哪些 token 是有效的(全为 1)。labels
:用于计算损失。输入部分的标签设为-100
,不计算损失,输出部分则是需要生成的目标 token。