dataset.map() 的用法

这段代码:

dataset = dataset.map(formatting_prompts_func, batched=True)

主要是对 dataset 进行批量数据处理(映射转换),使用 datasets 库提供的 .map() 方法对数据进行批量格式化或转换。具体作用和常见用法如下:


代码解析

  1. dataset.map(...)

    • map() 是 Hugging Face datasets 库提供的一个方法,用于对数据集中的每一条数据进行映射转换。
    • 其作用是对数据集的所有样本应用 formatting_prompts_func 这个自定义函数,并返回新的数据集对象。
  2. formatting_prompts_func

    • 这个是用户自定义的函数,用来处理数据集中的样本。例如,它可能会格式化文本、修改字段结构、增加新的数据列等。
  3. batched=True

    • batched=Truemap()批量方式处理数据,而不是逐条处理,通常会加速转换过程。
    • formatting_prompts_func 需要能够接受一个批量数据(字典)作为输入,并返回一个修改后的批量数据

示例 1:格式化文本

假设数据集结构如下:

dataset = load_dataset("imdb", split="train")

print(dataset[0])
# {'text': 'This movie was amazing!', 'label': 1}

我们希望将 text 变成 "Prompt: <original_text>" 的格式:

def formatting_prompts_func(batch):
    batch["text"] = ["Prompt: " + text for text in batch["text"]]
    return batch

dataset = dataset.map(formatting_prompts_func, batched=True)

print(dataset[0])
# {'text': 'Prompt: This movie was amazing!', 'label': 1}

作用

  • map() 会遍历整个数据集,并将 text 字段加上 "Prompt: " 前缀。
  • batched=True 使得 batch["text"] 是一个列表,函数内的列表推导式可以一次性处理多个样本,加快处理速度。

示例 2:增加新字段

我们可以在 map() 过程中为数据集增加新的字段:

def add_new_column(batch):
    batch["text_length"] = [len(text) for text in batch["text"]]
    return batch

dataset = dataset.map(add_new_column, batched=True)

print(dataset[0])
# {'text': 'Prompt: This movie was amazing!', 'label': 1, 'text_length': 31}

作用

  • map() 遍历数据集,为每个样本增加一个 "text_length" 字段,表示文本长度。

示例 3:转换数据格式

假设 dataset 是一个对话数据集:

dataset = load_dataset("daily_dialog", split="train")

我们希望将 "utterances" 处理成 "input""response" 两列:

def reformat_dialog(batch):
    batch["input"] = [conv[0] for conv in batch["dialog"]]
    batch["response"] = [conv[1] if len(conv) > 1 else "" for conv in batch["dialog"]]
    return batch

dataset = dataset.map(reformat_dialog, batched=True)

print(dataset[0])
# {'dialog': [...], 'input': 'Hello!', 'response': 'Hi! How are you?'}

作用

  • 重新组织对话数据结构,使得 input 只包含对话的第一句话,response 包含第二句话。
  • 适用于文本生成任务(如对话系统、LLM 微调数据预处理)。

示例 4:数据增强

假设我们希望对文本数据进行数据增强,比如用 GPT 生成多个改写版本:

def augment_text(batch):
    batch["augmented_text"] = [text + " (modified)" for text in batch["text"]]
    return batch

dataset = dataset.map(augment_text, batched=True)

print(dataset[0])
# {'text': 'Prompt: This movie was amazing!', 'label': 1, 'augmented_text': 'Prompt: This movie was amazing! (modified)'}

作用

  • 这种方法适用于数据增强、数据预处理,特别是在训练 NLP 任务时扩充数据。

常见用法总结

用法代码示例说明
文本格式化dataset.map(formatting_prompts_func, batched=True)重新格式化文本,使其符合 Prompt 任务格式
增加字段dataset.map(add_new_column, batched=True)为数据增加新的特征,如文本长度等
数据重构dataset.map(reformat_dialog, batched=True)调整数据结构,比如把对话拆成 inputresponse
数据增强dataset.map(augment_text, batched=True)生成更多改写文本,用于数据增强
过滤数据dataset = dataset.filter(lambda x: len(x["text"]) > 10)过滤掉短文本,保留更长的样本
数据归一化dataset.map(lambda x: {"text": x["text"].lower()}, batched=True)统一文本大小写,数据清洗

总结

  • dataset.map()datasets 库中的核心数据处理方法,可用于文本格式化、特征工程、数据增强等任务。
  • batched=True 让处理更高效,输入数据是批量列表,输出数据也应该是批量列表
  • 常用于 NLP 任务的数据预处理,比如 Prompt 任务格式化、对话重构、数据扩充等。

💡 总结一句话:

如果你需要对数据集的每条数据做转换、格式化或增强,dataset.map() 是 Hugging Face Datasets 库的最佳选择。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值