import os
import logging
import sys
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import (
BertTokenizerFast,
BertForSequenceClassification,
Trainer,
TrainingArguments,
DataCollatorWithPadding,
get_linear_schedule_with_warmup
)
from torch.utils.data import Dataset
import time
import math
import random
# 设置随机种子确保结果可复现
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed()
# 配置支持UTF-8的日志系统
class UTF8StreamHandler(logging.StreamHandler):
def emit(self, record):
try:
msg = self.format(record)
stream = self.stream
stream.write(msg + self.terminator)
self.flush()
except Exception:
self.handleError(record)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('news_classifier.log', encoding='utf-8'),
UTF8StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class NewsClassifier:
def __init__(self):
self.tokenizer = None
self.model = None
self.label_mapping = {}
self.id2label = {}
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"使用设备: {self.device}")
def load_data(self, filepath):
"""加载并预处理数据"""
try:
start_time = time.time()
logger.info(f"开始加载数据: {filepath}")
with open(filepath, 'r', encoding='utf-8', errors='replace') as f:
lines = f.readlines()
data = []
error_count = 0
for i, line in enumerate(lines, 1):
try:
line = line.strip()
if not line:
continue
if '\t' in line:
parts = line.rsplit('\t', 1)
else:
parts = line.rsplit(' ', 1)
if len(parts) != 2:
error_count += 1
logger.warning(f"第{i}行格式错误,已跳过: {line[:50]}...")
continue
text, label = parts
text, label = text.strip(), label.strip()
if not text or not label:
error_count += 1
logger.warning(f"第{i}行内容为空,已跳过: {line[:50]}...")
continue
data.append({'text': text, 'label': label})
except Exception as e:
error_count += 1
logger.warning(f"第{i}行处理失败: {str(e)} - 内容: {line[:50]}...")
if error_count > 0:
logger.warning(f"共跳过{error_count}条错误数据")
if not data:
raise ValueError("没有有效数据可加载")
df = pd.DataFrame(data)
unique_labels = df['label'].unique()
self.label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
self.id2label = {idx: label for label, idx in self.label_mapping.items()}
df['label'] = df['label'].map(self.label_mapping)
logger.info(f"成功加载 {len(df)} 条有效数据,共 {len(unique_labels)} 个类别")
logger.info("类别分布:\n" + df['label'].value_counts().to_string())
logger.info(f"数据加载完成,耗时: {time.time() - start_time:.2f}秒")
return df[['text', 'label']]
except Exception as e:
logger.error(f"数据加载失败: {str(e)}")
raise
def preprocess_data(self, data, max_length=128):
"""将数据转换为BERT输入格式"""
try:
start_time = time.time()
logger.info(f"开始预处理数据,共 {len(data)} 条,最大长度: {max_length}")
# 分批次处理大量数据,避免内存溢出
batch_size = 10000
all_encodings = None
all_labels = []
for i in range(0, len(data), batch_size):
batch_texts = data['text'].iloc[i:i + batch_size].tolist()
batch_labels = data['label'].iloc[i:i + batch_size].tolist()
encodings = self.tokenizer(
batch_texts,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors="pt"
)
if all_encodings is None:
all_encodings = encodings
else:
for key in all_encodings.keys():
all_encodings[key] = torch.cat([all_encodings[key], encodings[key]])
all_labels.extend(batch_labels)
if (i // batch_size + 1) % 10 == 0:
logger.info(f"已处理 {min(i + batch_size, len(data))}/{len(data)} 条数据")
all_labels = torch.tensor(all_labels)
logger.info(f"数据预处理完成,耗时: {time.time() - start_time:.2f}秒")
return all_encodings, all_labels
except Exception as e:
logger.error(f"预处理失败: {str(e)}")
raise
def load_model(self, model_path='bert-base-chinese'):
"""加载预训练模型,优先使用国内镜像源"""
try:
start_time = time.time()
logger.info(f"正在加载模型: {model_path}")
# 本地模型路径(自动生成)
local_path = f"./{model_path.replace('/', '-')}"
# 检查本地是否已下载模型
if os.path.exists(local_path):
logger.info(f"使用本地模型: {local_path}")
self.tokenizer = BertTokenizerFast.from_pretrained(local_path)
self.model = BertForSequenceClassification.from_pretrained(
local_path,
num_labels=len(self.label_mapping),
id2label=self.id2label,
label2id=self.label_mapping
)
self.model.to(self.device)
logger.info(f"本地模型加载成功,耗时: {time.time() - start_time:.2f}秒")
return
# 尝试使用国内镜像源下载(优先方案)
logger.info("尝试从国内镜像源下载模型...")
try:
# 设置国内镜像源(Hugging Face中国镜像)
from transformers import set_huggingface_hub_url
set_huggingface_hub_url("https://hf-mirror.com") # 国内镜像源
# 下载并保存模型
self._download_and_save_model(model_path, local_path)
logger.info(f"国内镜像源下载成功,保存至: {local_path}")
except Exception as mirror_err:
logger.warning(f"国内镜像源下载失败: {mirror_err},尝试禁用SSL验证下载...")
# 禁用SSL验证(备选方案)
self._disable_ssl_verification()
try:
self._download_and_save_model(model_path, local_path)
logger.info(f"禁用SSL验证后下载成功,保存至: {local_path}")
except Exception as ssl_err:
logger.error(f"所有下载方式失败: {ssl_err}")
raise ValueError("无法下载模型,请检查网络或手动下载")
# 加载已下载的模型
self.tokenizer = BertTokenizerFast.from_pretrained(local_path)
self.model = BertForSequenceClassification.from_pretrained(
local_path,
num_labels=len(self.label_mapping),
id2label=self.id2label,
label2id=self.label_mapping
)
self.model.to(self.device)
logger.info(f"模型加载成功,耗时: {time.time() - start_time:.2f}秒")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
raise
def _download_and_save_model(self, model_name, save_path):
"""下载模型并保存到本地"""
logger.info(f"下载模型: {model_name},保存至: {save_path}")
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(
model_name,
num_labels=len(self.label_mapping),
id2label=self.id2label,
label2id=self.label_mapping
)
tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)
def _disable_ssl_verification(self):
"""禁用SSL证书验证(仅在必要时使用)"""
logger.info("禁用SSL证书验证...")
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['REQUESTS_CA_BUNDLE'] = ''
def train(self, train_encodings, train_labels, val_encodings, val_labels):
"""训练分类模型"""
try:
start_time = time.time()
logger.info("开始准备训练数据...")
class NewsDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item['labels'] = self.labels[idx]
return item
def __len__(self):
return len(self.labels)
train_dataset = NewsDataset(train_encodings, train_labels)
val_dataset = NewsDataset(val_encodings, val_labels)
# 计算训练参数
train_batch_size = 16
eval_batch_size = 64
num_epochs = 3
device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
effective_batch_size = train_batch_size * device_count
warmup_steps = min(500, len(train_dataset) // effective_batch_size)
weight_decay = 0.01
logger.info(f"检测到 {device_count} 个设备,有效批大小: {effective_batch_size}")
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=eval_batch_size,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_dir='./logs',
logging_steps=min(10, len(train_dataset) // effective_batch_size // 10),
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
fp16=torch.cuda.is_available(),
gradient_accumulation_steps=max(1, 32 // effective_batch_size),
dataloader_num_workers=min(4, os.cpu_count() or 1),
report_to="none",
save_total_limit=3,
remove_unused_columns=False,
)
num_train_steps = (len(train_dataset) // effective_batch_size +
(1 if len(train_dataset) % effective_batch_size else 0)) * num_epochs
logger.info(f"训练总步数: {num_train_steps},热身步数: {warmup_steps}")
def compute_metrics(p):
preds = np.argmax(p.predictions, axis=1)
return {
'accuracy': accuracy_score(p.label_ids, preds),
'f1': f1_score(p.label_ids, preds, average='weighted')
}
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=5e-5,
weight_decay=weight_decay
)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_train_steps
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=DataCollatorWithPadding(tokenizer=self.tokenizer),
compute_metrics=compute_metrics,
optimizers=(optimizer, scheduler)
)
logger.info("开始训练模型...")
trainer.train()
logger.info("模型训练完成")
eval_results = trainer.evaluate()
logger.info(
f"验证集评估结果: 准确率={eval_results['eval_accuracy']:.4f}, F1分数={eval_results['eval_f1']:.4f}"
)
logger.info(f"总训练时间: {time.time() - start_time:.2f}秒")
return trainer
except Exception as e:
logger.error(f"训练失败: {str(e)}")
raise
def predict(self, texts, trainer=None):
"""对新文本进行分类预测"""
try:
model = trainer.model if trainer else self.model
model.eval()
batch_size = 64
all_results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
inputs = self.tokenizer(
batch_texts,
truncation=True,
padding=True,
max_length=128,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu()
preds = torch.argmax(logits, dim=1).numpy()
confs = torch.softmax(logits, dim=1).max(dim=1).values.numpy()
for text, pred, conf in zip(batch_texts, preds, confs):
all_results.append({
'text': text,
'label': self.id2label[pred],
'confidence': float(conf)
})
return all_results
except Exception as e:
logger.error(f"预测失败: {str(e)}")
raise
def main():
try:
start_time = time.time()
logger.info("=== 新闻分类器程序开始运行 ===")
classifier = NewsClassifier()
# 1. 加载数据
data_dir = "./train.txt"
logger.info(f"正在加载数据文件: {data_dir}")
data = classifier.load_data(data_dir)
# 2. 划分训练集和验证集
train_data, val_data = train_test_split(
data,
test_size=0.2,
random_state=42,
stratify=data['label']
)
logger.info(f"数据集划分完成 - 训练集: {len(train_data)}条, 验证集: {len(val_data)}条")
# 3. 加载模型(自动使用国内镜像源)
classifier.load_model()
# 4. 预处理数据
logger.info("正在预处理数据...")
train_encodings, train_labels = classifier.preprocess_data(train_data)
val_encodings, val_labels = classifier.preprocess_data(val_data)
# 5. 训练模型
trainer = classifier.train(train_encodings, train_labels, val_encodings, val_labels)
# 6. 保存模型
save_path = "./saved_model"
trainer.save_model(save_path)
classifier.tokenizer.save_pretrained(save_path)
logger.info(f"模型已保存到: {save_path}")
# 7. 测试预测
test_texts = [
"网民市民集体幻想中奖后如果你中了9000万怎么办",
"PVC期货有望5月挂牌",
"午时三刻新作《幻神录―宿命情缘》",
"欧司朗LLFY网络提供一站式照明解决方案",
"试探北京楼市向何方:排不完的队 涨不够的价"
]
logger.info("\n测试预测结果:")
results = classifier.predict(test_texts, trainer)
for r in results:
print(f"文本: {r['text']}")
print(f"类别: {r['label']}")
print(f"置信度: {r['confidence']:.2%}\n")
total_time = time.time() - start_time
logger.info(f"=== 程序运行完成,总耗时: {total_time:.2f}秒 ({total_time / 60:.2f}分钟) ===")
except Exception as e:
logger.error(f"程序运行出错: {str(e)}")
import traceback
logger.error(traceback.format_exc())
finally:
input("按Enter键退出...")
if __name__ == "__main__":
if sys.version_info[0] == 3 and sys.version_info[1] >= 7:
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
main()一、 实验目的
1. 理解 Transformer 架构在自然语言处理中的应用原理,重点掌握 BERT 模型的预训练机制与微调方法
2. 掌握基于 PyTorch 和 Hugging Face Transformers 库实现新闻文本分类的完整流程
3. 学习处理文本分类任务中的数据预处理、模型训练及评估技巧
4. 对比 Transformer 模型与循环神经网络(如 GRU)在文本分类任务中的性能差异一、 实验目的
1. 理解 Transformer 架构在自然语言处理中的应用原理,重点掌握 BERT 模型的预训练机制与微调方法
2. 掌握基于 PyTorch 和 Hugging Face Transformers 库实现新闻文本分类的完整流程
3. 学习处理文本分类任务中的数据预处理、模型训练及评估技巧
4. 对比 Transformer 模型与循环神经网络(如 GRU)在文本分类任务中的性能差异
生成思维导图
最新发布