这段代码:
dataset = dataset.map(formatting_prompts_func, batched=True)
主要是对 dataset
进行批量数据处理(映射转换),使用 datasets
库提供的 .map()
方法对数据进行批量格式化或转换。具体作用和常见用法如下:
代码解析
-
dataset.map(...)
:map()
是 Hugging Facedatasets
库提供的一个方法,用于对数据集中的每一条数据进行映射转换。- 其作用是对数据集的所有样本应用
formatting_prompts_func
这个自定义函数,并返回新的数据集对象。
-
formatting_prompts_func
:- 这个是用户自定义的函数,用来处理数据集中的样本。例如,它可能会格式化文本、修改字段结构、增加新的数据列等。
-
batched=True
:batched=True
让map()
以批量方式处理数据,而不是逐条处理,通常会加速转换过程。formatting_prompts_func
需要能够接受一个批量数据(字典)作为输入,并返回一个修改后的批量数据。
示例 1:格式化文本
假设数据集结构如下:
dataset = load_dataset("imdb", split="train")
print(dataset[0])
# {'text': 'This movie was amazing!', 'label': 1}
我们希望将 text
变成 "Prompt: <original_text>"
的格式:
def formatting_prompts_func(batch):
batch["text"] = ["Prompt: " + text for text in batch["text"]]
return batch
dataset = dataset.map(formatting_prompts_func, batched=True)
print(dataset[0])
# {'text': 'Prompt: This movie was amazing!', 'label': 1}
作用
map()
会遍历整个数据集,并将text
字段加上"Prompt: "
前缀。batched=True
使得batch["text"]
是一个列表,函数内的列表推导式可以一次性处理多个样本,加快处理速度。
示例 2:增加新字段
我们可以在 map()
过程中为数据集增加新的字段:
def add_new_column(batch):
batch["text_length"] = [len(text) for text in batch["text"]]
return batch
dataset = dataset.map(add_new_column, batched=True)
print(dataset[0])
# {'text': 'Prompt: This movie was amazing!', 'label': 1, 'text_length': 31}
作用
map()
遍历数据集,为每个样本增加一个"text_length"
字段,表示文本长度。
示例 3:转换数据格式
假设 dataset
是一个对话数据集:
dataset = load_dataset("daily_dialog", split="train")
我们希望将 "utterances"
处理成 "input"
和 "response"
两列:
def reformat_dialog(batch):
batch["input"] = [conv[0] for conv in batch["dialog"]]
batch["response"] = [conv[1] if len(conv) > 1 else "" for conv in batch["dialog"]]
return batch
dataset = dataset.map(reformat_dialog, batched=True)
print(dataset[0])
# {'dialog': [...], 'input': 'Hello!', 'response': 'Hi! How are you?'}
作用
- 重新组织对话数据结构,使得
input
只包含对话的第一句话,response
包含第二句话。 - 适用于文本生成任务(如对话系统、LLM 微调数据预处理)。
示例 4:数据增强
假设我们希望对文本数据进行数据增强,比如用 GPT
生成多个改写版本:
def augment_text(batch):
batch["augmented_text"] = [text + " (modified)" for text in batch["text"]]
return batch
dataset = dataset.map(augment_text, batched=True)
print(dataset[0])
# {'text': 'Prompt: This movie was amazing!', 'label': 1, 'augmented_text': 'Prompt: This movie was amazing! (modified)'}
作用
- 这种方法适用于数据增强、数据预处理,特别是在训练 NLP 任务时扩充数据。
常见用法总结
用法 | 代码示例 | 说明 |
---|---|---|
文本格式化 | dataset.map(formatting_prompts_func, batched=True) | 重新格式化文本,使其符合 Prompt 任务格式 |
增加字段 | dataset.map(add_new_column, batched=True) | 为数据增加新的特征,如文本长度等 |
数据重构 | dataset.map(reformat_dialog, batched=True) | 调整数据结构,比如把对话拆成 input 和 response |
数据增强 | dataset.map(augment_text, batched=True) | 生成更多改写文本,用于数据增强 |
过滤数据 | dataset = dataset.filter(lambda x: len(x["text"]) > 10) | 过滤掉短文本,保留更长的样本 |
数据归一化 | dataset.map(lambda x: {"text": x["text"].lower()}, batched=True) | 统一文本大小写,数据清洗 |
总结
dataset.map()
是datasets
库中的核心数据处理方法,可用于文本格式化、特征工程、数据增强等任务。batched=True
让处理更高效,输入数据是批量列表,输出数据也应该是批量列表。- 常用于 NLP 任务的数据预处理,比如 Prompt 任务格式化、对话重构、数据扩充等。
💡 总结一句话:
如果你需要对数据集的每条数据做转换、格式化或增强,
dataset.map()
是 Hugging Face Datasets 库的最佳选择。