import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
logging as hf_logging
)
import json
import logging
import os
import re
from typing import Dict, Any, Union
from datetime import datetime
# 关闭HF冗长日志
hf_logging.set_verbosity_error()
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 确保输出目录存在
output_dir = r"D:\paper"
os.makedirs(output_dir, exist_ok=True)
class DroneReportGenerator:
def __init__(self, model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
try:
# 初始化tokenizer(支持中文)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.padding_side = "left" # 更适合生成任务
# 加载模型(优化内存使用)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True
).eval()
# 无人机专属生成配置
self.generation_config = {
"max_new_tokens": 1024,
"temperature": 0.3,
"top_p": 0.92,
"repetition_penalty": 1.15,
"do_sample": True,
"early_stopping": True,
"pad_token_id": self.tokenizer.eos_token_id
}
# 无人机报告模板
self.system_prompt = """你是一个专业的无人机监测报告生成系统,必须严格遵循以下要求:
1. 使用标准JSON格式输出
2. 所有字段都必须填写完整
3. 内容必须基于提供的监测数据
4. 使用正式的技术术语
5. 时间戳使用ISO 8601格式
报告必须包含以下字段:
{
"监测任务": "任务描述",
"监测位置": {
"经度": float,
"纬度": float,
"区域描述": str
},
"异常类型": "具体异常现象",
"严重程度": "低/中/高/紧急",
"环境参数": {
"能见度": "描述",
"风速": "单位m/s",
"光照": "描述"
},
"设备状态": {
"电池": "百分比",
"信号强度": "百分比",
"传感器": "正常/异常描述"
},
"详细描述": "技术性分析",
"建议措施": "具体操作建议",
"时间戳": "ISO格式时间"
}"""
self._log_device_map()
except Exception as e:
logger.error(f"模型初始化失败: {str(e)}")
raise
def _log_device_map(self):
"""记录模型设备分配情况"""
try:
if hasattr(self.model, "hf_device_map"):
device_map = self.model.hf_device_map
logger.info(f"模型设备分配: {json.dumps(device_map, indent=2)}")
else:
logger.info(f"模型主设备: {self.model.device}")
except Exception as e:
logger.warning(f"设备映射记录失败: {str(e)}")
def generate_report(self, task_info: Dict[str, Any]) -> Dict[str, Any]:
"""生成标准化无人机监测报告"""
try:
current_time = datetime.now().isoformat()
prompt = self._build_prompt(task_info, current_time)
raw_output = self._generate_text(prompt)
output_text = self._process_output(raw_output)
report = self._validate_and_fix_json(output_text, task_info, current_time)
# 保存报告到D:\paper
timestamp_str = current_time.replace(":", "-").split(".")[0]
report_path = os.path.join(output_dir, f"drone_report_{timestamp_str}.json")
with open(report_path, "w", encoding="utf-8") as f:
json.dump(report, f, indent=2, ensure_ascii=False)
logger.info(f"报告已保存至: {report_path}")
return report
except Exception as e:
logger.error(f"报告生成失败: {str(e)}")
return self._emergency_report(task_info, str(e))
def _build_prompt(self, task_info: Dict, timestamp: str) -> str:
"""构建结构化提示词"""
try:
location = task_info.get("location", {})
loc_str = f"经度 {location.get('longitude', '未知')}, 纬度 {location.get('latitude', '未知')}"
if "description" in location:
loc_str += f" ({location['description']})"
prompt = f"""
{self.system_prompt}
监测数据:
- 监测任务: {task_info.get("task", "常规巡检")}
- 位置信息: {loc_str}
- 异常类型: {task_info.get("anomaly", "无")}
- 严重程度: {task_info.get("severity", "低")}
环境参数:
- 能见度: {task_info.get("visibility", "良好")}
- 风速: {task_info.get("wind_speed", "0")} m/s
- 光照: {task_info.get("light", "白天正常")}
设备状态:
- 电池: {task_info.get("battery", "100")}%
- 信号强度: {task_info.get("signal", "100")}%
- 传感器: {task_info.get("sensors", "正常")}
当前时间: {timestamp}
请生成完整的监测报告(严格JSON格式):
```json
"""
return prompt
except Exception as e:
logger.error(f"提示构建失败: {str(e)}")
raise
def _generate_text(self, prompt: str) -> str:
"""直接使用model.generate生成文本"""
try:
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.model.device)
outputs = self.model.generate(**inputs, **self.generation_config)
# 只解码生成的部分
generated_ids = outputs[0][len(inputs["input_ids"][0]):]
return self.tokenizer.decode(generated_ids, skip_special_tokens=True)
except Exception as e:
logger.error(f"文本生成失败: {str(e)}")
raise
#
# def _process_output(self, raw_output: str) -> str:
# """处理生成的文本,提取JSON格式内容"""
# try:
# # 清理首尾空白字符
# text = raw_output.strip()
# # 提取"```json“和”```“之间的代码块内容
# if "```json " in text and "```" in text:
# # 查找最后一个”```json“出现的位置
# last_json_pos = text.rfind("```json")
# if last_json_pos != -1:
# code_block = text[last_json_pos + len("```json"):]
# end_pos = code_block.find("```")
# if end_pos != -1:
# text = code_block[:end_pos].strip()
# # 如果没有找到标记代码块,尝试直接提取JSON
# if not text.startswith("{") and "{" in text:
# brace_pos = text.find("{")
# text = text[brace_pos:]
#
# # 确保返回的内容至少包含一个JSON对象
# if text and text[0] == "{":
# return text
# return f'{{"error": "无效的JSON格式"}}'
#
# except Exception as e:
# logger.error(f"输出处理失败: {str(e)}
def _process_output(self, raw_output: str) -> str:
"""处理生成的文本"""
try:
# 清理可能的格式问题
text = raw_output.strip()
# 提取```json和```之间的内容
if "```json" in text and "```" in text:
text = text.split("```json")[-1].split("```")[0].strip()
elif "{" in text: # 如果没有代码块标记,尝试从第一个{开始提取
text = text[text.find("{"):]
return text
except Exception as e:
logger.error(f"输出处理失败: {str(e)}")
return f'{{"error": "输出处理失败: {str(e)}"}}'
def _validate_and_fix_json(self, text: str, task_info: Dict, timestamp: str) -> Dict:
"""严格的JSON验证和修复"""
try:
# 预处理:转义未转义的引号
text = re.sub(r'(?<!\\)"', r'\"', text)
text = re.sub(r'\\"', r'"', text) # 修复双重转义
# 尝试直接解析
try:
report = json.loads(text)
if isinstance(report, dict):
return report
except:
pass
# 高级修复:重建JSON结构
text = self._advanced_json_repair(text)
try:
report = json.loads(text)
except json.JSONDecodeError as e:
logger.warning(f"高级修复失败: {str(e)}")
# 最终回退方案
report = self._create_default_report(task_info, timestamp)
report["_parsing_error"] = str(e)
report["_raw_output"] = text[:100] + "..."
return report
# 验证必需字段
default_report = self._create_default_report(task_info, timestamp)
for key, value in default_report.items():
if key not in report or not report[key]:
if isinstance(value, dict):
report[key] = {k: v for k, v in value.items() if k not in report.get(key, {})}
else:
report[key] = value
return report
except Exception as e:
logger.error(f"JSON验证失败: {str(e)}")
return {"error": str(e), "raw_output": text[:200] + "..."}
def _advanced_json_repair(self, text: str) -> str:
"""高级JSON修复算法"""
try:
# 1. 标准化引号
text = re.sub(r'(?<!\\)"', r'\"', text)
# 2. 修复缺失的引号(针对中文冒号情况)
def replace_colon(match):
key = match.group(1).strip().strip('"\'')
return f'"{key}":'
text = re.sub(r'([^\s:]+)\s*:', replace_colon, text)
# 3. 修复缺失的逗号
lines = [line.strip() for line in text.split('\n') if line.strip()]
fixed_lines = []
for i in range(len(lines)):
if i > 0 and lines[i].startswith(('"', "'")) and not lines[i - 1].endswith(('{', '[', ',', ':')):
fixed_lines.append(',')
fixed_lines.append(lines[i])
text = ' '.join(fixed_lines)
# 4. 确保完整的JSON结构
if not text.strip().startswith("{"):
text = "{" + text.strip()
if not text.strip().endswith("}"):
text = text.strip() + "}"
# 5. 修复转义字符
text = text.replace('\\"', '"').replace("\\'", "'")
return text
except Exception as e:
logger.warning(f"高级修复过程中出错: {str(e)}")
return f'{{"error": "修复失败", "detail": "{str(e)}"}}'
def _create_default_report(self, task_info: Dict, timestamp: str) -> Dict:
"""创建默认报告结构"""
location = task_info.get("location", {})
return {
"监测任务": task_info.get("task", "常规巡检"),
"监测位置": {
"经度": location.get("longitude", 0),
"纬度": location.get("latitude", 0),
"区域描述": location.get("description", "未知区域")
},
"异常类型": task_info.get("anomaly", "无"),
"严重程度": task_info.get("severity", "低"),
"环境参数": {
"能见度": task_info.get("visibility", "良好"),
"风速": task_info.get("wind_speed", "0"),
"光照": task_info.get("light", "白天正常")
},
"设备状态": {
"电池": task_info.get("battery", "100"),
"信号强度": task_info.get("signal", "100"),
"传感器": task_info.get("sensors", "正常")
},
"详细描述": "自动生成失败,请人工检查",
"建议措施": "立即进行人工复查",
"时间戳": timestamp
}
def _emergency_report(self, task_info: Dict, error: str) -> Dict:
"""紧急模式报告"""
logger.warning("启动紧急报告模式")
default = self._create_default_report(task_info, datetime.now().isoformat())
default.update({
"系统状态": "紧急模式",
"错误信息": error,
"详细描述": "自动报告生成系统故障",
"建议措施": "检查系统日志并联系技术支持"
})
return default
# 测试代码
if __name__ == "__main__":
try:
print("=" * 50)
print("无人机报告生成系统 - 环境检查")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"当前GPU: {torch.cuda.get_device_name(0)}")
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.2f}GB")
print("=" * 50)
generator = DroneReportGenerator()
test_cases = [
{
"task": "油田管道巡检",
"location": {"longitude": 117.1215, "latitude": 36.6811, "description": "济南历城油田区域"},
"anomaly": "油管压力异常",
"severity": "高",
"visibility": "中等(夜间)",
"wind_speed": "3.2",
"light": "夜间飞行",
"battery": "78",
"signal": "92",
"sensors": "红外正常,视觉受限"
},
{
"task": "风电场叶片检测",
"location": {"longitude": 121.4737, "latitude": 31.2304, "description": "上海东海风电场"},
"anomaly": "叶片异常震动",
"severity": "中",
"visibility": "良好",
"wind_speed": "8.5",
"light": "白天正常",
"battery": "65",
"signal": "88",
"sensors": "所有传感器正常"
}
]
for i, case in enumerate(test_cases, 1):
print(f"\n=== 测试案例 {i}: {case['task']} ===")
report = generator.generate_report(case)
print("生成的报告:")
print(json.dumps(report, indent=2, ensure_ascii=False))
except Exception as e:
logger.error(f"测试失败: {str(e)}")
运行上述代码出现2025-07-05 19:34:25,919 - ERROR - 报告生成失败: ‘DroneReportGenerator’ object has no attribute ‘_process_output’
2025-07-05 19:34:25,919 - ERROR - 测试失败: ‘DroneReportGenerator’ object has no attribute ‘_emergency_report’ 运行代码后出现上述错误,请修改代码并生成这些问题,请修改代码并生成