RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
!pip install transformers datasets torch rouge-score matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
import time
import numpy as np
from datasets import load_dataset
from rouge_score import rouge_scorer
import matplotlib.pyplot as plt
from IPython.display import clear_output
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据预处理(严格过滤无效样本)
class SummaryDataset(Dataset):
def __init__(self, dataset_split, tokenizer, max_article_len=384, max_summary_len=96, subset_size=0.01):
self.tokenizer = tokenizer
self.max_article_len = max_article_len
self.max_summary_len = max_summary_len
self.subset = dataset_split.select(range(int(len(dataset_split) * subset_size)))
# 严格过滤无效样本
self.articles = []
self.summaries = []
self.vocab = set(tokenizer.vocab.keys())
for item in self.subset:
article = item['article'].strip()
summary = item['highlights'].strip()
if len(article) > 20 and len(summary) > 10:
article_tokens = tokenizer.tokenize(article)
summary_tokens = tokenizer.tokenize(summary)
if all(t in self.vocab for t in article_tokens) and all(t in self.vocab for t in summary_tokens):
self.articles.append(article)
self.summaries.append(summary)
self.pad_token_id = tokenizer.pad_token_id
self.unk_token_id = tokenizer.unk_token_id
def __len__(self):
return len(self.articles)
def __getitem__(self, idx):
src = self.tokenizer(
self.articles[idx],
max_length=self.max_article_len,
truncation=True,
padding='max_length',
return_tensors='pt',
add_special_tokens=True
)
tgt = self.tokenizer(
self.summaries[idx],
max_length=self.max_summary_len,
truncation=True,
padding='max_length',
return_tensors='pt',
add_special_tokens=True
)
tgt_labels = tgt['input_ids'].squeeze()
tgt_labels[tgt_labels == self.pad_token_id] = -100 # 忽略填充
tgt_labels[tgt_labels >= len(self.tokenizer.vocab)] = self.unk_token_id # 过滤无效id
return {
'input_ids': src['input_ids'].squeeze(),
'attention_mask': src['attention_mask'].squeeze(),
'labels': tgt_labels
}
# 基础Seq2Seq模型
class BasicEncoder(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.gru = nn.GRU(emb_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, src):
embedded = self.embedding(src)
outputs, hidden = self.gru(embedded)
# 取第二层双向隐藏状态
forward_hidden = hidden[-2, :, :] # 第二层正向
backward_hidden = hidden[-1, :, :] # 第二层反向
hidden = torch.cat([forward_hidden, backward_hidden], dim=1) # (batch, 2*hidden_dim)
hidden = self.fc_hidden(hidden).unsqueeze(0) # (1, batch, hidden_dim)
return hidden
class BasicDecoder(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.gru = nn.GRU(emb_dim + hidden_dim, hidden_dim, num_layers=1, batch_first=True)
self.fc = nn.Linear(hidden_dim * 2 + emb_dim, vocab_size)
def forward(self, input_ids, hidden, context):
input_embedded = self.embedding(input_ids.unsqueeze(1)) # (batch, 1, emb_dim)
input_combined = torch.cat([input_embedded, context.unsqueeze(1)], dim=2) # (batch, 1, emb_dim+hidden_dim)
output, hidden = self.gru(input_combined, hidden) # (batch, 1, hidden_dim)
output = output.squeeze(1) # (batch, hidden_dim)
combined = torch.cat([output, context, input_embedded.squeeze(1)], dim=1) # (batch, 2*hidden_dim+emb_dim)
logits = self.fc(combined)
return logits, hidden
class BasicSeq2Seq(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden_dim=256):
super().__init__()
self.encoder = BasicEncoder(vocab_size, emb_dim, hidden_dim)
self.decoder = BasicDecoder(vocab_size, emb_dim, hidden_dim)
self.device = device
self.sos_token_id = 101 # [CLS]
self.eos_token_id = 102 # [SEP]
self.unk_token_id = 100 # [UNK]
def forward(self, src, tgt):
hidden = self.encoder(src)
context = hidden.squeeze(0)
batch_size, tgt_len = tgt.size()
outputs = torch.zeros(batch_size, tgt_len, self.decoder.fc.out_features).to(device)
input_ids = tgt[:, 0]
for t in range(1, tgt_len):
logits, hidden = self.decoder(input_ids, hidden, context)
outputs[:, t] = logits
input_ids = tgt[:, t]
return outputs
def generate(self, src, max_length=80):
src = src.to(device)
hidden = self.encoder(src)
context = hidden.squeeze(0)
# 修正后的生成初始化
generated = torch.full((src.size(0), 1), self.sos_token_id, device=device) # 注意这里的修正
for _ in range(max_length-1):
logits, hidden = self.decoder(generated[:, -1], hidden, context)
next_token = torch.argmax(logits, dim=1, keepdim=True)
# 防止过早生成标点
if generated.size(1) < 5:
punctuation = [',', '.', ';', ':', '!', '?', "'", '"', '`', '~']
punct_ids = [self.tokenizer.convert_tokens_to_ids(p) for p in punctuation]
if next_token.item() in punct_ids:
# 替换为最常见的实词
next_token = torch.tensor([[self.tokenizer.convert_tokens_to_ids('the')]], device=device)
generated = torch.cat([generated, next_token], dim=1)
if (next_token == self.eos_token_id).all():
break
return generated
# 注意力Seq2Seq模型
class Attention(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.W = nn.Linear(2 * hidden_dim, hidden_dim)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, hidden, encoder_outputs):
src_len = encoder_outputs.size(1)
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) # (batch, src_len, hidden_dim)
combined = torch.cat([hidden, encoder_outputs], dim=2) # (batch, src_len, 2*hidden_dim)
energy = self.v(torch.tanh(self.W(combined))).squeeze(2) # (batch, src_len)
return torch.softmax(energy, dim=1)
class AttnEncoder(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers=2, batch_first=True, bidirectional=True, dropout=0.1)
self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim) # 双向输出拼接
self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, src):
embedded = self.embedding(src)
outputs, (hidden, cell) = self.lstm(embedded) # outputs: (batch, src_len, 2*hidden_dim)
# 取第二层双向隐藏状态
hidden = torch.cat([hidden[-2, :, :], hidden[-1, :, :]], dim=1) # (batch, 2*hidden_dim)
cell = torch.cat([cell[-2, :, :], cell[-1, :, :]], dim=1)
hidden = self.fc_hidden(hidden).unsqueeze(0) # (1, batch, hidden_dim)
cell = self.fc_cell(cell).unsqueeze(0)
return outputs, (hidden, cell)
class AttnDecoder(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
self.attention = Attention(hidden_dim)
self.lstm = nn.LSTM(emb_dim + 2 * hidden_dim, hidden_dim, num_layers=1, batch_first=True)
self.fc = nn.Linear(hidden_dim + emb_dim, vocab_size)
def forward(self, input_ids, hidden, cell, encoder_outputs):
input_embedded = self.embedding(input_ids.unsqueeze(1)) # (batch, 1, emb_dim)
attn_weights = self.attention(hidden.squeeze(0), encoder_outputs) # (batch, src_len)
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # (batch, 1, 2*hidden_dim)
lstm_input = torch.cat([input_embedded, context], dim=2) # (batch, 1, emb_dim+2*hidden_dim)
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell)) # output: (batch, 1, hidden_dim)
logits = self.fc(torch.cat([output.squeeze(1), input_embedded.squeeze(1)], dim=1)) # (batch, vocab_size)
return logits, hidden, cell
class AttnSeq2Seq(nn.Module):
def __init__(self, vocab_size, emb_dim=128, hidden_dim=256):
super().__init__()
self.encoder = AttnEncoder(vocab_size, emb_dim, hidden_dim)
self.decoder = AttnDecoder(vocab_size, emb_dim, hidden_dim)
self.device = device
self.sos_token_id = 101 # [CLS]
self.eos_token_id = 102 # [SEP]
self.unk_token_id = 100 # [UNK]
def forward(self, src, tgt):
encoder_outputs, (hidden, cell) = self.encoder(src)
batch_size, tgt_len = tgt.size()
outputs = torch.zeros(batch_size, tgt_len, self.decoder.fc.out_features).to(device)
input_ids = tgt[:, 0]
for t in range(1, tgt_len):
logits, hidden, cell = self.decoder(input_ids, hidden, cell, encoder_outputs)
outputs[:, t] = logits
input_ids = tgt[:, t]
return outputs
def generate(self, src, max_length=80):
encoder_outputs, (hidden, cell) = self.encoder(src)
# 修正后的生成初始化
generated = torch.full((src.size(0), 1), self.sos_token_id, device=device) # 注意这里的修正
for _ in range(max_length-1):
logits, hidden, cell = self.decoder(generated[:, -1], hidden, cell, encoder_outputs)
next_token = torch.argmax(logits, dim=1, keepdim=True)
# 防止过早生成标点
if generated.size(1) < 5:
punctuation = [',', '.', ';', ':', '!', '?', "'", '"', '`', '~']
punct_ids = [self.tokenizer.convert_tokens_to_ids(p) for p in punctuation]
if next_token.item() in punct_ids:
# 替换为最常见的实词
next_token = torch.tensor([[self.tokenizer.convert_tokens_to_ids('the')]], device=device)
generated = torch.cat([generated, next_token], dim=1)
if (next_token == self.eos_token_id).all():
break
return generated
# Transformer模型
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=3, dim_feedforward=512, max_len=5000):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_encoder = PositionalEncoding(d_model, max_len)
# 编码器
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
# 解码器
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout=0.1)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
self.fc = nn.Linear(d_model, vocab_size)
self.d_model = d_model
self.sos_token_id = 101 # [CLS]
self.eos_token_id = 102 # [SEP]
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def forward(self, src, tgt):
src_mask = None
tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(device)
src_key_padding_mask = (src == 0)
tgt_key_padding_mask = (tgt == 0)
src = self.embedding(src) * np.sqrt(self.d_model)
src = self.pos_encoder(src)
tgt = self.embedding(tgt) * np.sqrt(self.d_model)
tgt = self.pos_encoder(tgt)
memory = self.transformer_encoder(src.transpose(0, 1), src_mask, src_key_padding_mask)
output = self.transformer_decoder(
tgt.transpose(0, 1),
memory,
tgt_mask,
None,
tgt_key_padding_mask,
src_key_padding_mask
)
output = self.fc(output.transpose(0, 1))
return output
def generate(self, src, max_length=80):
src_mask = None
src_key_padding_mask = (src == 0)
src = self.embedding(src) * np.sqrt(self.d_model)
src = self.pos_encoder(src)
memory = self.transformer_encoder(src.transpose(0, 1), src_mask, src_key_padding_mask)
batch_size = src.size(0)
generated = torch.full((batch_size, 1), self.sos_token_id, device=device)
for i in range(max_length-1):
tgt_mask = self._generate_square_subsequent_mask(generated.size(1)).to(device)
tgt_key_padding_mask = (generated == 0)
tgt = self.embedding(generated) * np.sqrt(self.d_model)
tgt = self.pos_encoder(tgt)
output = self.transformer_decoder(
tgt.transpose(0, 1),
memory,
tgt_mask,
None,
tgt_key_padding_mask,
src_key_padding_mask
)
output = self.fc(output.transpose(0, 1)[:, -1, :])
next_token = torch.argmax(output, dim=1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if (next_token == self.eos_token_id).all():
break
return generated
# 训练函数
def train_model(model, train_loader, optimizer, criterion, epochs=3):
model.train()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=1, factor=0.5)
start_time = time.time()
for epoch in range(epochs):
total_loss = 0
model.train()
for i, batch in enumerate(train_loader):
src = batch['input_ids'].to(device)
tgt = batch['labels'].to(device)
optimizer.zero_grad()
outputs = model(src, tgt[:, :-1])
# 检查模型输出有效性
if torch.isnan(outputs).any():
print("警告:模型输出包含NaN,跳过此批次")
continue
loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt[:, 1:].reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 梯度裁剪
optimizer.step()
total_loss += loss.item()
if (i+1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_loader)} | Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_loader)
scheduler.step(avg_loss)
print(f"Epoch {epoch+1} | 平均损失: {avg_loss:.4f}")
torch.cuda.empty_cache()
total_time = time.time() - start_time
print(f"训练完成!总耗时: {total_time:.2f}s ({total_time/60:.2f}分钟)")
return model, total_time
# 评估函数
def evaluate_model(model, val_loader, tokenizer, num_examples=2):
model.eval()
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
valid_count = 0
with torch.no_grad():
for i, batch in enumerate(val_loader):
src = batch['input_ids'].to(device)
tgt = batch['labels'].to(device)
generated = model.generate(src)
for s, p, t in zip(src, generated, tgt):
src_txt = tokenizer.decode(s, skip_special_tokens=True)
pred_txt = tokenizer.decode(p, skip_special_tokens=True)
true_txt = tokenizer.decode(t[t != -100], skip_special_tokens=True)
if len(pred_txt.split()) > 3 and len(true_txt.split()) > 3:
valid_count += 1
if valid_count <= num_examples:
print(f"\n原文: {src_txt[:100]}...")
print(f"生成: {pred_txt}")
print(f"参考: {true_txt[:80]}...")
print("-"*60)
if true_txt and pred_txt:
scores = scorer.score(true_txt, pred_txt)
for key in rouge_scores:
rouge_scores[key].append(scores[key].fmeasure)
if valid_count > 0:
avg_scores = {key: sum(rouge_scores[key])/len(rouge_scores[key]) for key in rouge_scores}
print(f"\n评估结果 (基于{valid_count}个样本):")
print(f"ROUGE-1: {avg_scores['rouge1']*100:.2f}%")
print(f"ROUGE-2: {avg_scores['rouge2']*100:.2f}%")
print(f"ROUGE-L: {avg_scores['rougeL']*100:.2f}%")
else:
print("警告:未生成有效摘要")
avg_scores = {key: 0.0 for key in rouge_scores}
return avg_scores
# 可视化模型性能
def visualize_model_performance(model_names, train_times, rouge_scores):
plt.figure(figsize=(15, 6))
# 训练时间对比图
plt.subplot(1, 2, 1)
bars = plt.bar(model_names, train_times)
plt.title('模型训练时间对比')
plt.ylabel('时间 (分钟)')
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.1f} min', ha='center', va='bottom')
# ROUGE分数对比图
plt.subplot(1, 2, 2)
x = np.arange(len(model_names))
width = 0.25
plt.bar(x - width, [scores['rouge1'] for scores in rouge_scores], width, label='ROUGE-1')
plt.bar(x, [scores['rouge2'] for scores in rouge_scores], width, label='ROUGE-2')
plt.bar(x + width, [scores['rougeL'] for scores in rouge_scores], width, label='ROUGE-L')
plt.title('模型ROUGE分数对比')
plt.ylabel('F1分数')
plt.xticks(x, model_names)
plt.legend()
plt.tight_layout()
plt.savefig('performance_comparison.png')
plt.show()
print("性能对比图已保存为 performance_comparison.png")
# 交互式文本摘要生成
def interactive_summarization(models, tokenizer, model_names, max_length=80):
while True:
print("\n" + "="*60)
print("文本摘要交互式测试 (输入 'q' 退出)")
print("="*60)
input_text = input("请输入要摘要的文本:\n")
if input_text.lower() == 'q':
break
if len(input_text) < 50:
print("请输入更长的文本(至少50个字符)")
continue
# 生成摘要
inputs = tokenizer(
input_text,
max_length=384,
truncation=True,
padding='max_length',
return_tensors='pt'
).to(device)
print("\n生成摘要中...")
all_summaries = []
for i, model in enumerate(models):
model.eval()
with torch.no_grad():
generated = model.generate(inputs["input_ids"])
summary = tokenizer.decode(generated[0], skip_special_tokens=True)
all_summaries.append(summary)
# 打印结果
print(f"\n{model_names[i]} 摘要:")
print("-"*50)
print(summary)
print("-"*50)
print("\n所有模型摘要对比:")
for i, (name, summary) in enumerate(zip(model_names, all_summaries)):
print(f"{i+1}. {name}: {summary}")
# 主程序
print("加载数据集...")
dataset = load_dataset("cnn_dailymail", "3.0.0")
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
vocab_size = len(tokenizer.vocab)
# 准备训练数据
print("准备训练数据...")
train_ds = SummaryDataset(dataset['train'], tokenizer, subset_size=0.01) # 使用1%的数据
val_ds = SummaryDataset(dataset['validation'], tokenizer, subset_size=0.01)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0)
# 定义损失函数
criterion = nn.CrossEntropyLoss(ignore_index=-100)
# 训练基础Seq2Seq
print("\n" + "="*60)
print("训练基础Seq2Seq模型")
print("="*60)
basic_model = BasicSeq2Seq(vocab_size).to(device)
trained_basic, basic_time = train_model(basic_model, train_loader, None, criterion, epochs=3)
basic_rouge = evaluate_model(trained_basic, val_loader, tokenizer)
# 训练注意力Seq2Seq
print("\n" + "="*60)
print("训练注意力Seq2Seq模型")
print("="*60)
attn_model = AttnSeq2Seq(vocab_size).to(device)
trained_attn, attn_time = train_model(attn_model, train_loader, None, criterion, epochs=3)
attn_rouge = evaluate_model(trained_attn, val_loader, tokenizer)
# 训练Transformer
print("\n" + "="*60)
print("训练Transformer模型")
print("="*60)
transformer_model = TransformerModel(vocab_size).to(device)
trained_transformer, transformer_time = train_model(transformer_model, train_loader, None, criterion, epochs=3)
transformer_rouge = evaluate_model(trained_transformer, val_loader, tokenizer)
# 可视化模型性能
print("\n" + "="*60)
print("模型性能对比")
print("="*60)
model_names = ['基础Seq2Seq', '注意力Seq2Seq', 'Transformer']
train_times = [basic_time/60, attn_time/60, transformer_time/60]
rouge_scores = [basic_rouge, attn_rouge, transformer_rouge]
visualize_model_performance(model_names, train_times, rouge_scores)
# 交互式测试
print("\n" + "="*60)
print("交互式文本摘要测试")
print("="*60)
print("提示:输入一段文本,将同时生成三个模型的摘要结果")
interactive_summarization(
[trained_basic, trained_attn, trained_transformer],
tokenizer,
model_names
)
修改完错误后发完整代码给我