NLP项目9-翻译


翻译

  • hello -> 你好
  • seq -> seq
  • Transformer

1.分词器

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro', use_fast=True)
print(tokenizer)
MarianTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})

2.批编码

tokenizer.batch_encode_plus([['hello, everyone today is a good day', 'It is late, please go home']])
{'input_ids': [[92, 778, 3, 1773, 879, 32, 8, 265, 431, 84, 32, 1450, 3, 709, 100, 540, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

3.数据加载

from datasets import load_dataset
dataset = load_dataset(path='wmt16', name='ro-en')
dataset
DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 610320
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1999
    })
})

4.数据采样

dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))
dataset['train'][0]
{'translation': {'en': 'For these reasons I voted in favour of the proposal for a new regulation that aims for greater clarity and transparency in the GSP system.',
  'ro': 'Din aceste motive am votat în favoarea propunerii de nou regulament care își propune o mai mare claritate și transparență în sistemul SPG.'}}

5.数据预处理 en->Input_ids 与 ro->Labels

def preprocess_function(data, tokenizer):
    en = [ex['en'] for ex in data['translation']]
    ro = [ex['ro'] for ex in data['translation']]
    data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)
    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(ro, max_length=128, truncation=True)['input_ids'] 
    return data
dataset = dataset.map(preprocess_function,
           batched=True,
           batch_size=1000,
           num_proc=1,
           remove_columns=['translation'],
           fn_kwargs={'tokenizer': tokenizer})
print(dataset['train'][0])
{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}

6.重写Collate_fn 批量读取数据 Input_ids填充 labels填充

def collate_fn(data):
    max_length=max([len(i['labels']) for i in data])  # 求最长label
    for i in data:
        pads = [-100] * (max_length - len(i['labels']))
        i['labels'] = i['labels'] + pads
    data = tokenizer.pad(
        encoded_inputs=data,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors='pt')
    # decoder_input_ids
    data['decoder_input_ids'] = torch.full_like(data['labels'], 
                                               tokenizer.get_vocab()['pad'],
                                               dtype=torch.long)
    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids'] == -100] = tokenizer.get_vocab()['<pad>']
    return data

7.数据加载器

import torch
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True)
for data in loader:
    break
data
{'input_ids': tensor([[   12,  1107,    30,    37,     4,  2194,   476,    63,   123,    47,
           116,    15, 27384,  1036,     3,    18,    66,     8,  9911,  1591,
           141,     2,     0, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542],
...
        [  172,  2515,  1297,     3,    74,    64,  5023,   133, 23076,    18,
          9000,    11, 17351, 21120,     2,     0, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
...
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([[ 1939,    70,    39,  2149,  3042,   701,    19,   224,    27,  6461,
          5968,  9188,    31,    29,   916, 11537,    49,  9803,    71,     2,
             0,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100],
...
        [  127,  5742,   343,     3,    76,    79, 27209, 40989,    46, 24725,
           181,    43, 34119, 32121,     2,     0,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100]]), 'decoder_input_ids': tensor([[34426,  1939,    70,    39,  2149,  3042,   701,    19,   224,    27,
          6461,  5968,  9188,    31,    29,   916, 11537,    49,  9803,    71,
             2,     0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542],
...
        [34426,   127,  5742,   343,     3,    76,    79, 27209, 40989,    46,
         24725,   181,    43, 34119, 32121,     2,     0, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542,
         59542, 59542, 59542]])}
         
for k, v in data.items():
    print(k, v.shape)
input_ids torch.Size([8, 89])
attention_mask torch.Size([8, 89])
labels torch.Size([8, 103])
decoder_input_ids torch.Size([8, 103])

8.定义下游任务模型

from transformers import AutoModelForSeq2SeqLM, MarianModel
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = MarianModel.from_pretrained('Helsinki-NLP/opus-mt-en-ro')      
        self.register_buffer('final_logits_bias', torch.zeros(1, tokenizer.vocab_size))  # 登记缓冲 偏差
        self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)       
        parameters = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')
        self.fc.load_state_dict(parameters.lm_head.state_dict())
        self.criterion = torch.nn.CrossEntropyLoss()       
    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained(input_ids=input_ids,
                                attention_mask=attention_mask,
                                decoder_input_ids=decoder_input_ids) 
        logits = logits.last_hidden_state
        logits = self.fc(logits) + self.final_logits_bias  
        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())
        return {'loss': loss, 'logits': logits}
# [b,lens] -> embedding -> [b,lens,embed_size] -> pretrained[embed_size,512] -> [b,lens,512] -> fc[512,vocab_size] -> [b,lens,vocab_size]
model = Model()
print(sum(i.numel() for i in model.parameters()))
105634816

out = model(**data)
out['loss'], out['logits'].shape
(1.4804006814956665, torch.Size([8, 103, 59543]))

9.测试1

def test(model):
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True)
    predictions = []
    references = []
    for i, data in enumerate(loader_test):
        with torch.no_grad():
            out = model(**data)            
        pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
        label = tokenizer.batch_decode(data['decoder_input_ids'])
        predictions.append(pred)
        references.append(label)        
        if i % 2 == 0:
            print(i)
            input_ids = tokenizer.decode(data['input_ids'][0])
            print('input_ids=', input_ids)
            print('pred=', pred[0])
            print('label=', label[0])            
        if i == 10:
            break            
    references = [[j] for j in references]
test(model)
0
input_ids= The▁only name that▁was not▁mentioned by▁any of the▁participants in the▁negotiations of recent▁days is that of the▁former▁head of the▁branch,▁Mayor Gheorghe Nichita.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Singurul nume care nu a fost men de niciunul dintre participanții la negocierile din ultimele zile este cel al fostului șef al filialei, primarul Gheorghe Nichita.</s> DEaa al al Nicol Nicol Nicol N N În În În În În În În În În În În În În În În În În În Singur Singur Singur Singur Singur Singur
label= pad Singurul nume care nu a fost menționat de niciunul din participanții la negocierile din ultimele zile este cel al fostului lider al filialei, primarul Gheorghe Nichita.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
...
10
input_ids= ▁That▁would▁point to a▁stock▁market▁drop▁if the Fed▁raises the rate,▁unless▁policymakers▁were to▁soften the▁blow by▁promising that▁another▁increase▁would be a▁ways▁off.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Acest ar indica o scădere a piața, federaled- rata,, cu cazul în care factorii care elaborează politici ar reduce lovitura,mițând că o exista o peste de la o creștere a</s>,, Acest În În În În În În În În În În Acest Acest Acest
label= pad Aceasta ar indica o scădere pe bursă dacă Fed crește rata dobânzii, exceptând cazul în care cei care elaborează politicile ar atenua lovitura promițând că ar trece mult timp până la următoarea creștere.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>

10.训练

from transformers import AdamW
from transformers.optimization import get_scheduler
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda', index=0)

def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                             num_warmup_steps=0,
                             num_training_steps=len(loader),
                             optimizer=optimizer)
    model.to(device)
    model.train()   
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)
        out = model(**data)
        loss = out['loss']        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()       
        optimizer.zero_grad()
        model.zero_grad()       
        if i % 50 == 0:
            out = out['logits'].argmax(dim=2)
            correct = (data['decoder_input_ids'] == out).sum().item()
            total = data['decoder_input_ids'].shape[1] * 8
            accuracy = correct / total           
            predictions = []
            references = []           
            for j in range(8):
                pred = tokenizer.decode(out[j])
                label = tokenizer.decode(data['decoder_input_ids'][j])
                predictions.append(pred)
                references.append(label)                
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), accuracy, lr)
train()
0 2.2159132957458496 0.0 1.9992e-05
...
2400 0.6920648217201233 0.006696428571428571 7.920000000000001e-07
2450 0.8634450435638428 0.004032258064516129 3.92e-07

11.模型保存

torch.save(model, '../data/翻译.model')

12.模型加载

model2 = torch.load('../data/翻译.model', map_location='cpu')

13.测试2

test(model2)
0
input_ids= ▁Last▁month▁saw▁lowest▁growth▁rise▁since▁records▁began in 2000</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Luna ultima lună,-a înregistrat o mai mică creşte de 2000. în prezent.</s> țiululululul -: - De De De De De De De De De De De De De De De De Luna De De De De De De Luna De Luna Luna Luna Luna De Luna Luna De Luna Luna Luna Luna De Luna De Luna Luna Luna Luna Luna Luna Luna Luna
label= pad În ultima lună s-a înregistrat cea mai lentă creștere din 2000 până în prezent.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
...
10
input_ids= Corneliu Vadim Tudor▁was▁born on▁November 28,1949, in▁Bucharest. He▁was a▁writer,▁politician and▁journalist.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
pred= Corneliu Vadim Tudor s-a născutscut in 28 noiembrie 1949, la Bucuresti. a scriitor, politician si jurnalist.</s> </s> al alul al al,  La al A A A I La Cor A Cor A Cor Cor Cor La Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Cor Corneliu Cor Cor Cor Cor Cor Cor Cor Cor
label= pad Corneliu Vadim Tudor s-a nascut în 28 noiembrie 1949, în Bucuresti, era scriitor, politician și jurnalist.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

阿值

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值