import sys
import os
import re
import jieba
import time
from PyQt5.QtWidgets import (
QApplication, QWidget, QLineEdit, QPushButton, QTextEdit,
QVBoxLayout, QHBoxLayout, QFileDialog, QFormLayout, QMessageBox,
QProgressBar, QCheckBox, QDialog, QListWidget, QLabel, QVBoxLayout
)
from PyQt5.QtCore import QThread, pyqtSignal
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
import numpy as np
class DocumentSelectorDialog(QDialog):
def __init__(self, doc_labels, parent=None):
super().__init__(parent)
self.setWindowTitle("选择两个文档进行相似度评估")
self.selected_items = []
layout = QVBoxLayout(self)
label = QLabel("请选择两个训练文档标签进行比较:")
layout.addWidget(label)
self.list_widget = QListWidget()
self.list_widget.addItems(doc_labels)
self.list_widget.setSelectionMode(QListWidget.MultiSelection)
layout.addWidget(self.list_widget)
confirm_btn = QPushButton("确定")
confirm_btn.clicked.connect(self.accept)
layout.addWidget(confirm_btn)
def get_selected(self):
selected = [item.text() for item in self.list_widget.selectedItems()]
return selected
class LoadingMessageBox(QMessageBox):
def __init__(self, parent=None):
super().__init__(parent)
self.setIcon(QMessageBox.Information)
self.setText("正在加载,请稍候...")
self.setStandardButtons(QMessageBox.NoButton)
self.show()
self.activateWindow()
self.raise_()
class LoadingModelThread(QThread):
loaded_signal = pyqtSignal(object)
def __init__(self, model_path, parent=None):
super().__init__(parent)
self.model_path = model_path
def run(self):
try:
start = time.time()
model = Doc2Vec.load(self.model_path)
end = time.time()
self.loaded_signal.emit((model, end - start))
except Exception as e:
self.loaded_signal.emit((e, 0))
class TrainingThread(QThread):
log_signal = pyqtSignal(str)
finished_signal = pyqtSignal()
progress_signal = pyqtSignal(int)
def __init__(self, file_paths, vector_size, window, min_count, epochs,
resume_model_path=None, overwrite_model=False, parent=None):
super().__init__(parent)
self.file_paths = file_paths
self.vector_size = vector_size
self.window = window
self.min_count = min_count
self.epochs = epochs
self.resume_model_path = resume_model_path
self.overwrite_model = overwrite_model
def preprocess_line(self, line):
line = re.sub(r'[^\u4e00-\u9fa5]', ' ', line)
words = jieba.lcut(line.strip())
return [word for word in words if word.strip() and len(word.strip()) > 0]
def sentence_generator(self):
for file_path in self.file_paths:
if file_path.endswith('.txt'):
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
yield self.preprocess_line(line)
except Exception as e:
self.log_signal.emit(f"读取 TXT 文件时出错: {e}")
continue
elif file_path.endswith('.docx'):
try:
from docx import Document
doc = Document(file_path)
for para in doc.paragraphs:
yield self.preprocess_line(para.text.strip())
except ImportError:
self.log_signal.emit("错误:未安装 python-docx,请使用 pip install python-docx")
return
except Exception as e:
self.log_signal.emit(f"读取 Word (.docx) 文件时出错: {e}")
return
elif file_path.endswith('.doc'):
try:
import win32com.client
word = win32com.client.Dispatch("Word.Application")
word.Visible = 0
doc = word.Documents.Open(file_path)
text = doc.Content.Text
doc.Close()
word.Quit()
for line in text.split('\r'):
yield self.preprocess_line(line.strip())
except ImportError:
self.log_signal.emit("错误:未安装 pywin32,请使用 pip install pywin32")
return
except Exception as e:
self.log_signal.emit(f"读取 Word (.doc) 文件时出错: {e}")
return
else:
self.log_signal.emit(f"不支持的文件格式: {file_path}")
continue
def run(self):
self.log_signal.emit("开始逐行加载语料...")
sentences_gen = list(self.sentence_generator())
tagged_data = [TaggedDocument(words, [str(i)]) for i, words in enumerate(sentences_gen) if words]
if self.resume_model_path and os.path.exists(self.resume_model_path):
self.log_signal.emit(f"正在加载模型进行增量训练:{self.resume_model_path}")
model = Doc2Vec.load(self.resume_model_path)
model.build_vocab(tagged_data, update=True)
else:
self.log_signal.emit("未检测到已有模型,开始从头训练...")
model = Doc2Vec(
vector_size=self.vector_size,
window=self.window,
min_count=self.min_count,
workers=4,
epochs=self.epochs,
dm=0
)
model.build_vocab(tagged_data)
self.log_signal.emit("开始训练模型...")
total_epochs = self.epochs
for epoch in range(total_epochs):
model.train(tagged_data, total_examples=model.corpus_count, epochs=1)
progress = int((epoch + 1) / total_epochs * 100)
self.progress_signal.emit(progress)
# 确保进度条显示为 100%
time.sleep(0.2)
self.progress_signal.emit(100)
if self.overwrite_model and self.resume_model_path:
model.save(self.resume_model_path)
self.log_signal.emit(f"模型已更新并保存到原路径:{self.resume_model_path}")
else:
model_path, _ = QFileDialog.getSaveFileName(None, "保存模型", "", "Model Files (*.model)")
if model_path:
model.save(model_path)
self.log_signal.emit(f"模型训练完成并已保存为 {model_path}")
else:
self.log_signal.emit("模型保存已取消")
self.finished_signal.emit()
class Doc2VecTrainerGUI(QWidget):
def __init__(self):
super().__init__()
self.setWindowTitle("Doc2Vec 中文模型训练与评估器(支持增量训练)")
self.resize(800, 800)
self.model = None
self.thread = None
self.selected_files = []
self.resume_model_path = None
self.init_ui()
def init_ui(self):
layout = QVBoxLayout()
# 文件选择
self.file_path = QLineEdit()
self.select_file_btn = QPushButton("选择语料文件")
self.select_file_btn.clicked.connect(self.select_file)
file_layout = QHBoxLayout()
file_layout.addWidget(self.file_path)
file_layout.addWidget(self.select_file_btn)
layout.addLayout(file_layout)
self.select_folder_btn = QPushButton("选择语料文件夹")
self.select_folder_btn.clicked.connect(self.select_folder)
layout.addWidget(self.select_folder_btn)
# 加载模型继续训练按钮
self.load_model_for_resume_btn = QPushButton("选择模型进行增量训练")
self.load_model_for_resume_btn.clicked.connect(self.load_model_for_resume)
layout.addWidget(self.load_model_for_resume_btn)
# 是否覆盖模型
self.overwrite_checkbox = QCheckBox("是否覆盖原模型(增量训练时)")
layout.addWidget(self.overwrite_checkbox)
# 参数设置
form_layout = QFormLayout()
self.vector_size = QLineEdit("50")
self.window = QLineEdit("5")
self.min_count = QLineEdit("1")
self.epochs = QLineEdit("20")
form_layout.addRow("向量维度 (vector_size):", self.vector_size)
form_layout.addRow("上下文窗口大小 (window):", self.window)
form_layout.addRow("最小词频 (min_count):", self.min_count)
form_layout.addRow("训练轮数 (epochs):", self.epochs)
layout.addLayout(form_layout)
# 开始训练按钮
self.train_btn = QPushButton("开始训练")
self.train_btn.clicked.connect(self.start_training)
layout.addWidget(self.train_btn)
# 进度条
self.progress_bar = QProgressBar()
layout.addWidget(self.progress_bar)
# 加载模型按钮
self.load_model_btn = QPushButton("加载模型")
self.load_model_btn.clicked.connect(self.load_model)
layout.addWidget(self.load_model_btn)
# 模型信息按钮
self.show_info_btn = QPushButton("查看模型信息")
self.show_info_btn.clicked.connect(self.show_model_info)
layout.addWidget(self.show_info_btn)
# 词汇表按钮
self.vocab_btn = QPushButton("查看词汇表")
self.vocab_btn.clicked.connect(self.show_vocab)
layout.addWidget(self.vocab_btn)
# 词向量导出按钮
self.export_vec_btn = QPushButton("导出词向量")
self.export_vec_btn.clicked.connect(self.export_vectors)
layout.addWidget(self.export_vec_btn)
# 评估输入
self.text1 = QLineEdit()
self.text2 = QLineEdit()
eval_layout = QFormLayout()
eval_layout.addRow("文本1(评估):", self.text1)
eval_layout.addRow("文本2(评估):", self.text2)
self.eval_btn = QPushButton("评估相似度")
self.eval_btn.clicked.connect(self.evaluate_similarity)
layout.addLayout(eval_layout)
layout.addWidget(self.eval_btn)
# 日志输出
self.log_area = QTextEdit()
self.log_area.setReadOnly(True)
layout.addWidget(self.log_area)
# 词汇表输出
self.vocab_area = QTextEdit()
self.vocab_area.setReadOnly(True)
self.vocab_area.setFixedHeight(150)
layout.addWidget(self.vocab_area)
self.setLayout(layout)
def select_file(self):
file_name, _ = QFileDialog.getOpenFileName(
self,
"选择语料文件",
"",
"All Supported Files (*.txt *.docx *.doc);;"
"Text Files (*.txt);;"
"Word Documents (*.docx *.doc);;"
"All Files (*)"
)
if file_name:
self.file_path.setText(file_name)
self.selected_files = [file_name]
self.log(f"已选择语料文件: {file_name}")
def select_folder(self):
folder_path = QFileDialog.getExistingDirectory(self, "选择语料文件夹")
if folder_path:
txt_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(".txt")]
if txt_files:
self.file_path.setText(f"多个TXT文件(共{len(txt_files)}个)")
self.selected_files = txt_files
self.log(f"已选择文件夹中的 {len(txt_files)} 个 .txt 文件")
else:
self.log("该文件夹中没有 .txt 文件。")
QMessageBox.information(self, "提示", "该文件夹中没有 .txt 文件。")
def load_model_for_resume(self):
model_path, _ = QFileDialog.getOpenFileName(
self, "选择模型文件", "", "Model Files (*.model)"
)
if model_path:
try:
self.resume_model_path = model_path
self.model = Doc2Vec.load(model_path)
self.log(f"已选择模型用于增量训练:{model_path}")
except Exception as e:
self.log(f"加载模型失败:{e}")
QMessageBox.critical(self, "错误", f"加载模型失败:{e}")
def start_training(self):
if not self.selected_files:
self.log("请先选择语料文件或文件夹。")
QMessageBox.warning(self, "警告", "请先选择语料文件或文件夹。")
return
if not all(os.path.isfile(fp) for fp in self.selected_files):
self.log("某些文件路径无效,请重新选择。")
QMessageBox.critical(self, "错误", "某些文件路径无效,请重新选择。")
return
try:
vector_size = int(self.vector_size.text())
window = int(self.window.text())
min_count = int(self.min_count.text())
epochs = int(self.epochs.text())
except ValueError:
self.log("请输入合法的参数(整数)")
QMessageBox.warning(self, "警告", "请输入合法的参数(整数)")
return
if self.thread is not None and self.thread.isRunning():
self.log("当前已有训练任务在运行。")
QMessageBox.information(self, "提示", "当前已有训练任务在运行。")
return
self.train_btn.setEnabled(False)
self.progress_bar.setValue(0)
self.log("开始训练任务...")
self.thread = TrainingThread(
self.selected_files,
vector_size,
window,
min_count,
epochs,
resume_model_path=self.resume_model_path,
overwrite_model=self.overwrite_checkbox.isChecked(),
parent=self
)
self.thread.log_signal.connect(self.log)
self.thread.progress_signal.connect(self.progress_bar.setValue)
self.thread.finished_signal.connect(self.on_training_finished)
self.thread.start()
def load_model(self):
model_path, _ = QFileDialog.getOpenFileName(
self, "选择模型文件", "", "Model Files (*.model)"
)
if model_path:
self.loading_msg = LoadingMessageBox(self)
self.loading_thread = LoadingModelThread(model_path, self)
self.loading_thread.loaded_signal.connect(self.on_model_loaded)
self.loading_thread.start()
def on_model_loaded(self, result):
self.loading_msg.accept()
self.loading_thread.quit()
self.loading_thread.wait()
data, elapsed = result
if isinstance(data, Exception):
self.log(f"加载模型失败:{data}")
QMessageBox.critical(self, "错误", f"加载模型失败:{data}")
else:
self.model = data
self.log(f"模型加载成功,耗时 {elapsed:.2f} 秒")
QMessageBox.information(self, "提示", f"模型加载成功,耗时 {elapsed:.2f} 秒")
def show_model_info(self):
if self.model is None:
self.log("请先加载模型!")
return
info = (
f"模型维度: {self.model.vector_size}\n"
f"词汇量: {len(self.model.wv)}\n"
f"训练轮数: {self.model.epochs}\n"
f"最小词频: {self.model.min_count}"
)
self.log("模型信息:\n" + info)
def show_vocab(self):
if self.model is None:
self.log("请先加载模型!")
return
vocab = list(self.model.wv.key_to_index.keys())
self.vocab_area.setPlainText("\n".join(vocab[:500])) # 显示前500个词
self.log("已显示模型词汇表前500个词。")
def export_vectors(self):
if self.model is None:
self.log("请先加载模型!")
return
save_path, _ = QFileDialog.getSaveFileName(self, "导出向量", "", "Text Files (*.txt)")
if not save_path:
return
self.model.wv.save_word2vec_format(save_path, binary=False)
self.log(f"词向量已导出至 {save_path}")
def evaluate_similarity(self):
if self.model is None:
self.log("请先加载模型!")
QMessageBox.warning(self, "警告", "请先加载模型!")
return
text1 = self.text1.text()
text2 = self.text2.text()
if not text1 or not text2:
self.log("请输入两段文本用于评估。")
QMessageBox.warning(self, "警告", "请输入两段文本用于评估。")
return
# 获取所有训练好的文档标签(使用 model.dv 替代 model.docvecs)
try:
doc_labels = self.model.dv.index_to_key
except Exception as e:
self.log("错误:模型中没有可用的文档向量。")
QMessageBox.critical(self, "错误", "模型中没有可用的文档向量。")
return
if len(doc_labels) < 2:
self.log("需要至少两个训练文档才能进行相似度评估。")
QMessageBox.warning(self, "警告", "需要至少两个训练文档。")
return
# 弹出文档选择窗口
dialog = DocumentSelectorDialog(doc_labels, self)
result = dialog.exec_()
if result == QDialog.Accepted:
selected = dialog.get_selected()
if len(selected) != 2:
self.log("请选择两个文档进行比较。")
QMessageBox.warning(self, "警告", "请选择两个文档进行比较。")
return
doc_id1, doc_id2 = selected[0], selected[1]
# 获取文档向量
try:
vec1 = self.model.dv[doc_id1] # 使用 dv 替代 docvecs
vec2 = self.model.dv[doc_id2]
except KeyError as e:
self.log(f"找不到文档向量: {e}")
QMessageBox.critical(self, "错误", f"找不到文档向量: {e}")
return
# 计算余弦相似度
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
self.log(f"文档标签1: {doc_id1}")
self.log(f"文档标签2: {doc_id2}")
self.log(f"余弦相似度: {similarity:.4f}")
else:
self.log("取消选择文档。")
def log(self, message):
self.log_area.append(message)
def on_training_finished(self):
self.log("训练已完成。")
self.train_btn.setEnabled(True)
self.progress_bar.setValue(100)
QMessageBox.information(self, "提示", "模型训练已完成!")
self.thread = None
# 自动加载新训练的模型
try:
model_path = self.resume_model_path if self.resume_model_path and self.overwrite_checkbox.isChecked() else "doc2vec_model.model"
if os.path.exists(model_path):
self.model = Doc2Vec.load(model_path)
self.log("模型已自动加载。")
else:
self.model = None
except Exception as e:
self.log(f"自动加载模型失败:{e}")
if __name__ == '__main__':
app = QApplication(sys.argv)
window = Doc2VecTrainerGUI()
window.show()
sys.exit(app.exec_())
文本1: 壮热, 神情烦躁,面色红赤, 口渴, 声音高粗, 大便干结,小便短赤, 舌红绛,苔黄黑起刺, 脉洪大滑数,有力
文本2: 寒冷, 神情萎靡, 面色暗淡, 口不渴, 声音低微, 大便溏薄,小便清长, 舌淡胖嫩,苔白滑, 脉沉细迟弱,无力
余弦相似度 : 0.98 46 为什么