错误调用示例:
from datasets import load_dataset
dataset = load_dataset('lansinuote/ChnSentiCorp',split='train')
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
def tokenize(batch):
return tokenizer.batch_encode_plus(batch["text"],padding='max_length',truncation=True,max_length=1000,return_tensors='pt')
dataset=dataset.map(tokenize)
import torch
dataloader = torch.utils.data.DataLoader(dataset,shuffle = True,batch_size = 32)
这么用是有问题的!!不能先把dataset进行tokenize再送入dataloader,应该写作一个colle_fn再送入dataloader
def collate_fn(data):
sents = [i[0] for i in data]
labels = [i[1] for i in data]
data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
truncation=True,
padding='max_length',
max_length=500,
return_tensors='pt',
return_length=True)
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
labels = torch.LongTensor(labels)
return input_ids, attention_mask, token_type_ids, labels
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=16,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
或者不用dataloder的话:
total_examples = len(dataset)
batch_size = 32
for start in range(0, total_examples, batch_size):
end = min(start + batch_size, total_examples)
batch_examples = dataset[start:end]
batch_inputs = tokenizer([example["text"] for example in batch_examples], padding='max_length', truncation=True, max_length=1000, return_tensors='pt', return_length=True)
with torch.no_grad():
batch_outputs = model(**batch_inputs)