import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import collections
import os
import time
import matplotlib.pyplot as plt
# 设置超参数
start_token = 'G' # 诗歌起始标记
end_token = 'E' # 诗歌结束标记
batch_size = 64 # 训练批量大小
embedding_dim = 128 # 词向量维度
hidden_dim = 256 # LSTM隐藏层维度
learning_rate = 0.001 # 学习率
num_epochs = 50 # 训练轮数
# 设备配置(自动选择GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 词嵌入层
class WordEmbedding(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(WordEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
nn.init.uniform_(self.embedding.weight, -1.0, 1.0) # 均匀初始化
def forward(self, x):
return self.embedding(x)
# RNN模型
class RNN_Model(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2, dropout=0.3):
super(RNN_Model, self).__init__()
self.embedding = WordEmbedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0
)
self.fc = nn.Linear(hidden_dim, vocab_size)
self.dropout = nn.Dropout(dropout)
self.init_weights()
# 权重初始化函数
def init_weights(self):
# 全连接层初始化
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.fc.bias)
# LSTM权重初始化
for name, param in self.lstm.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
# 设置遗忘门偏置为1(有助于缓解梯度消失)
n = param.size(0)
param.data[n // 4:n // 2].fill_(1.0)
def forward(self, x, hidden=None):
# x: (batch_size, seq_len)
batch_size = x.size(0)
# 嵌入层
embeds = self.embedding(x) # (batch_size, seq_len, embedding_dim)
# LSTM层
if hidden is None:
h0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(device)
c0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(device)
hidden = (h0, c0)
lstm_out, hidden = self.lstm(embeds, hidden) # (batch_size, seq_len, hidden_dim)
# 应用dropout,防止过拟合
lstm_out = self.dropout(lstm_out)
# 全连接层
output = self.fc(lstm_out) # (batch_size, seq_len, vocab_size)
# 重新排列维度用于损失计算
output = output.permute(0, 2, 1) # (batch_size, vocab_size, seq_len)
return output, hidden
# 数据处理函数
def process_poems(file_path):
poems = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
if ':' in line:
_, content = line.split(':', 1) # 分割标题和内容
else:
content = line
# 清理内容
content = content.replace(' ', '')
if any(char in content for char in ['_', '(', '(', '《', '[', start_token, end_token]):
continue
# 长度过滤
if len(content) < 5 or len(content) > 80:
continue
# 添加开始和结束标记
content = start_token + content + end_token
poems.append(content)
except Exception as e:
print(f"Error processing line: {line} - {str(e)}")
continue
# 统计词频
all_chars = [char for poem in poems for char in poem]
counter = collections.Counter(all_chars)
# 创建词汇表(按频率排序)
sorted_chars = sorted(counter.items(), key=lambda x: -x[1])
chars = [char for char, _ in sorted_chars]
char_to_idx = {char: i + 1 for i, char in enumerate(chars)} # 0保留给填充
char_to_idx['<PAD>'] = 0
idx_to_char = {i: char for char, i in char_to_idx.items()}
# 转换诗歌为索引序列
poems_idx = []
for poem in poems:
poem_idx = [char_to_idx.get(char, 0) for char in poem] # 未知字符映射为0
poems_idx.append(poem_idx)
return poems_idx, char_to_idx, idx_to_char
# 批量生成函数
def create_batches(poems_idx, batch_size, max_seq_len=50):
# 按长度排序,有助于减少填充
sorted_poems = sorted(poems_idx, key=len)
batches = []
num_batches = len(sorted_poems) // batch_size
for i in range(num_batches):
start_idx = i * batch_size
end_idx = (i + 1) * batch_size
batch = sorted_poems[start_idx:end_idx]
# 找到本批次最大长度
max_len = min(max(len(poem) for poem in batch), max_seq_len)
# 填充序列
padded_batch = []
for poem in batch:
if len(poem) > max_len:
padded_poem = poem[:max_len] # 截断
else:
padded_poem = poem + [0] * (max_len - len(poem)) # 填充,0是<PAD>
padded_batch.append(padded_poem)
# 输入和输出序列
inputs = [poem[:-1] for poem in padded_batch] # 输入序列(去除最后一个字符)
targets = [poem[1:] for poem in padded_batch] # 输出序列(移除第一个字符)
batches.append((inputs, targets))
return batches
# 训练函数
def train_model():
# 处理数据
poems_idx, char_to_idx, idx_to_char = process_poems('./poems.txt')
vocab_size = len(char_to_idx)
print(f"Vocabulary size: {vocab_size}")
print(f"Number of poems: {len(poems_idx)}")
# 创建批次
batches = create_batches(poems_idx, batch_size)
print(f"Number of batches: {len(batches)}")
# 初始化模型
model = RNN_Model(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
num_layers=2
).to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=0) # 忽略填充索引
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# 训练记录
train_losses = []
best_loss = float('inf')
print("Starting training...")
start_time = time.time()
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(batches):
# 转换为张量
inputs_tensor = torch.tensor(inputs, dtype=torch.long).to(device)
targets_tensor = torch.tensor(targets, dtype=torch.long).to(device)
# 前向传播
optimizer.zero_grad()
output, _ = model(inputs_tensor)
# 计算损失
loss = criterion(output, targets_tensor)
# 反向传播
loss.backward()
# 梯度裁剪(防止梯度爆炸)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# 更新参数
optimizer.step()
# 记录损失
epoch_loss += loss.item()
# 打印进度
if batch_idx % 20 == 0:
avg_loss = epoch_loss / (batch_idx + 1)
elapsed = time.time() - start_time
print(f"Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx}/{len(batches)}], "
f"Loss: {avg_loss:.4f}, Time: {elapsed:.2f}s")
# 计算本轮平均损失
epoch_loss /= len(batches)
train_losses.append(epoch_loss)
scheduler.step(epoch_loss)
# 打印摘要
print(f"Epoch [{epoch + 1}/{num_epochs}] completed, Avg Loss: {epoch_loss:.4f}")
# 保存最佳模型
if epoch_loss < best_loss:
best_loss = epoch_loss
os.makedirs('./models', exist_ok=True)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss,
'char_to_idx': char_to_idx,
'idx_to_char': idx_to_char
}, './models/best_poetry_model.pth')
print(f"Saved best model with loss: {best_loss:.4f}")
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('./models/training_loss.png')
plt.close()
print("Training completed!")
# 生成诗歌函数
def generate_poem(model, idx_to_char, char_to_idx, start_word, max_length=50):
model.eval()
poem = []
# 初始化输入
input_seq = torch.tensor([[char_to_idx[start_token], char_to_idx[start_word]]],
dtype=torch.long).to(device)
poem.extend([start_token, start_word])
# 初始化隐藏状态
hidden = None
with torch.no_grad(): # 禁用梯度计算
for _ in range(max_length):
# 向前传播,预测下一个字符
output, hidden = model(input_seq, hidden)
# 获取最后一个字符的预测
last_output = output[:, :, -1] # (batch_size, vocab_size)
# 应用温度采样/多项式采样(增加多样性)
probabilities = torch.softmax(last_output, dim=1).squeeze()
next_idx = torch.multinomial(probabilities, 1).item()
# 检查是否结束
if next_idx == char_to_idx[end_token]:
break
# 添加到诗歌中
next_char = idx_to_char.get(next_idx, '<UNK>')
poem.append(next_char)
# 更新输入序列
input_seq = torch.tensor([[next_idx]], dtype=torch.long).to(device)
return ''.join(poem)
# 打印格式化的诗歌
def pretty_print_poem(poem):
# 移除标记并分割句子
clean_poem = poem.replace(start_token, '').replace(end_token, '')
sentences = clean_poem.split('。')
# 打印非空句子
for s in sentences:
if s.strip():
print(s.strip() + '。')
# 主函数
if __name__ == '__main__':
train_model()
# 加载最佳模型并生成诗歌
print("\nGenerating poems with best model...")
checkpoint = torch.load('./models/best_poetry_model.pth')
char_to_idx = checkpoint['char_to_idx']
idx_to_char = checkpoint['idx_to_char']
vocab_size = len(char_to_idx)
model = RNN_Model(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
num_layers=2
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
start_words = ["日", "红", "山", "夜", "湖", "君"]
for word in start_words:
poem = generate_poem(model, idx_to_char, char_to_idx, word)
print(f"\n--- Poem starting with '{word}' ---")
pretty_print_poem(poem)
简要描述算法流程和解决问题的思路