import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import re
import time
import os
from collections import Counter
# 全局设备配置(仅输出一次)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"设备: {device}") # 简化设备输出文字
# -------------------------- 1. 增强版数据加载与预处理 --------------------------
class ForumDataset(Dataset):
def __init__(self, job_path, academy_path, max_len=40, min_freq=2):
self.max_len = max_len
self.min_freq = min_freq
# 加载数据并平衡类别
self.job_titles = self.load_txt(job_path)
self.academy_titles = self.load_txt(academy_path)
# 平衡两类数据数量
min_count = min(len(self.job_titles), len(self.academy_titles))
self.job_titles = self.job_titles[:min_count]
self.academy_titles = self.academy_titles[:min_count]
# 基于频率构建字符字典(过滤低频字符)
all_chars = []
for title in self.job_titles + self.academy_titles:
all_chars.extend(list(title))
char_counts = Counter(all_chars)
self.char2idx = {'<PAD>': 0, '<UNK>': 1}
for char, count in char_counts.items():
if count >= self.min_freq:
self.char2idx[char] = len(self.char2idx)
self.vocab_size = len(self.char2idx)
# 预处理所有数据
self.data = []
for title in self.academy_titles:
self.data.append(self.process_title(title, 0))
for title in self.job_titles:
self.data.append(self.process_title(title, 1))
def load_txt(self, path):
with open(path, 'r', encoding='utf8') as f:
lines = [line.strip() for line in f if line.strip()]
# 保留有意义的标点符号辅助分类
pattern = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9,。!?;:()()]')
return [pattern.sub('', line) for line in lines]
def process_title(self, title, label):
char_ids = [self.char2idx.get(char, self.char2idx['<UNK>'])
for char in title[:self.max_len]]
pad_len = self.max_len - len(char_ids)
if pad_len > 0:
char_ids.extend([self.char2idx['<PAD>']] * pad_len)
return torch.tensor(char_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# -------------------------- 2. 提升版RNN模型(已修正) --------------------------
class EnhancedCharRNN(nn.Module):
def __init__(self, vocab_size, embedding_dim=96, hidden_dim=192, num_classes=2):
super().__init__()
# 将hidden_dim定义为模型实例属性
self.hidden_dim = hidden_dim
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.embedding_dropout = nn.Dropout(0.2)
# 双层LSTM增强特征提取
self.lstm1 = nn.LSTM(
embedding_dim, hidden_dim, num_layers=1,
batch_first=True, bidirectional=True
)
self.lstm2 = nn.LSTM(
hidden_dim * 2, hidden_dim, num_layers=1,
batch_first=True, bidirectional=True
)
# 池化层捕捉关键特征
self.max_pool = nn.AdaptiveMaxPool1d(1)
self.avg_pool = nn.AdaptiveAvgPool1d(1)
# 更稳健的分类头
self.fc = nn.Sequential(
nn.Linear(hidden_dim * 4, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, x):
embed = self.embedding(x)
embed = self.embedding_dropout(embed)
out1, _ = self.lstm1(embed)
out2, (hidden, _) = self.lstm2(out1)
# 融合最后一层隐藏状态和池化特征
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
out2 = out2.permute(0, 2, 1)
max_pool = self.max_pool(out2).squeeze(-1)
avg_pool = self.avg_pool(out2).squeeze(-1)
# 通过self.hidden_dim访问隐藏层维度(已修正)
combined = torch.cat([hidden, max_pool, avg_pool], dim=1)[:, :self.hidden_dim * 4]
return self.fc(combined)
# -------------------------- 3. 优化训练策略 --------------------------
def train_model(dataset, epochs=10, batch_size=96, lr=3e-4):
# 划分数据集
train_size = int(0.85 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(
dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# 数据加载器
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=min(os.cpu_count(), 1), pin_memory=False
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size*2, shuffle=False,
num_workers=0
)
model = EnhancedCharRNN(vocab_size=dataset.vocab_size).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # 标签平滑防过拟合
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
best_val_acc = 0.0
print(f"开始训练({epochs}轮)...")
for epoch in range(epochs):
# 训练阶段
model.train()
train_correct, train_total = 0, 0
for batch_x, batch_y in train_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
logits = model(batch_x)
loss = criterion(logits, batch_y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 梯度裁剪
optimizer.step()
preds = torch.argmax(logits, dim=1)
train_correct += (preds == batch_y).sum().item()
train_total += batch_x.size(0)
# 验证阶段
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
logits = model(batch_x)
preds = torch.argmax(logits, dim=1)
val_correct += (preds == batch_y).sum().item()
val_total += batch_x.size(0)
# 输出信息
train_acc = train_correct / train_total
val_acc = val_correct / val_total
scheduler.step()
print(f"轮次 {epoch+1} | 训练acc: {train_acc:.4f} | 验证acc: {val_acc:.4f}")
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'model_state_dict': model.state_dict(),
'char2idx': dataset.char2idx,
'max_len': dataset.max_len
}, 'best_model.pth')
print(f"训练完成 | 最佳验证acc: {best_val_acc:.4f}")
return model
# -------------------------- 4. 预测函数 --------------------------
def predict_text(text, model_path='best_model.pth'):
checkpoint = torch.load(model_path, map_location=device)
char2idx = checkpoint['char2idx']
max_len = checkpoint['max_len']
model = EnhancedCharRNN(vocab_size=len(char2idx)).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# 保持与训练一致的预处理
pattern = re.compile(r'[^\u4e00-\u9fa5a-zA-Z0-9,。!?;:()()]')
text = pattern.sub('', text)
# 字符编码
char_ids = [char2idx.get(char, char2idx['<UNK>']) for i, char in enumerate(text) if i < max_len]
char_ids += [char2idx['<PAD>']] * (max_len - len(char_ids))
input_tensor = torch.tensor(char_ids, dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)
pred = torch.argmax(logits, dim=1).item()
confidence = torch.softmax(logits, dim=1)[0][pred].item() # 增加置信度
return f'考硕考博({confidence:.2f})' if pred == 0 else f'招聘信息({confidence:.2f})'
# -------------------------- 5. 主函数 --------------------------
if __name__ == '__main__':
# 1. 初始化数据集
dataset = ForumDataset(
job_path='job_titles.txt',
academy_path='academy_titles.txt'
)
print(f'数据量: {len(dataset)} | 字典大小: {dataset.vocab_size}')
# 2. 训练模型
train_model(dataset)
# 3. 预测
print('\n===== 分类预测 =====')
while True:
user_input = input('输入文本("退出"结束):')
if user_input == '退出':
break
if not user_input.strip():
print('请输入有效文本!')
continue
print(f'分类:{predict_text(user_input)}\n')
修改代码,提高分类准确率
最新发布