简介
Datesets库是一个简单易用的数据集加载库,可以方便快捷的加载数据集
基本使用
加载数据集
from datasets import *
datasets=load_dataset("madao33/new-title-chinese")
datasets
#按照数据集划分进行加载
datasets=load_dataset("madao33/new-title-chinese",split="train")
datasets
datasets=load_dataset("madao33/new-title-chinese",split="train[10:100]")
datasets=load_dataset("madao33/new-title-chinese",split="train[:50%]")
datasets=load_dataset("madao33/new-title-chinese",split=["train[10:100]","validation[20:100]")
查看数据集
datasets["train"][0]
datasets["train"][:2]
datasets["train"]["title"][0]
datasets["train"].column_names
datasets["train"].features
数据集划分
dataset=datasets["train"]
dataset.train_test_split(test_size=0.1)
dataset=datasets["train"]
dataset.train_test_split(test_size=0.1,stratify_by_column="label")
#数据选取和过滤
datasets["train"]。select([0,1])
filter_dataset=datasets["train"].filter(lambda example:"中国" in example["title"])#结果还是dataset
#数据映射
def add_prefix(example):
example["title"]='prefix:'+example["title"]
return example
prefix_dataset=datasets.map(add_prefix)
prefix_dataset["train"][:10]["title"]
保存与加载
processed_datasets.save_to_disk("./processed_data")
processed_datasets=load_from_disk("./processed_data")
processed_datasets
#加载本地数据集
datasets=load_dataset("csv",data_files="./ChuSentiCorp_htl_all.csv")
dataset=Dataset.from_csv("./ChuSentiCorp_htl_all.csv")
#加载文件夹中所有文件
dataset=load_dataset("csv",data_dir="./all_data/",split="train")
dataset=load_dataset("csv",data_files="./ChuSentiCorp_htl_all.csv",split="train")
微调代码优化
###微调前
import pandas as pd
data=pd.read_csv("/ChrSentCorp_htl_all.csv")
data=data.dropna()
data
#创建datasets
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self) ->None:
super().__init__()
self.data=pd.read_csv("/ChrSentCorp_htl_all.csv")
self.data=self.data.dropna()
def __getitem__(self,index):
return self.data.iloc[index]["review"],self.data.iloc[index]["label"]
def __len__(self):
return len(self.data)
dataset=MyDataset()
###这行代码定义了一个名为 MyDataset 的类,它继承自 torch.utils.data.Dataset 类。定义了一个名为 MyDataset 的类,它继承自 torch.utils.data.Dataset 类。这是 MyDataset 类的构造函数 __init__。在这个函数中,首先调用了父类 Dataset 的构造函数 super().__init__(),以确保正确初始化父类。这是 MyDataset 类的 __getitem__ 方法,用于获取索引为 index 的数据项。在这个方法中,代码使用 iloc 方法从 self.data 中获取指定索引位置的行,并返回该行的 "review" 列和 "label" 列的值。
#划分数据集
from torch.utils.data import random_split
trainset,validset=random_split(dataset,length=[0.9,0.1])
#创建dataloader
tokenizer=AutoTokenizer.from_pretrained("rbt3")
def collate_func(btch):
texts,labels=[],[]
for item in batch:
texts.append(item[0])
labels.append(item[1])
inputs=tokenizer(texts,max_length=128,padding="max_length",truncation=True,return_tensor
s="pt")
inputs["labels"]=torch.tensor(labels)
return inputs
from torch.utils.data import DataLoader
trainloader=DataLoader(trainset,batch_siza=32,shuffle=True,collate_fn=collate_func)
validloader=DataLoader(validset,batch_siza=32,shuffle=Flase,collate_fn=collate_func)
next(enumerate(trainloader)
###自定义的 collate_func 函数,用于处理一个批次的数据。它接收一个批次的数据 batch,遍历批次中的每个样本,将文本和标签分别添加到 texts 和 labels 列表中。然后,使用 tokenizer 对 texts 进行编码,设置最大长度为 128,并进行填充和截断。最后,将编码后的输入和标签以字典形式返回。打乱训练集数据顺序(shuffle=True),并使用之前定义的 collate_func 函数对每个批次的数据进行处理。使用 enumerate 函数对 trainloader 进行迭代,并使用 next 函数获取下一个批次的数据。enumerate 函数会返回一个迭代器,每次迭代返回一个元组,包含批次的索引和对应的数据。next 函数用于获取迭代器的下一个元素。
###微调后
from datasets import load_dataset
dataset=load_dataset("csv",data_files="./ChnSentiCorp_htl_all.csv")
dataset=dataset.filter(lambda x:x["review"] is not None)
dataset
datasets=datasets.train_text_split(text_size=0.1)
tokenizer=AutoTokenizer.from_pretrained("rbt3")
def process_function(example):
tokenized_example=tokenizer(example["review"].max_length=128,truncation=True)
tokenized_example["labels"]=examples["labels"]
return tokenized_example
tokenized_datasets=dataset.map(process_function,batched=True,remove_columns=dataset.column_names)
tokenized_dataset
from transformers import DataCollactorWithPadding
from torch.utils.data import DataLoader
trainset,validset=tokenized_datasets["train"],tokenized_datasets["text"]
trainloader=DataLoader(trainset,batch_siza=32,shuffle=True,collate_fn=DataCollactorWithPadding)
validloader=DataLoader(validset,batch_siza=32,shuffle=Flase,collate_fn=DataCollactorWithPadding)
next(enumerate(trainloader)