Single Thread Model接口

本文介绍了SingleThreadModel接口的作用及其实现方式,该接口确保了单个servlet实例不会被多个请求线程同时访问。主要探讨了两种实现方案:一是通过排队处理请求,每次仅处理一个;二是创建多个servlet实例。具体的实现取决于servlet容器的设计。

Single Thread Model接口保证不会有多个请求线程同时访问单个servlet实例。

在这里有两种实现方式:

1. 将所有请求排队,每次只处理一个请求

2. 创建多个servlet实例

 

具体采用哪种实现方式与servlet容器有关,tomcat6.0.24采用了第二种实现方式

# E:\AI_System\agent\model_manager.py import os import sys import logging import json import importlib import hashlib import gc import time from pathlib import Path from typing import Dict, Any, Optional, Tuple, List class ModelManager: """AI模型管理器 - 增强版(保留原有接口并添加新功能)""" MODEL_REGISTRY_FILE = "model_registry.json" def __init__(self, model_registry: Dict[str, str] = None, cache_dir: str = "model_cache", use_gpu: bool = True, max_models_in_memory: int = 3): """ 初始化模型管理器 参数: model_registry: 模型路径注册表(字典:模型名->模型路径或ID) cache_dir: 模型缓存目录 use_gpu: 是否使用GPU max_models_in_memory: 内存中最大模型数量 """ self.logger = logging.getLogger("ModelManager") self.logger.info("🚀 初始化模型管理器...") # 初始化参数 self.model_registry = model_registry or {} self.cache_dir = cache_dir self.use_gpu = use_gpu self.max_models_in_memory = max_models_in_memory # 确保缓存目录存在 os.makedirs(self.cache_dir, exist_ok=True) # 加载持久化的注册表 self._persistent_registry = self._load_registry() # 合并传入的注册表和持久化注册表 for name, path in self.model_registry.items(): self._persistent_registry.setdefault(name, { "path": path, "type": "text", # 默认为文本模型 "status": "unloaded" }) # 已加载的模型 self.loaded_models: Dict[str, Any] = {} self.logger.info(f"✅ 模型管理器初始化完成 (GPU: {'启用' if use_gpu else '禁用'})") self.logger.info(f"注册模型: {list(self.model_registry.keys())}") def _load_registry(self) -> Dict[str, dict]: """加载模型注册表""" try: registry_path = Path(self.MODEL_REGISTRY_FILE) if registry_path.exists(): with open(registry_path, 'r', encoding='utf-8') as f: registry = json.load(f) self.logger.info(f"📋 成功加载模型注册表: {self.MODEL_REGISTRY_FILE}") return registry except Exception as e: self.logger.error(f"❌ 加载模型注册表失败: {str(e)}") return {} def _save_registry(self): """保存模型注册表""" try: with open(self.MODEL_REGISTRY_FILE, 'w', encoding='utf-8') as f: json.dump(self._persistent_registry, f, indent=2, ensure_ascii=False) self.logger.debug(f"💾 模型注册表已保存: {self.MODEL_REGISTRY_FILE}") return True except Exception as e: self.logger.error(f"❌ 保存模型注册表失败: {str(e)}") return False def register_model(self, model_name: str, model_path: str, model_type: str = "text", adapter_config: Optional[dict] = None) -> bool: """ 注册新模型到注册表 参数: model_name: 模型名称 model_path: 模型路径或HuggingFace ID model_type: 模型类型 (text, image, audio, multimodal) adapter_config: 适配器配置 """ # 检查模型是否存在 if not self._check_model_exists(model_path): self.logger.error(f"❌ 模型路径不可访问: {model_path}") return False # 计算SHA256校验和 try: checksum = self._calculate_checksum(model_path) except Exception as e: self.logger.warning(f"⚠️ 无法计算校验和: {str(e)}") checksum = "unknown" # 添加到注册表 self._persistent_registry[model_name] = { "path": model_path, "type": model_type, "status": "unloaded", "checksum": checksum, "last_accessed": time.time(), "adapter": adapter_config } self.logger.info(f"✅ 模型注册成功: {model_name} ({model_type})") self._save_registry() return True def _check_model_exists(self, model_path: str) -> bool: """检查模型路径是否有效""" # 如果是HuggingFace模型ID if "/" in model_path and not os.path.exists(model_path): self.logger.info(f"🔍 检测到HuggingFace模型ID: {model_path}") return True # 检查本地路径 if os.path.exists(model_path): return True return False def _calculate_checksum(self, model_path: str) -> str: """计算模型校验和(对于大模型只计算关键文件)""" # 如果是HuggingFace模型ID,不计算校验和 if "/" in model_path and not os.path.exists(model_path): return "hf-" + hashlib.md5(model_path.encode()).hexdigest()[:8] # 本地模型 - 计算校验和 sha256 = hashlib.sha256() # 如果是目录,只计算关键文件 if os.path.isdir(model_path): key_files = ["pytorch_model.bin", "model.safetensors", "config.json"] for file in key_files: file_path = os.path.join(model_path, file) if os.path.exists(file_path): with open(file_path, 'rb') as f: while chunk := f.read(8192): sha256.update(chunk) else: # 单个模型文件 with open(model_path, 'rb') as f: while chunk := f.read(8192): sha256.update(chunk) return sha256.hexdigest() def load_model(self, model_name: str, force_reload: bool = False) -> Tuple[bool, Any]: """ 加载模型到内存 参数: model_name: 要加载的模型名称 force_reload: 是否强制重新加载 返回: (成功状态, 模型对象) """ # 检查模型是否已注册 if model_name not in self._persistent_registry: self.logger.error(f"❌ 模型未注册: {model_name}") return False, None model_info = self._persistent_registry[model_name] # 如果模型已加载且不需要强制重载 if model_name in self.loaded_models and not force_reload: self.logger.info(f"📦 模型已在内存中: {model_name}") model_info["status"] = "loaded" model_info["last_accessed"] = time.time() return True, self.loaded_models[model_name] # 检查内存占用,如有必要卸载最少使用的模型 if len(self.loaded_models) >= self.max_models_in_memory: self._unload_least_recently_used() # 实际加载模型 try: self.logger.info(f"🔄 加载模型: {model_name} ({model_info['type']})") # 根据模型类型动态导入加载器 model_type = model_info["type"] if model_type == "text": model = self._load_text_model(model_info) elif model_type == "image": model = self._load_image_model(model_info) elif model_type == "audio": model = self._load_audio_model(model_info) else: self.logger.error(f"❌ 不支持的模型类型: {model_type}") return False, None # 更新状态 self.loaded_models[model_name] = model model_info["status"] = "loaded" model_info["last_accessed"] = time.time() self._save_registry() self.logger.info(f"✅ 模型加载成功: {model_name}") return True, model except Exception as e: tb = sys.exc_info()[2] self.logger.error(f"❌ 模型加载失败: {model_name}, 错误: {str(e)}", exc_info=tb) model_info["status"] = "error" return False, None def _load_text_model(self, model_info: dict) -> Any: """加载文本模型(LLM)""" model_path = model_info["path"] # 动态导入transformers try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: self.logger.error("❌ transformers库未安装") raise RuntimeError("transformers not installed") self.logger.debug(f"🔧 加载文本模型: {model_path}") device = "cuda" if self.use_gpu else "cpu" # 加载tokenizer和模型 tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=self.cache_dir) model = AutoModelForCausalLM.from_pretrained( model_path, cache_dir=self.cache_dir, device_map=device if self.use_gpu else None ) return { "model": model, "tokenizer": tokenizer, "info": model_info } def _load_image_model(self, model_info: dict) -> Any: """加载图像模型""" model_path = model_info["path"] try: # 尝试动态导入diffusers或torchvision if "diffusion" in model_path.lower(): from diffusers import StableDiffusionPipeline device = "cuda" if self.use_gpu else "cpu" pipe = StableDiffusionPipeline.from_pretrained( model_path, cache_dir=self.cache_dir ).to(device) return {"pipeline": pipe} else: import torchvision.models as models # 简单实现 - 实际中需要更复杂的逻辑 model = models.resnet50(pretrained=True) if self.use_gpu: model = model.cuda() return {"model": model} except ImportError as e: self.logger.error("❌ 必要的图像模型库未安装") raise def _load_audio_model(self, model_info: dict) -> Any: """加载音频模型""" # 实现类似文本模型的动态加载 try: # 示例:语音识别模型 from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC processor = Wav2Vec2Processor.from_pretrained(model_info["path"], cache_dir=self.cache_dir) model = Wav2Vec2ForCTC.from_pretrained(model_info["path"], cache_dir=self.cache_dir) device = "cuda" if self.use_gpu else "cpu" if self.use_gpu: model = model.to(device) return { "processor": processor, "model": model } except ImportError: self.logger.error("❌ transformers库未安装") raise def unload_model(self, model_name: str = None) -> bool: """ 从内存中卸载模型 参数: model_name: 要卸载的模型名称(None表示卸载所有模型) """ if model_name is None: self.logger.info("卸载所有模型") for name in list(self.loaded_models.keys()): self._unload_single_model(name) return True return self._unload_single_model(model_name) def _unload_single_model(self, model_name: str) -> bool: """卸载单个模型""" if model_name not in self.loaded_models: self.logger.warning(f"⚠️ 模型未加载: {model_name}") return False try: # 删除模型引用并调用垃圾回收 del self.loaded_models[model_name] gc.collect() # 更新注册表状态 if model_name in self._persistent_registry: self._persistent_registry[model_name]["status"] = "unloaded" self._save_registry() self.logger.info(f"🗑️ 模型已卸载: {model_name}") return True except Exception as e: self.logger.error(f"❌ 卸载模型失败: {model_name}, 错误: {str(e)}") return False def _unload_least_recently_used(self): """卸载最近最少使用的模型""" if not self.loaded_models: return # 找到最近最少使用的模型 lru_model = None lru_time = time.time() for model_name in self.loaded_models: if model_name in self._persistent_registry: access_time = self._persistent_registry[model_name].get("last_accessed", 0) if access_time < lru_time: lru_time = access_time lru_model = model_name # 卸载该模型 if lru_model: self.unload_model(lru_model) self.logger.info(f"🔁 已释放内存: 卸载模型 {lru_model}") def get_model(self, model_name: str) -> Optional[Any]: """ 获取已加载的模型 参数: model_name: 模型名称 返回: 模型对象或None """ if model_name in self.loaded_models: # 更新最后访问时间 if model_name in self._persistent_registry: self._persistent_registry[model_name]["last_accessed"] = time.time() self._save_registry() return self.loaded_models[model_name] return None def get_model_info(self, model_name: str) -> Dict[str, Any]: """获取模型信息""" if model_name in self._persistent_registry: return self._persistent_registry[model_name] return {"status": "unknown"} def list_models(self) -> List[str]: """列出所有注册模型""" return list(self._persistent_registry.keys()) def shutdown(self): """关闭模型管理器,卸载所有模型""" self.logger.info("🛑 关闭模型管理器...") self.unload_model() # 卸载所有模型 self.logger.info("✅ 模型管理器已关闭") 你还缺什么 你可以告诉我 我都发给你 请把你之前说的 需要改的 改好的可以直接覆盖的文件发我
08-31
1.# E:\AI_System\agent\model_manager.py import os import sys import logging import json import hashlib import gc import time from pathlib import Path from typing import Dict, Any, Optional, Tuple, List # 添加项目根目录到路径 sys.path.append(str(Path(__file__).parent.parent)) # 导入工具函数 try: from utils.path_utils import normalize_path, is_valid_hf_id, resolve_shortcut except ImportError: # 如果导入失败,提供基本实现 def normalize_path(path): """规范化路径""" return os.path.abspath(os.path.expanduser(path)) def is_valid_hf_id(model_id): """检查是否为有效的HuggingFace模型ID""" return isinstance(model_id, str) and "/" in model_id and len(model_id.split("/")) == 2 def resolve_shortcut(shortcut_path): """解析Windows快捷方式""" try: import win32com.client shell = win32com.client.Dispatch("WScript.Shell") shortcut = shell.CreateShortCut(shortcut_path) return shortcut.Targetpath except: return None class ModelManager: """AI模型管理器 - 完整修复版(星型架构适配版)""" MODEL_REGISTRY_FILE = "model_registry.json" DEFAULT_MODEL_PATHS = { "TEXT_BASE": "Qwen2-7B", "TEXT_CHAT": "deepseek-7b-chat", "IMAGE_MODEL": "sdxl", "MULTIMODAL": "deepseek-vl2", "YI_VL": "yi-vl" } def __init__( self, config: Dict[str, Any] = None, cache_dir: str = "model_cache", use_gpu: bool = True, max_models_in_memory: int = 3 ): # 配置日志 self.logger = logging.getLogger("ModelManager") self.logger.setLevel(logging.INFO) if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.info("🚀 初始化模型管理器...") # 初始化参数 self.config = config or {} self.cache_dir = normalize_path(cache_dir) self.use_gpu = use_gpu self.max_models_in_memory = max_models_in_memory # 确保缓存目录存在 os.makedirs(self.cache_dir, exist_ok=True) # 解析模型根目录 self.models_root = self._resolve_models_root() self.logger.info(f"✅ 模型根目录设置为: {self.models_root}") # 加载或创建注册表 self._persistent_registry = self._load_or_create_registry() # 已加载的模型 self.loaded_models: Dict[str, Any] = {} # 自动注册默认模型 self._register_default_models() self.logger.info(f"✅ 模型管理器初始化完成 (GPU: {'启用' if use_gpu else '禁用'})") self.logger.info(f"已注册模型: {list(self._persistent_registry.keys())}") # 星型架构协调器引用 self.orchestrator = None def set_orchestrator(self, orchestrator): """设置协调器引用""" self.orchestrator = orchestrator self.logger.info("✅ 已设置协调器引用") def is_healthy(self) -> bool: """检查管理器健康状况""" try: # 检查注册表是否可访问 registry_path = Path(normalize_path(self.MODEL_REGISTRY_FILE)) if not registry_path.exists(): self.logger.warning("⚠️ 模型注册表文件不存在") return False # 检查模型根目录是否可访问 if not os.path.exists(self.models_root): self.logger.error(f"❌ 模型根目录不存在: {self.models_root}") return False return True except Exception as e: self.logger.error(f"❌ 健康检查失败: {str(e)}") return False def get_status(self) -> Dict[str, Any]: """获取管理器状态""" loaded_count = len(self.loaded_models) total_count = len(self._persistent_registry) return { "status": "healthy" if self.is_healthy() else "unhealthy", "models_loaded": loaded_count, "models_registered": total_count, "models_root": self.models_root, "cache_dir": self.cache_dir, "use_gpu": self.use_gpu, "max_models_in_memory": self.max_models_in_memory, "loaded_models": list(self.loaded_models.keys()) } def _resolve_models_root(self) -> str: """解析模型存储根目录""" # 尝试多个可能的快捷方式路径 shortcut_paths = [ r"E:\AI_Workspace\01_模型存储\主模型.lnk", r"E:\AI_Workspace\01_模型存储\模型.lnk", r"E:\AI_Workspace\模型存储\主模型.lnk" ] for shortcut_path in shortcut_paths: try: target_path = resolve_shortcut(shortcut_path) if target_path and os.path.exists(target_path): normalized_path = normalize_path(target_path) self.logger.info(f"🔗 解析快捷方式成功: {shortcut_path} -> {normalized_path}") return normalized_path except Exception as e: self.logger.warning(f"⚠️ 快捷方式解析失败: {shortcut_path} - {e}") # 尝试直接使用可能的模型目录 possible_paths = [ r"E:\AI_Models", r"E:\AI_System\AI_Models", r"E:\AI_Workspace\01_模型存储", r"E:\AI_Workspace\模型存储" ] for path in possible_paths: if os.path.exists(path): normalized_path = normalize_path(path) self.logger.info(f"📁 使用现有模型目录: {normalized_path}") return normalized_path # 创建默认目录 fallback = r"E:\AI_Models" os.makedirs(fallback, exist_ok=True) self.logger.warning(f"⚠️ 创建默认模型路径: {fallback}") return normalize_path(fallback) def _load_or_create_registry(self) -> Dict[str, dict]: """加载或创建模型注册表""" try: registry_path = Path(normalize_path(self.MODEL_REGISTRY_FILE)) if registry_path.exists(): with open(registry_path, 'r', encoding='utf-8') as f: registry = json.load(f) self.logger.info(f"📋 成功加载模型注册表: {registry_path}") return registry self.logger.warning(f"⚠️ 模型注册表不存在,创建新文件: {registry_path}") with open(registry_path, 'w', encoding='utf-8') as f: json.dump({}, f, indent=2) return {} except Exception as e: self.logger.error(f"❌ 处理模型注册表失败: {str(e)}") return {} def _save_registry(self): """保存模型注册表""" try: registry_path = Path(normalize_path(self.MODEL_REGISTRY_FILE)) with open(registry_path, 'w', encoding='utf-8') as f: json.dump(self._persistent_registry, f, indent=2, ensure_ascii=False) self.logger.info("✅ 模型注册表已保存") except Exception as e: self.logger.error(f"❌ 保存模型注册表失败: {str(e)}") def _register_default_models(self): """注册配置文件中的默认模型""" model_settings = self.config.get("model_settings", {}) for model_name, rel_path in self.DEFAULT_MODEL_PATHS.items(): # 构建完整路径:模型根目录 + 相对路径 abs_path = os.path.join(self.models_root, rel_path) # 转换为相对于模型根目录的相对路径 # 注册模型 if model_name not in self._persistent_registry: model_type = model_settings.get(model_name, {}).get("type", "text") # 注意:这里我们使用rel_path作为存储的路径(相对于模型根目录) # 但是,我们需要检查该路径是否存在,所以使用绝对路径检查 exists, is_local = self._check_model_exists(abs_path) if not exists: self.logger.error(f"❌ 默认模型路径不可访问: {abs_path}") continue self._persistent_registry[model_name] = { "path": rel_path, # 存储相对路径 "type": model_type, "status": "unloaded", "checksum": "unknown", # 默认模型不计算校验和(如果需要,可以后续计算) "last_accessed": time.time(), "adapter": None, "is_local": is_local } self.logger.info(f"✅ 默认模型注册成功: {model_name} ({model_type})") self._save_registry() def register_model( self, model_name: str, model_path: str, model_type: str = "text", adapter_config: Optional[dict] = None ) -> bool: """注册新模型 注意:model_path可以是绝对路径,也可以是相对路径,或者HuggingFace模型ID。 如果是本地路径,将转换为相对于模型根目录的相对路径存储。 如果是HuggingFace模型ID,则直接存储ID(不转换为相对路径)。 """ # 检查模型是否存在 exists, is_local = self._check_model_exists(model_path) if not exists: self.logger.error(f"❌ 模型路径不可访问: {model_path}") return False # 如果是本地路径,转换为相对于模型根目录的相对路径 if is_local: # 首先规范化路径 abs_path = normalize_path(model_path) # 计算相对于模型根目录的相对路径 try: # 注意:如果abs_path不在models_root下,则使用绝对路径(但这种情况应该避免) rel_path = os.path.relpath(abs_path, self.models_root) except ValueError: # 在不同驱动器上时,无法计算相对路径,使用绝对路径 rel_path = abs_path self.logger.warning(f"⚠️ 模型路径与模型根目录不在同一驱动器,使用绝对路径: {abs_path}") stored_path = rel_path else: # HuggingFace模型ID,直接存储 stored_path = model_path # 计算校验和(如果是本地文件) checksum = "unknown" if is_local: try: checksum = self._calculate_checksum(abs_path) except Exception as e: self.logger.warning(f"⚠️ 无法计算校验和: {str(e)}") checksum = "error" # 添加到注册表 self._persistent_registry[model_name] = { "path": stored_path, # 存储相对路径或HF ID "type": model_type, "status": "unloaded", "checksum": checksum, "last_accessed": time.time(), "adapter": adapter_config, "is_local": is_local } self.logger.info(f"✅ 模型注册成功: {model_name} ({model_type})") self._save_registry() return True def _check_model_exists(self, model_path: str) -> Tuple[bool, bool]: """检查模型路径是否有效,返回(是否存在,是否是本地路径)""" # 如果是HuggingFace模型ID if is_valid_hf_id(model_path): self.logger.info(f"🔍 检测到HuggingFace模型ID: {model_path}") return True, False # 检查本地路径 abs_path = normalize_path(model_path) if os.path.exists(abs_path): return True, True # 尝试相对路径(相对于当前工作目录) if os.path.exists(model_path): return True, True return False, False def _calculate_checksum(self, model_path: str) -> str: """计算模型校验和(使用绝对路径)""" abs_path = normalize_path(model_path) if os.path.isdir(abs_path): sha256 = hashlib.sha256() key_files = ["pytorch_model.bin", "model.safetensors", "config.json"] for root, _, files in os.walk(abs_path): for file in files: if file in key_files: file_path = os.path.join(root, file) with open(file_path, 'rb') as f: while chunk := f.read(8192): sha256.update(chunk) return sha256.hexdigest() # 单个模型文件 with open(abs_path, 'rb') as f: return hashlib.sha256(f.read()).hexdigest() def load_model(self, model_name: str, force_reload: bool = False) -> Tuple[bool, Any]: """加载模型到内存""" if model_name not in self._persistent_registry: self.logger.error(f"❌ 模型未注册: {model_name}") return False, None model_info = self._persistent_registry[model_name] stored_path = model_info["path"] is_local = model_info.get("is_local", True) # 构建绝对路径:如果是本地路径,则与模型根目录组合;否则为HF ID,直接使用 if is_local: # 如果是相对路径,则组合;如果是绝对路径(在注册时可能由于不同驱动器而存储为绝对路径),则直接使用 if os.path.isabs(stored_path): abs_path = stored_path else: abs_path = os.path.join(self.models_root, stored_path) abs_path = normalize_path(abs_path) else: abs_path = stored_path # HF ID # 如果模型已加载且不需要强制重载 if model_name in self.loaded_models and not force_reload: self.logger.info(f"📦 模型已在内存中: {model_name}") model_info["last_accessed"] = time.time() return True, self.loaded_models[model_name] # 检查内存占用 if len(self.loaded_models) >= self.max_models_in_memory: self._unload_least_recently_used() # 实际加载模型 try: self.logger.info(f"🔄 加载模型: {model_name} ({model_info['type']})") model_type = model_info["type"] if model_type == "text": model = self._load_text_model(model_info, abs_path) elif model_type == "image": model = self._load_image_model(model_info, abs_path) elif model_type == "multimodal": model = self._load_multimodal_model(model_info, abs_path) else: self.logger.error(f"❌ 不支持的模型类型: {model_type}") return False, None # 更新状态 self.loaded_models[model_name] = model model_info["status"] = "loaded" model_info["last_accessed"] = time.time() self._save_registry() self.logger.info(f"✅ 模型加载成功: {model_name}") return True, model except ImportError as e: self.logger.error(f"❌ 缺失依赖库: {str(e)}") return False, None except Exception as e: self.logger.error(f"❌ 模型加载失败: {model_name}, 路径: {abs_path}, 错误: {str(e)}", exc_info=True) model_info["status"] = "error" return False, None def _load_text_model(self, model_info: dict, model_path: str) -> Any: """加载文本模型(支持分片格式)""" self.logger.info(f"📝 加载文本模型: {model_path}") try: from transformers import AutoTokenizer, AutoModelForCausalLM import torch device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" # 检查是否是分片模型 is_sharded = any(f.endswith('.index.json') for f in os.listdir(model_path)) if is_sharded: self.logger.info(f"🔗 检测到分片模型,使用from_pretrained加载") # 加载tokenizer和模型 tokenizer = AutoTokenizer.from_pretrained( model_path, cache_dir=self.cache_dir, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( model_path, cache_dir=self.cache_dir, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, trust_remote_code=True ) if device == "cuda": model = model.to(device) return { "model": model, "tokenizer": tokenizer, "info": model_info, "device": device } except ImportError: self.logger.error("❌ transformers库未安装,无法加载文本模型") raise except Exception as e: self.logger.error(f"❌ 文本模型加载失败: {str(e)}") raise def _load_image_model(self, model_info: dict, model_path: str) -> Any: """加载图像模型(支持Diffusers格式)""" self.logger.info(f"🖼️ 加载图像模型: {model_path}") try: # 检查是否是Diffusers格式 if self._is_diffusers_format(model_path): return self._load_diffusers_model(model_path, model_info) else: # 传统Transformers格式 from transformers import AutoProcessor, AutoModelForVision2Seq import torch device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_path, cache_dir=self.cache_dir) model = AutoModelForVision2Seq.from_pretrained( model_path, cache_dir=self.cache_dir, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) return { "model": model, "processor": processor, "info": model_info } except ImportError: self.logger.error("❌ 必要的库未安装,无法加载图像模型") raise except Exception as e: self.logger.error(f"❌ 图像模型加载失败: {str(e)}") raise def _load_multimodal_model(self, model_info: dict, model_path: str) -> Any: """加载多模态模型""" self.logger.info(f"🎭 加载多模态模型: {model_path}") try: from transformers import AutoProcessor, AutoModelForVision2Seq import torch device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_path, cache_dir=self.cache_dir) model = AutoModelForVision2Seq.from_pretrained( model_path, cache_dir=self.cache_dir, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) return { "model": model, "processor": processor, "info": model_info } except ImportError: self.logger.error("❌ transformers库未安装,无法加载多模态模型") raise except Exception as e: self.logger.error(f"❌ 多模态模型加载失败: {str(e)}") raise def _is_diffusers_format(self, model_path: str) -> bool: """检查是否是Diffusers格式的模型""" if not os.path.isdir(model_path): return False # Diffusers模型通常包含这些子目录 diffusers_dirs = {"unet", "vae", "text_encoder", "scheduler", "tokenizer"} subdirs = {d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d))} return any(d in subdirs for d in diffusers_dirs) def _load_diffusers_model(self, model_path: str, model_info: dict) -> Any: """加载Diffusers格式的模型""" try: from diffusers import StableDiffusionPipeline, DiffusionPipeline import torch device = "cuda" if self.use_gpu and torch.cuda.is_available() else "cpu" # 尝试加载SDXL或其他Diffusers模型 try: pipeline = StableDiffusionPipeline.from_pretrained( model_path, cache_dir=self.cache_dir, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) except: # 如果StableDiffusionPipeline失败,尝试通用的DiffusionPipeline pipeline = DiffusionPipeline.from_pretrained( model_path, cache_dir=self.cache_dir, torch_dtype=torch.float16 if device == "cuda" else torch.float32 ).to(device) return { "pipeline": pipeline, "info": model_info, "type": "diffusers" } except Exception as e: self.logger.error(f"❌ Diffusers模型加载失败: {str(e)}") raise def unload_model(self, model_name: Optional[str] = None): """卸载模型""" if model_name: # 卸载指定模型 if model_name in self.loaded_models: self._unload_single_model(model_name) self.logger.info(f"🗑️ 已卸载模型: {model_name}") else: self.logger.warning(f"⚠️ 模型未加载: {model_name}") else: # 卸载所有模型 for name in list(self.loaded_models.keys()): self._unload_single_model(name) self.logger.info("🗑️ 已卸载所有模型") def _unload_single_model(self, model_name: str): """卸载单个模型""" try: # 更新注册表状态 if model_name in self._persistent_registry: self._persistent_registry[model_name]["status"] = "unloaded" # 从内存中移除 if model_name in self.loaded_models: del self.loaded_models[model_name] # 强制垃圾回收 gc.collect() # 如果使用GPU,清理CUDA缓存 if self.use_gpu: try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except ImportError: pass except Exception as e: self.logger.error(f"❌ 卸载模型失败: {model_name}, 错误: {str(e)}") def _unload_least_recently_used(self): """卸载最近最少使用的模型""" if not self.loaded_models: return # 找到最近最少使用的模型 lru_model = None lru_time = float('inf') for model_name in self.loaded_models: if model_name in self._persistent_registry: last_accessed = self._persistent_registry[model_name].get("last_accessed", 0) if last_accessed < lru_time: lru_time = last_accessed lru_model = model_name # 卸载该模型 if lru_model: self.logger.info(f"🗑️ 卸载最近最少使用的模型: {lru_model}") self._unload_single_model(lru_model) def shutdown(self): """关闭模型管理器""" self.logger.info("🛑 关闭模型管理器") self.unload_model() # 卸载所有模型 self.logger.info("模型管理器已关闭") 2.class ModelManager: def load_model(self, model_name: str) -> bool: """加载模型(添加异步加载支持)""" if model_name not in self._persistent_registry: self.logger.error(f"❌ 模型未注册: {model_name}") return False # 检查是否已在内存中 if model_name in self.loaded_models: self.logger.info(f"ℹ️ 模型已在内存中: {model_name}") return True model_info = self._persistent_registry[model_name] model_path = model_info['path'] # 异步加载标志 async_load = self.config.get("async_load", False) try: if async_load: # 异步加载(不阻塞主线程) threading.Thread( target=self._load_model_thread, args=(model_name, model_path, model_info['type']), daemon=True ).start() return True else: return self._load_model_thread(model_name, model_path, model_info['type']) except Exception as e: self.logger.error(f"❌ 加载模型失败: {model_name} - {str(e)}", exc_info=True) return False def _load_model_thread(self, model_name, model_path, model_type): """实际加载模型的线程方法""" self.logger.info(f"🔄 后台加载模型: {model_name} ({model_type})") # 实际加载逻辑... # [原有加载代码] self.logger.info(f"✅ 后台加载完成: {model_name}") return True
最新发布
08-31
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值