# E:\AI_System\agent\model_manager.py
import os
import sys
import logging
import json
import importlib
import hashlib
import gc
import time
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
class ModelManager:
"""AI模型管理器 - 增强版(保留原有接口并添加新功能)"""
MODEL_REGISTRY_FILE = "model_registry.json"
def __init__(self, model_registry: Dict[str, str] = None,
cache_dir: str = "model_cache",
use_gpu: bool = True,
max_models_in_memory: int = 3):
"""
初始化模型管理器
参数:
model_registry: 模型路径注册表(字典:模型名->模型路径或ID)
cache_dir: 模型缓存目录
use_gpu: 是否使用GPU
max_models_in_memory: 内存中最大模型数量
"""
self.logger = logging.getLogger("ModelManager")
self.logger.info("🚀 初始化模型管理器...")
# 初始化参数
self.model_registry = model_registry or {}
self.cache_dir = cache_dir
self.use_gpu = use_gpu
self.max_models_in_memory = max_models_in_memory
# 确保缓存目录存在
os.makedirs(self.cache_dir, exist_ok=True)
# 加载持久化的注册表
self._persistent_registry = self._load_registry()
# 合并传入的注册表和持久化注册表
for name, path in self.model_registry.items():
self._persistent_registry.setdefault(name, {
"path": path,
"type": "text", # 默认为文本模型
"status": "unloaded"
})
# 已加载的模型
self.loaded_models: Dict[str, Any] = {}
self.logger.info(f"✅ 模型管理器初始化完成 (GPU: {'启用' if use_gpu else '禁用'})")
self.logger.info(f"注册模型: {list(self.model_registry.keys())}")
def _load_registry(self) -> Dict[str, dict]:
"""加载模型注册表"""
try:
registry_path = Path(self.MODEL_REGISTRY_FILE)
if registry_path.exists():
with open(registry_path, 'r', encoding='utf-8') as f:
registry = json.load(f)
self.logger.info(f"📋 成功加载模型注册表: {self.MODEL_REGISTRY_FILE}")
return registry
except Exception as e:
self.logger.error(f"❌ 加载模型注册表失败: {str(e)}")
return {}
def _save_registry(self):
"""保存模型注册表"""
try:
with open(self.MODEL_REGISTRY_FILE, 'w', encoding='utf-8') as f:
json.dump(self._persistent_registry, f, indent=2, ensure_ascii=False)
self.logger.debug(f"💾 模型注册表已保存: {self.MODEL_REGISTRY_FILE}")
return True
except Exception as e:
self.logger.error(f"❌ 保存模型注册表失败: {str(e)}")
return False
def register_model(self, model_name: str, model_path: str, model_type: str = "text",
adapter_config: Optional[dict] = None) -> bool:
"""
注册新模型到注册表
参数:
model_name: 模型名称
model_path: 模型路径或HuggingFace ID
model_type: 模型类型 (text, image, audio, multimodal)
adapter_config: 适配器配置
"""
# 检查模型是否存在
if not self._check_model_exists(model_path):
self.logger.error(f"❌ 模型路径不可访问: {model_path}")
return False
# 计算SHA256校验和
try:
checksum = self._calculate_checksum(model_path)
except Exception as e:
self.logger.warning(f"⚠️ 无法计算校验和: {str(e)}")
checksum = "unknown"
# 添加到注册表
self._persistent_registry[model_name] = {
"path": model_path,
"type": model_type,
"status": "unloaded",
"checksum": checksum,
"last_accessed": time.time(),
"adapter": adapter_config
}
self.logger.info(f"✅ 模型注册成功: {model_name} ({model_type})")
self._save_registry()
return True
def _check_model_exists(self, model_path: str) -> bool:
"""检查模型路径是否有效"""
# 如果是HuggingFace模型ID
if "/" in model_path and not os.path.exists(model_path):
self.logger.info(f"🔍 检测到HuggingFace模型ID: {model_path}")
return True
# 检查本地路径
if os.path.exists(model_path):
return True
return False
def _calculate_checksum(self, model_path: str) -> str:
"""计算模型校验和(对于大模型只计算关键文件)"""
# 如果是HuggingFace模型ID,不计算校验和
if "/" in model_path and not os.path.exists(model_path):
return "hf-" + hashlib.md5(model_path.encode()).hexdigest()[:8]
# 本地模型 - 计算校验和
sha256 = hashlib.sha256()
# 如果是目录,只计算关键文件
if os.path.isdir(model_path):
key_files = ["pytorch_model.bin", "model.safetensors", "config.json"]
for file in key_files:
file_path = os.path.join(model_path, file)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
sha256.update(chunk)
else:
# 单个模型文件
with open(model_path, 'rb') as f:
while chunk := f.read(8192):
sha256.update(chunk)
return sha256.hexdigest()
def load_model(self, model_name: str, force_reload: bool = False) -> Tuple[bool, Any]:
"""
加载模型到内存
参数:
model_name: 要加载的模型名称
force_reload: 是否强制重新加载
返回:
(成功状态, 模型对象)
"""
# 检查模型是否已注册
if model_name not in self._persistent_registry:
self.logger.error(f"❌ 模型未注册: {model_name}")
return False, None
model_info = self._persistent_registry[model_name]
# 如果模型已加载且不需要强制重载
if model_name in self.loaded_models and not force_reload:
self.logger.info(f"📦 模型已在内存中: {model_name}")
model_info["status"] = "loaded"
model_info["last_accessed"] = time.time()
return True, self.loaded_models[model_name]
# 检查内存占用,如有必要卸载最少使用的模型
if len(self.loaded_models) >= self.max_models_in_memory:
self._unload_least_recently_used()
# 实际加载模型
try:
self.logger.info(f"🔄 加载模型: {model_name} ({model_info['type']})")
# 根据模型类型动态导入加载器
model_type = model_info["type"]
if model_type == "text":
model = self._load_text_model(model_info)
elif model_type == "image":
model = self._load_image_model(model_info)
elif model_type == "audio":
model = self._load_audio_model(model_info)
else:
self.logger.error(f"❌ 不支持的模型类型: {model_type}")
return False, None
# 更新状态
self.loaded_models[model_name] = model
model_info["status"] = "loaded"
model_info["last_accessed"] = time.time()
self._save_registry()
self.logger.info(f"✅ 模型加载成功: {model_name}")
return True, model
except Exception as e:
tb = sys.exc_info()[2]
self.logger.error(f"❌ 模型加载失败: {model_name}, 错误: {str(e)}", exc_info=tb)
model_info["status"] = "error"
return False, None
def _load_text_model(self, model_info: dict) -> Any:
"""加载文本模型(LLM)"""
model_path = model_info["path"]
# 动态导入transformers
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
self.logger.error("❌ transformers库未安装")
raise RuntimeError("transformers not installed")
self.logger.debug(f"🔧 加载文本模型: {model_path}")
device = "cuda" if self.use_gpu else "cpu"
# 加载tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=self.cache_dir)
model = AutoModelForCausalLM.from_pretrained(
model_path,
cache_dir=self.cache_dir,
device_map=device if self.use_gpu else None
)
return {
"model": model,
"tokenizer": tokenizer,
"info": model_info
}
def _load_image_model(self, model_info: dict) -> Any:
"""加载图像模型"""
model_path = model_info["path"]
try:
# 尝试动态导入diffusers或torchvision
if "diffusion" in model_path.lower():
from diffusers import StableDiffusionPipeline
device = "cuda" if self.use_gpu else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
cache_dir=self.cache_dir
).to(device)
return {"pipeline": pipe}
else:
import torchvision.models as models
# 简单实现 - 实际中需要更复杂的逻辑
model = models.resnet50(pretrained=True)
if self.use_gpu:
model = model.cuda()
return {"model": model}
except ImportError as e:
self.logger.error("❌ 必要的图像模型库未安装")
raise
def _load_audio_model(self, model_info: dict) -> Any:
"""加载音频模型"""
# 实现类似文本模型的动态加载
try:
# 示例:语音识别模型
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained(model_info["path"], cache_dir=self.cache_dir)
model = Wav2Vec2ForCTC.from_pretrained(model_info["path"], cache_dir=self.cache_dir)
device = "cuda" if self.use_gpu else "cpu"
if self.use_gpu:
model = model.to(device)
return {
"processor": processor,
"model": model
}
except ImportError:
self.logger.error("❌ transformers库未安装")
raise
def unload_model(self, model_name: str = None) -> bool:
"""
从内存中卸载模型
参数:
model_name: 要卸载的模型名称(None表示卸载所有模型)
"""
if model_name is None:
self.logger.info("卸载所有模型")
for name in list(self.loaded_models.keys()):
self._unload_single_model(name)
return True
return self._unload_single_model(model_name)
def _unload_single_model(self, model_name: str) -> bool:
"""卸载单个模型"""
if model_name not in self.loaded_models:
self.logger.warning(f"⚠️ 模型未加载: {model_name}")
return False
try:
# 删除模型引用并调用垃圾回收
del self.loaded_models[model_name]
gc.collect()
# 更新注册表状态
if model_name in self._persistent_registry:
self._persistent_registry[model_name]["status"] = "unloaded"
self._save_registry()
self.logger.info(f"🗑️ 模型已卸载: {model_name}")
return True
except Exception as e:
self.logger.error(f"❌ 卸载模型失败: {model_name}, 错误: {str(e)}")
return False
def _unload_least_recently_used(self):
"""卸载最近最少使用的模型"""
if not self.loaded_models:
return
# 找到最近最少使用的模型
lru_model = None
lru_time = time.time()
for model_name in self.loaded_models:
if model_name in self._persistent_registry:
access_time = self._persistent_registry[model_name].get("last_accessed", 0)
if access_time < lru_time:
lru_time = access_time
lru_model = model_name
# 卸载该模型
if lru_model:
self.unload_model(lru_model)
self.logger.info(f"🔁 已释放内存: 卸载模型 {lru_model}")
def get_model(self, model_name: str) -> Optional[Any]:
"""
获取已加载的模型
参数:
model_name: 模型名称
返回:
模型对象或None
"""
if model_name in self.loaded_models:
# 更新最后访问时间
if model_name in self._persistent_registry:
self._persistent_registry[model_name]["last_accessed"] = time.time()
self._save_registry()
return self.loaded_models[model_name]
return None
def get_model_info(self, model_name: str) -> Dict[str, Any]:
"""获取模型信息"""
if model_name in self._persistent_registry:
return self._persistent_registry[model_name]
return {"status": "unknown"}
def list_models(self) -> List[str]:
"""列出所有注册模型"""
return list(self._persistent_registry.keys())
def shutdown(self):
"""关闭模型管理器,卸载所有模型"""
self.logger.info("🛑 关闭模型管理器...")
self.unload_model() # 卸载所有模型
self.logger.info("✅ 模型管理器已关闭")
你还缺什么 你可以告诉我 我都发给你 请把你之前说的 需要改的 改好的可以直接覆盖的文件发我