1.# E:\AI_System\agent\model_manager.py
import os
import sys
import logging
import json
import hashlib
import gc
import time
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
# 添加项目根目录到路径
sys.path.append(str(Path(__file__).parent.parent))
# 导入工具函数
try:
from utils.path_utils import normalize_path, is_valid_hf_id, resolve_shortcut
except ImportError:
# 如果导入失败,提供基本实现
def normalize_path(path):
"""规范化路径"""
return os.path.abspath(os.path.expanduser(path))
def is_valid_hf_id(model_id):
"""检查是否为有效的HuggingFace模型ID"""
return isinstance(model_id, str) and "/" in model_id and len(model_id.split("/")) == 2
def resolve_shortcut(shortcut_path):
"""解析Windows快捷方式"""
try:
import win32com.client
shell = win32com.client.Dispatch("WScript.Shell")
shortcut = shell.CreateShortCut(shortcut_path)
return shortcut.Targetpath
except:
return None
class ModelManager:
"""AI模型管理器 - 完整修复版(星型架构适配版)"""
MODEL_REGISTRY_FILE = "model_registry.json"
DEFAULT_MODEL_PATHS = {
"TEXT_BASE": "Qwen2-7B",
"TEXT_CHAT": "deepseek-7b-chat",
"IMAGE_MODEL": "sdxl",
"MULTIMODAL": "deepseek-vl2",
"YI_VL": "yi-vl"
}
def __init__(
self,
config: Dict[str, Any] = None,
cache_dir: str = "model_cache",
use_gpu: bool = True,
max_models_in_memory: int = 3
):
# 配置日志
self.logger = logging.getLogger("ModelManager")
self.logger.setLevel(logging.INFO)
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.info("🚀 初始化模型管理器...")
# 初始化参数
self.config = config or {}
self.cache_dir = normalize_path(cache_dir)
self.use_gpu = use_gpu
self.max_models_in_memory = max_models_in_memory
# 确保缓存目录存在
os.makedirs(self.cache_dir, exist_ok=True)
# 解析模型根目录
self.models_root = self._resolve_models_root()
self.logger.info(f"✅ 模型根目录设置为: {self.models_root}")
# 加载或创建注册表
self._persistent_registry = self._load_or_create_registry()
# 已加载的模型
self.loaded_models: Dict[str, Any] = {}
# 自动注册默认模型
self._register_default_models()
self.logger.info(f"✅ 模型管理器初始化完成 (GPU: {'启用' if use_gpu else '禁用'})")
self.logger.info(f"已注册模型: {list(self._persistent_registry.keys())}")
# 星型架构协调器引用
self.orchestrator = None
def set_orchestrator(self, orchestrator):
"""设置协调器引用"""
self.orchestrator = orchestrator
self.logger.info("✅ 已设置协调器引用")
def is_healthy(self) -> bool:
"""检查管理器健康状况"""
try:
# 检查注册表是否可访问
registry_path = Path(normalize_path(self.MODEL_REGISTRY_FILE))
if not registry_path.exists():
self.logger.warning("⚠️ 模型注册表文件不存在")
return False
# 检查模型根目录是否可访问
if not os.path.exists(self.models_root):
self.logger.error(f"❌ 模型根目录不存在: {self.models_root}")
return False
return True
except Exception as e:
self.logger.error(f"❌ 健康检查失败: {str(e)}")
return False
def get_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
loaded_count = len(self.loaded_models)
total_count = len(self._persistent_registry)
return {
"status": "healthy" if self.is_healthy() else "unhealthy",
"models_loaded": loaded_count,
"models_registered": total_count,
"models_root": self.models_root,
"cache_dir": self.cache_dir,
"use_gpu": self.use_gpu,
"max_models_in_memory": self.max_models_in_memory,
"loaded_models": list(self.loaded_models.keys())
}
def _resolve_models_root(self) -> str:
"""解析模型存储根目录"""
# 尝试多个可能的快捷方式路径
shortcut_paths = [
r"E:\AI_Workspace\01_模型存储\主模型.lnk",
r"E:\AI_Workspace\01_模型存储\模型.lnk",
r"E:\AI_Workspace\模型存储\主模型.lnk"
]
for shortcut_path in shortcut_paths:
try:
target_path = resolve_shortcut(shortcut_path)
if target_path and os.path.exists(target_path):
normalized_path = normalize_path(target_path)
self.logger.info(f"🔗 解析快捷方式成功: {shortcut_path} -> {normalized_path}")
return normalized_path
except Exception as e:
self.logger.warning(f"⚠️ 快捷方式解析失败: {shortcut_path} - {e}")
# 尝试直接使用可能的模型目录
possible_paths = [
r"E:\AI_Models",
r"E:\AI_System\AI_Models",
r"E:\AI_Workspace\01_模型存储",
r"E:\AI_Workspace\模型存储"
]
for path in possible_paths:
if os.path.exists(path):
normalized_path = normalize_path(path)
self.logger.info(f"📁 使用现有模型目录: {normalized_path}")
return normalized_path
# 创建默认目录
fallback = r"E:\AI_Models"
os.makedirs(fallback, exist_ok=True)
self.logger.warning(f"⚠️ 创建默认模型路径: {fallback}")
return normalize_path(fallback)
def _load_or_create_registry(self) -> Dict[str, dict]:
"""加载或创建模型注册表"""
try:
registry_path = Path(normalize_path(self.MODEL_REGISTRY_FILE))
if registry_path.exists():
with open(registry_path, 'r', encoding='utf-8') as f:
registry = json.load(f)
self.logger.info(f"📋 成功加载模型注册表: {registry_path}")
return registry
self.logger.warning(f"⚠️ 模型注册表不存在,创建新文件: {registry_path}")
with open(registry_path, 'w', encoding='utf-8') as f:
json.dump({}, f, indent=2)
return {}
except Exception as e:
self.logger.error(f"❌ 处理模型注册表失败: {str(e)}")
return {}
def _save_registry(self):
"""保存模型注册表"""
try:
registry_path = Path(normalize_path(self.MODEL_REGISTRY_FILE))
with open(registry_path, 'w', encoding='utf-8') as f:
json.dump(self._persistent_registry, f, indent=2, ensure_ascii=False)
self.logger.info("✅ 模型注册表已保存")
except Exception as e:
self.logger.error(f"❌ 保存模型注册表失败: {str(e)}")
def _register_default_models(self):
"""注册配置文件中的默认模型"""
model_settings = self.config.get("model_settings", {})
for model_name, rel_path in self.DEFAULT_MODEL_PATHS.items():
# 构建完整路径:模型根目录 + 相对路径
abs_path = os.path.join(self.models_root, rel_path)
# 转换为相对于模型根目录的相对路径
# 注册模型
if model_name not in self._persistent_registry:
model_type = model_settings.get(model_name, {}).get("type", "text")
# 注意:这里我们使用rel_path作为存储的路径(相对于模型根目录)
# 但是,我们需要检查该路径是否存在,所以使用绝对路径检查
exists, is_local = self._check_model_exists(abs_path)
if not exists:
self.logger.error(f"❌ 默认模型路径不可访问: {abs_path}")
continue
self._persistent_registry[model_name] = {
"path": rel_path, # 存储相对路径
"type": model_type,
"status": "unloaded",
"checksum": "unknown", # 默认模型不计算校验和(如果需要,可以后续计算)
"last_accessed": time.time(),
"adapter": None,
"is_local": is_local
}
self.logger.info(f"✅ 默认模型注册成功: {model_name} ({model_type})")
self._save_registry()
def register_model(
self,
model_name: str,
model_path: str,
model_type: str = "text",
adapter_config: Optional[dict] = None
) -> bool:
"""注册新模型
注意:model_path可以是绝对路径,也可以是相对路径,或者HuggingFace模型ID。
如果是本地路径,将转换为相对于模型根目录的相对路径存储。
如果是HuggingFace模型ID,则直接存储ID(不转换为相对路径)。
"""
# 检查模型是否存在
exists, is_local = self._check_model_exists(model_path)
if not exists:
self.logger.error(f"❌ 模型路径不可访问: {model_path}")
return False
# 如果是本地路径,转换为相对于模型根目录的相对路径
if is_local:
# 首先规范化路径
abs_path = normalize_path(model_path)
# 计算相对于模型根目录的相对路径
try:
# 注意:如果abs_path不在models_root下,则使用绝对路径(但这种情况应该避免)
rel_path = os.path.relpath(abs_path, self.models_root)
except ValueError:
# 在不同驱动器上时,无法计算相对路径,使用绝对路径
rel_path = abs_path
self.logger.warning(f"⚠️ 模型路径与模型根目录不在同一驱动器,使用绝对路径: {abs_path}")
stored_path = rel_path
else:
# HuggingFace模型ID,直接存储
stored_path = model_path
# 计算校验和(如果是本地文件)
checksum = "unknown"
if is_local:
try:
checksum = self._calculate_checksum(abs_path)
except Exception as e:
self.logger.warning(f"⚠️ 无法计算校验和: {str(e)}")
checksum = "error"
# 添加到注册表
self._persistent_registry[model_name] = {
"path": stored_path, # 存储相对路径或HF ID
"type": model_type,
"status": "unloaded",
"checksum": checksum,
"last_accessed": time.time(),
"adapter": adapter_config,
"is_local": is_local
}
self.logger.info(f"✅ 模型注册成功: {model_name} ({model_type})")
self._save_registry()
return True
def _check_model_exists(self, model_path: str) -> Tuple[bool, bool]:
"""检查模型路径是否有效,返回(是否存在,是否是本地路径)"""
# 如果是HuggingFace模型ID
if is_valid_hf_id(model_path):
self.logger.info(f"🔍 检测到HuggingFace模型ID: {model_path}")
return True, False
# 检查本地路径
abs_path = normalize_path(model_path)
if os.path.exists(abs_path):
return True, True
# 尝试相对路径(相对于当前工作目录)
if os.path.exists(model_path):
return True, True
return False, False
def _calculate_checksum(self, model_path: str) -> str:
"""计算模型校验和(使用绝对路径)"""
abs_path = normalize_path(model_path)
if os.path.isdir(abs_path):
sha256 = hashlib.sha256()
key_files = ["pytorch_model.bin", "model.safetensors", "config.json"]
for root, _, files in os.walk(abs_path):
for file in files:
if file in key_files:
file_path = os.path.join(root, file)
with open(file_path, 'rb') as f:
while chunk := f.read(8192):
sha256.update(chunk)
return sha256.hexdigest()
# 单个模型文件
with open(abs_path, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
def load_model(self, model_name: str, force_reload: bool = False) -> Tuple[bool, Any]:
"""加载模型到内存"""
if model_name not in self._persistent_registry:
self.logger.error(f"❌ 模型未注册: {model_name}")
return False, None
model_info = self._persistent_registry[model_name]
stored_path = model_info["path"]
is_local = model_info.get("is_local", True)
# 构建绝对路径:如果是本地路径,则与模型根目录组合;否则为HF ID,直接使用
if is_local:
# 如果是相对路径,则组合;如果是绝对路径(在注册时可能由于不同驱动器而存储为绝对路径),则直接使用
if os.path.isabs(stored_path):
abs_path = stored_path
else:
abs_path = os.path.join(self.models_root, stored_path)
abs_path = normalize_path(abs_path)
else:
abs_path = stored_path # HF ID
# 如果模型已加载且不需要强制重载
if model_name in self.loaded_models and not force_reload:
self.logger.info(f"📦 模型已在内存中: {model_name}")
model_info["last_accessed"] = time.time()
return True, self.loaded_models[model_name]
# 检查内存占用
if len(self.loaded_models) >= self.max_models_in_memory:
self._unload_least_recently_used()
# 实际加载模型
try:
self.logger.info(f"🔄 加载模型: {model_name} ({model_info['type']})")
model_type = model_info["type"]
if model_type == "text":
model = self._load_text_model(model_info, abs_path)
elif model_type == "image":
model = self._load_image_model(model_info, abs_path)
elif model_type == "multimodal":
model = self._load_multimodal_model(model_info, abs_path)
else:
self.logger.error(f"❌ 不支持的模型类型: {model_type}")
return False, None
# 更新状态
self.loaded_models[model_name] = model
model_info["status"] = "loaded"
model_info["last_accessed"] = time.time()
self._save_registry()
self.logger.info(f"✅ 模型加载成功: {model_name}")
return True, model
except ImportError as e:
self.logger.error(f"❌ 缺失依赖库: {str(e)}")
return False, None
except Exception as e:
self.logger.error(f"❌ 模型加载失败: {model_name}, 路径: {abs_path}, 错误: {str(e)}", exc_info=True)
model_info["status"] = "error"
return False, None
def _load_text_model(self, model_info: dict, model_path: str) -> Any:
"""加载文本模型(支持分片格式)"""
self.logger.info(f"📝 加载文本模型: {model_path}")
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu"
# 检查是否是分片模型
is_sharded = any(f.endswith('.index.json') for f in os.listdir(model_path))
if is_sharded:
self.logger.info(f"🔗 检测到分片模型,使用from_pretrained加载")
# 加载tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained(
model_path,
cache_dir=self.cache_dir,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
trust_remote_code=True
)
if device == "cuda":
model = model.to(device)
return {
"model": model,
"tokenizer": tokenizer,
"info": model_info,
"device": device
}
except ImportError:
self.logger.error("❌ transformers库未安装,无法加载文本模型")
raise
except Exception as e:
self.logger.error(f"❌ 文本模型加载失败: {str(e)}")
raise
def _load_image_model(self, model_info: dict, model_path: str) -> Any:
"""加载图像模型(支持Diffusers格式)"""
self.logger.info(f"🖼️ 加载图像模型: {model_path}")
try:
# 检查是否是Diffusers格式
if self._is_diffusers_format(model_path):
return self._load_diffusers_model(model_path, model_info)
else:
# 传统Transformers格式
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_path, cache_dir=self.cache_dir)
model = AutoModelForVision2Seq.from_pretrained(
model_path,
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
return {
"model": model,
"processor": processor,
"info": model_info
}
except ImportError:
self.logger.error("❌ 必要的库未安装,无法加载图像模型")
raise
except Exception as e:
self.logger.error(f"❌ 图像模型加载失败: {str(e)}")
raise
def _load_multimodal_model(self, model_info: dict, model_path: str) -> Any:
"""加载多模态模型"""
self.logger.info(f"🎭 加载多模态模型: {model_path}")
try:
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(model_path, cache_dir=self.cache_dir)
model = AutoModelForVision2Seq.from_pretrained(
model_path,
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
return {
"model": model,
"processor": processor,
"info": model_info
}
except ImportError:
self.logger.error("❌ transformers库未安装,无法加载多模态模型")
raise
except Exception as e:
self.logger.error(f"❌ 多模态模型加载失败: {str(e)}")
raise
def _is_diffusers_format(self, model_path: str) -> bool:
"""检查是否是Diffusers格式的模型"""
if not os.path.isdir(model_path):
return False
# Diffusers模型通常包含这些子目录
diffusers_dirs = {"unet", "vae", "text_encoder", "scheduler", "tokenizer"}
subdirs = {d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))}
return any(d in subdirs for d in diffusers_dirs)
def _load_diffusers_model(self, model_path: str, model_info: dict) -> Any:
"""加载Diffusers格式的模型"""
try:
from diffusers import StableDiffusionPipeline, DiffusionPipeline
import torch
device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu"
# 尝试加载SDXL或其他Diffusers模型
try:
pipeline = StableDiffusionPipeline.from_pretrained(
model_path,
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
except:
# 如果StableDiffusionPipeline失败,尝试通用的DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
model_path,
cache_dir=self.cache_dir,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
return {
"pipeline": pipeline,
"info": model_info,
"type": "diffusers"
}
except Exception as e:
self.logger.error(f"❌ Diffusers模型加载失败: {str(e)}")
raise
def unload_model(self, model_name: Optional[str] = None):
"""卸载模型"""
if model_name:
# 卸载指定模型
if model_name in self.loaded_models:
self._unload_single_model(model_name)
self.logger.info(f"🗑️ 已卸载模型: {model_name}")
else:
self.logger.warning(f"⚠️ 模型未加载: {model_name}")
else:
# 卸载所有模型
for name in list(self.loaded_models.keys()):
self._unload_single_model(name)
self.logger.info("🗑️ 已卸载所有模型")
def _unload_single_model(self, model_name: str):
"""卸载单个模型"""
try:
# 更新注册表状态
if model_name in self._persistent_registry:
self._persistent_registry[model_name]["status"] = "unloaded"
# 从内存中移除
if model_name in self.loaded_models:
del self.loaded_models[model_name]
# 强制垃圾回收
gc.collect()
# 如果使用GPU,清理CUDA缓存
if self.use_gpu:
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
except Exception as e:
self.logger.error(f"❌ 卸载模型失败: {model_name}, 错误: {str(e)}")
def _unload_least_recently_used(self):
"""卸载最近最少使用的模型"""
if not self.loaded_models:
return
# 找到最近最少使用的模型
lru_model = None
lru_time = float('inf')
for model_name in self.loaded_models:
if model_name in self._persistent_registry:
last_accessed = self._persistent_registry[model_name].get("last_accessed", 0)
if last_accessed < lru_time:
lru_time = last_accessed
lru_model = model_name
# 卸载该模型
if lru_model:
self.logger.info(f"🗑️ 卸载最近最少使用的模型: {lru_model}")
self._unload_single_model(lru_model)
def shutdown(self):
"""关闭模型管理器"""
self.logger.info("🛑 关闭模型管理器")
self.unload_model() # 卸载所有模型
self.logger.info("模型管理器已关闭")
2.class ModelManager:
def load_model(self, model_name: str) -> bool:
"""加载模型(添加异步加载支持)"""
if model_name not in self._persistent_registry:
self.logger.error(f"❌ 模型未注册: {model_name}")
return False
# 检查是否已在内存中
if model_name in self.loaded_models:
self.logger.info(f"ℹ️ 模型已在内存中: {model_name}")
return True
model_info = self._persistent_registry[model_name]
model_path = model_info['path']
# 异步加载标志
async_load = self.config.get("async_load", False)
try:
if async_load:
# 异步加载(不阻塞主线程)
threading.Thread(
target=self._load_model_thread,
args=(model_name, model_path, model_info['type']),
daemon=True
).start()
return True
else:
return self._load_model_thread(model_name, model_path, model_info['type'])
except Exception as e:
self.logger.error(f"❌ 加载模型失败: {model_name} - {str(e)}", exc_info=True)
return False
def _load_model_thread(self, model_name, model_path, model_type):
"""实际加载模型的线程方法"""
self.logger.info(f"🔄 后台加载模型: {model_name} ({model_type})")
# 实际加载逻辑...
# [原有加载代码]
self.logger.info(f"✅ 后台加载完成: {model_name}")
return True
最新发布