解决mmdet 训练提示gt_mask = instance[‘mask‘] KeyError: ‘mask‘

博客介绍了通过coco数据集“annotations”里的“segmentation”计算mask的方法,还提到要检查coco数据集中“segmentation”是否为空,并且可将“bbox”传入“segmentation”,同时给出了参考链接。

该字段为通过coco数据集”annotations“中的“segmentation”来计算mask,检查一下coco数据集中的“segmentation”是否为空。可以将“bbox”传入“segmentation”

with open ('valid.json',encoding='utf-8') as f:
    json_info = json.load(f)

for i in json_info["annotations"]:
    [x,y,w,h] = i["bbox"]
    segmentation = [[x,y,(x+w),y,(x+w),(y+h),x,(y+h)]]
    i["segmentation"] = segmentation
with open ('valid_1.json','w',encoding='utf-8') as f:
    json.dump(json_info, f, indent=1)

参考链接:KeyError: 'mask' --- gt_mask = instance['mask'] · Issue #10961 · open-mmlab/mmdetection (github.com)

# E:\AI_System\.env # ======================== # AI 系统环境变量配置 # ======================== # 环境类型 (dev, test, prod) ENV=dev # 目录配置 (使用双下划线表示层级) AI_SYSTEM_DIRECTORIES__PROJECT_ROOT=E:\AI_System AI_SYSTEM_DIRECTORIES__AGENT_DIR=E:\AI_System\agent AI_SYSTEM_DIRECTORIES__WEB_UI_DIR=E:\AI_System\web_ui AI_SYSTEM_DIRECTORIES__DEFAULT_MODEL=E:\AI_Models\Qwen2-7B # 日志配置 AI_SYSTEM_ENVIRONMENT__LOG_LEVEL=DEBUG # 数据库配置 AI_SYSTEM_DATABASE__DB_HOST=localhost AI_SYSTEM_DATABASE__DB_PORT=5432 AI_SYSTEM_DATABASE__DB_NAME=ai_system AI_SYSTEM_DATABASE__DB_USER=ai_user AI_SYSTEM_DATABASE__DB_PASSWORD=your_secure_password_here # 安全配置 AI_SYSTEM_SECURITY__SECRET_KEY=your_generated_secret_key_here # 模型配置 AI_SYSTEM_MODEL_PATHS__TEXT_BASE=E:\AI_Models\Qwen2-7B AI_SYSTEM_MODEL_PATHS__TEXT_CHAT=E:\AI_Models\deepseek-7b-chat AI_SYSTEM_MODEL_PATHS__MULTIMODAL=E:\AI_Models\deepseek-vl2 AI_SYSTEM_MODEL_PATHS__IMAGE_GEN=E:\AI_Models\sdxl AI_SYSTEM_MODEL_PATHS__YI_VL=E:\AI_Models\yi-vl AI_SYSTEM_MODEL_PATHS__STABLE_DIFFUSION=E:\AI_Models\stable-diffusion-xl-base-1.0 # 网络配置 AI_SYSTEM_NETWORK__HOST=0.0.0.0 AI_SYSTEM_NETWORK__FLASK_PORT=8000 AI_SYSTEM_NETWORK__GRADIO_PORT=7860 “# 数据库凭证 DB_PASSWORD=your_actual_db_password_here # 应用密钥 SECRET_KEY=generated-secret-key-here # 其他环境变量 MODEL_BASE_PATH=E:/AI_Models PROJECT_ROOT=E:/AI_System ” # E:\AI_System\core\config.py import os import sys import json import logging from pathlib import Path from dotenv import load_dotenv from prettytable import PrettyTable # 设置日志 logger = logging.getLogger('CoreConfig') logger.setLevel(logging.INFO) # 确保有基本日志处理器 if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) class CoreConfig: """核心配置系统 - 支持环境变量、配置文件和默认值的优先级加载""" _instance = None _config = {} @classmethod def get_instance(cls): """获取单例实例""" if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): """初始化配置系统""" # 设置基础目录 self.base_dir = Path(__file__).resolve().parent.parent self.env_prefix = "AI_SYSTEM" # 敏感字段列表(在日志和输出中掩码) self.sensitive_fields = ["DB_PASSWORD", "SECRET_KEY", "API_KEY", "ACCESS_TOKEN"] # 路径类型配置键 self.path_keys = [ "LOG_DIR", "CONFIG_DIR", "MODEL_CACHE_DIR", "MODEL_BASE_PATH", "WEB_UI_DIR", "AGENT_DIR", "CORE_DIR", "MODELS_DIR", "LOGS_DIR", "TEXT_BASE", "TEXT_CHAT", "MULTIMODAL", "IMAGE_GEN", "YI_VL", "STABLE_DIFFUSION" ] # 加载配置 self._load_config() # 验证并创建缺失路径 self._validate_paths() # 新增:路径验证 logger.info("✅ 配置系统初始化完成") # === 新增:路径验证和创建功能 === def _validate_paths(self): """验证并创建缺失的关键路径""" # 1. 验证目录路径 dir_paths = [ self.get("LOG_DIR"), self.get("CONFIG_DIR"), self.get("MODEL_CACHE_DIR"), self.get("MODEL_BASE_PATH"), self.get("WEB_UI_DIR"), self.get("AGENT_DIR"), self.get("DIRECTORIES.PROJECT_ROOT") ] for path in dir_paths: if path: try: resolved_path = Path(path) if not resolved_path.exists(): resolved_path.mkdir(parents=True, exist_ok=True) logger.info(f"📁 创建缺失目录: {resolved_path}") except Exception as e: logger.error(f"❌ 创建目录失败: {path} - {str(e)}") # 2. 验证模型路径 model_paths = self.get("MODEL_PATHS", {}) for model_type, path in model_paths.items(): if path: try: resolved_path = Path(path) if not resolved_path.exists(): resolved_path.mkdir(parents=True, exist_ok=True) logger.info(f"📁 创建缺失模型路径: {resolved_path}") except Exception as e: logger.error(f"❌ 创建模型路径失败: {path} - {str(e)}") # === 以下保留原有功能不变 === def __getattr__(self, name): """允许通过属性方式访问配置项""" if name in self._config: return self._config[name] # 尝试访问嵌套配置 if '.' in name: return self.get_nested(name) # 提供一些常用配置的默认值 if name == "DEFAULT_MODEL": return self.get("MODEL_PATHS.TEXT_BASE", "") # 记录警告而不是直接抛出异常 logger.warning(f"访问未定义的配置项: {name}") return None def __getitem__(self, key): """通过键访问配置值""" return self.get(key) def __contains__(self, key): """检查键是否存在""" return self.get(key) is not None def _mask_sensitive_value(self, key, value): """对敏感信息进行掩码处理""" if value and any(sensitive_key in key for sensitive_key in self.sensitive_fields): return "******" return value def _log_sensitive_value(self, key, value): """在日志中安全地记录敏感信息""" if any(sensitive_key in key for sensitive_key in self.sensitive_fields): logger.info(f"🔄 环境变量覆盖: {key}=******") else: logger.info(f"🔄 环境变量覆盖: {key}={value}") def _set_defaults(self): """设置默认配置值""" # 系统路径配置 defaults = { # 顶级配置项 "LOG_DIR": str(self.base_dir / "logs"), "CONFIG_DIR": str(self.base_dir / "config"), "MODEL_CACHE_DIR": str(self.base_dir / "model_cache"), "AGENT_NAME": "小蓝", "DEFAULT_USER": "管理员", "MAX_WORKERS": 4, "AGENT_RESPONSE_TIMEOUT": 30.0, "MODEL_BASE_PATH": "E:/AI_Models", # 嵌套配置结构 "MODEL_PATHS": { "TEXT_BASE": "E:/AI_Models/Qwen2-7B", "TEXT_CHAT": "E:/AI_Models/deepseek-7b-chat", "MULTIMODAL": "E:/AI_Models/deepseek-vl2", "IMAGE_GEN": "E:/AI_Models/sdxl", "YI_VL": "E:/AI_Models/yi-vl", "STABLE_DIFFUSION": "E:/AI_Models/stable-diffusion-xl-base-1.0" }, "NETWORK": { "HOST": "0.0.0.0", "FLASK_PORT": 8000, "GRADIO_PORT": 7860 }, "DATABASE": { "DB_HOST": "localhost", "DB_PORT": 5432, "DB_NAME": "ai_system", "DB_USER": "ai_user", "DB_PASSWORD": "" }, "SECURITY": { "SECRET_KEY": "default-secret-key-change-in-production" }, "ENVIRONMENT": { "ENV": "dev", "LOG_LEVEL": "INFO", "USE_GPU": True }, "DIRECTORIES": { "DEFAULT_MODEL": "E:/AI_Models/Qwen2-7B", "WEB_UI_DIR": str(self.base_dir / "web_ui"), "AGENT_DIR": str(self.base_dir / "agent"), "PROJECT_ROOT": str(self.base_dir) } } # 初始化配置字典 self._config = defaults logger.debug("设置默认配置值") def _load_config_files(self): """加载配置文件""" # 确保配置目录存在 config_dir = Path(self._config.get("CONFIG_DIR", self.base_dir / "config")) config_dir.mkdir(exist_ok=True, parents=True) # 配置加载顺序 config_files = [ config_dir / 'default.json', config_dir / 'local.json' ] for config_file in config_files: if config_file.exists(): try: with open(config_file, 'r', encoding='utf-8') as f: config_data = json.load(f) # 递归合并配置 self._merge_config(self._config, config_data) # 掩码敏感信息用于日志 masked_data = self._mask_config(config_data) logger.info(f"📂 从 {config_file} 加载配置: {json.dumps(masked_data, indent=2)}") except Exception as e: logger.error(f"❌ 加载配置文件 {config_file} 错误: {str(e)}") else: logger.info(f"ℹ️ 配置文件不存在: {config_file},跳过") def _merge_config(self, base, new): """递归合并配置""" for key, value in new.items(): if isinstance(value, dict) and key in base and isinstance(base[key], dict): # 递归合并嵌套字典 self._merge_config(base[key], value) else: # 设置或覆盖值 base[key] = value # 处理路径配置 if key in self.path_keys and isinstance(value, str): # 确保路径使用正斜杠 base[key] = value.replace('\\', '/') def _mask_config(self, config_data): """递归掩码敏感配置""" if isinstance(config_data, dict): return {k: self._mask_config(v) for k, v in config_data.items()} elif isinstance(config_data, str): return self._mask_sensitive_value("", config_data) return config_data def _load_environment(self): """加载环境变量 - 修复了环境变量处理逻辑""" # 加载.env文件 env_file = self.base_dir / '.env' if env_file.exists(): try: load_dotenv(dotenv_path=str(env_file), override=True) logger.info(f"🌐 从 {env_file} 加载环境变量") except Exception as e: logger.error(f"❌ 加载环境变量失败: {str(e)}") # 覆盖环境变量中的配置 for key, value in os.environ.items(): # 检查是否以环境前缀开头 if key.startswith(self.env_prefix + "_"): # 去掉前缀并转换为小写 config_key = key[len(self.env_prefix) + 1:].lower() # 处理特殊值 if value.lower() in ['true', 'false']: value = value.lower() == 'true' elif value.isdigit(): value = int(value) elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: try: value = float(value) except ValueError: pass # 保持字符串 # 设置配置值 self._set_nested_config(config_key, value) self._log_sensitive_value(key, value) elif key in self._config: # 处理非前缀的顶级配置 # 处理特殊值 if value.lower() in ['true', 'false']: value = value.lower() == 'true' elif value.isdigit(): value = int(value) elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: try: value = float(value) except ValueError: pass # 保持字符串 self._config[key] = value logger.info(f"🔄 环境变量覆盖: {key}={self._mask_sensitive_value(key, value)}") def _set_nested_config(self, key_path, value): """设置嵌套配置值 - 支持点分隔和双下划线分隔的键""" # 支持两种分隔符:点号(.)或双下划线(__) if '.' in key_path: keys = key_path.split('.') else: keys = key_path.split('__') # 使用双下划线表示嵌套层级 current = self._config # 遍历键路径 for i, key in enumerate(keys): key = key.upper() # 统一使用大写键名 # 如果是最后一个键,设置值 if i == len(keys) - 1: # 处理路径配置 if key in self.path_keys and isinstance(value, str): value = value.replace('\\', '/') current[key] = value else: # 确保中间层级是字典 if key not in current or not isinstance(current[key], dict): current[key] = {} current = current[key] def validate_model_paths(self): """验证所有模型路径是否存在""" model_keys = [ "TEXT_BASE", "TEXT_CHAT", "MULTIMODAL", "IMAGE_GEN", "YI_VL", "STABLE_DIFFUSION" ] results = {} valid_count = 0 for key in model_keys: # 使用点分路径获取配置 path = self.get(f"MODEL_PATHS.{key}", "") if path: path_obj = Path(path) exists = path_obj.exists() valid = exists if not exists: logger.warning(f"⚠️ 模型路径不存在: {key} = {path}") elif not any(path_obj.iterdir()): logger.warning(f"⚠️ 模型路径为空目录: {key} = {path}") valid = False if valid: valid_count += 1 results[key] = { "path": str(path_obj), "exists": exists, "valid": valid } else: results[key] = { "path": "", "exists": False, "valid": False } logger.warning(f"⚠️ 模型路径未配置: {key}") # 添加总体状态 results["overall"] = { "total_models": len(model_keys), "valid_models": valid_count, "all_valid": valid_count == len(model_keys) } return results def get(self, key_path, default=None, sep="."): """获取配置值(支持点分路径)""" keys = key_path.split(sep) value = self._config try: for key in keys: key = key.upper() # 统一使用大写键名 if isinstance(value, dict) and key in value: value = value[key] else: return default return value except (KeyError, TypeError): return default def to_dict(self, mask_sensitive=True): """返回当前配置的字典表示""" if mask_sensitive: return self._mask_config(self._config.copy()) return self._config.copy() def get_nested(self, key_path, default=None, sep="."): """获取嵌套配置值(别名)""" return self.get(key_path, default, sep) def print_config_summary(self): """打印配置摘要""" table = PrettyTable() table.field_names = ["配置路径", "值"] table.align["配置路径"] = "l" table.align["值"] = "l" # 添加关键配置项 key_items = [ "AGENT_NAME", "DIRECTORIES.PROJECT_ROOT", "LOG_DIR", "AGENT_DIR", "WEB_UI_DIR", "NETWORK.HOST", "NETWORK.FLASK_PORT", "MODEL_PATHS.TEXT_BASE", "ENVIRONMENT.USE_GPU", "ENVIRONMENT.LOG_LEVEL" ] for key_path in key_items: value = self.get(key_path, "未设置") if isinstance(value, str) and len(value) &gt; 50: value = value[:47] + "..." table.add_row([key_path, value]) print("\n系统配置摘要:") print(table) def _load_config(self): """加载所有配置""" # 1. 设置默认值 self._set_defaults() # 2. 加载配置文件 self._load_config_files() # 3. 加载环境变量 self._load_environment() # 4. 验证关键路径 self.validate_model_paths() # 创建全局配置实例 config = CoreConfig.get_instance() # 测试代码 if __name__ == "__main__": # 设置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) print("=" * 50) print("配置系统测试") print("=" * 50) # 获取配置实例 config = CoreConfig.get_instance() # 打印基本配置 print(f"AGENT_NAME: {config.get('AGENT_NAME')}") print(f"PROJECT_ROOT: {config.get('DIRECTORIES.PROJECT_ROOT')}") print(f"LOG_DIR: {config.get('LOG_DIR')}") print(f"AGENT_DIR: {config.get('DIRECTORIES.AGENT_DIR')}") print(f"WEB_UI_DIR: {config.get('DIRECTORIES.WEB_UI_DIR')}") print(f"DB_HOST: {config.get('DATABASE.DB_HOST')}") print(f"DEFAULT_MODEL: {config.get('MODEL_PATHS.TEXT_BASE')}") # 验证模型路径 print("\n模型路径验证结果:") for model, info in config.validate_model_paths().items(): if model == "overall": continue status = "✅ 有效" if info["valid"] else "❌ 无效" print(f"{model:20} {status} ({info['path']})") # 打印配置摘要 config.print_config_summary() print("\n测试完成!") “# 在core/config.py中存在 def _replace_variables(self, value: str) -&gt; str: """替换字符串中的变量引用""" if value.startswith("${ENV:") and value.endswith("}"): # 环境变量引用 var_name = value[6:-1] return os.getenv(var_name, "") elif value.startswith("${") and "}" in value: # 配置变量引用 var_name = value[2:value.index("}")] return self.config.get(var_name, value) return value ”
最新发布
08-29
根据你说的情况我在多提供给你该文件中的代码,在前面部分有提及with_type的设置问题 import csv import glob import os import re import cv2 import matplotlib.pyplot as plt import numpy as np import scipy.io as sio import torch.utils.data import imgaug as ia from imgaug import augmenters as iaa from misc.utils import cropping_center from .augs import ( add_to_brightness, add_to_contrast, add_to_hue, add_to_saturation, gaussian_blur, median_blur, ) #### class FileLoader(torch.utils.data.Dataset): """Data Loader. Loads images from a file list and performs augmentation with the albumentation library. After augmentation, horizontal and vertical maps are generated. Args: file_list: list of filenames to load input_shape: shape of the input [h,w] - defined in config.py mask_shape: shape of the output [h,w] - defined in config.py mode: 'train' or 'valid' """ # TODO: doc string def __init__( self, file_list, with_type=False, input_shape=None, mask_shape=None, mode="train", setup_augmentor=True, target_gen=None, ): assert input_shape is not None and mask_shape is not None self.mode = mode self.info_list = file_list self.with_type = with_type self.mask_shape = mask_shape self.input_shape = input_shape self.id = 0 self.target_gen_func = target_gen[0] self.target_gen_kwargs = target_gen[1] if setup_augmentor: self.setup_augmentor(0, 0) return def setup_augmentor(self, worker_id, seed): self.augmentor = self.__get_augmentation(self.mode, seed) self.shape_augs = iaa.Sequential(self.augmentor[0]) self.input_augs = iaa.Sequential(self.augmentor[1]) self.id = self.id + worker_id return def __len__(self): return len(self.info_list) def __getitem__(self, idx): path = self.info_list[idx] data = np.load(path) # split stacked channel into image and label img = (data[..., :3]).astype("uint8") # RGB images ann = (data[..., 3:]).astype("int32") # instance ID map and type map if self.shape_augs is not None: shape_augs = self.shape_augs.to_deterministic() img = shape_augs.augment_image(img) ann = shape_augs.augment_image(ann) if self.input_augs is not None: input_augs = self.input_augs.to_deterministic() img = input_augs.augment_image(img) img = cropping_center(img, self.input_shape) feed_dict = {"img": img} inst_map = ann[..., 0] # HW1 -&gt; HW if self.with_type: type_map = (ann[..., 1]).copy() type_map = cropping_center(type_map, self.mask_shape) #type_map[type_map == 5] = 1 # merge neoplastic and non-neoplastic feed_dict["tp_map"] = type_map # TODO: document hard coded assumption about #input target_dict = self.target_gen_func( inst_map, self.mask_shape, **self.target_gen_kwargs ) feed_dict.update(target_dict) return feed_dict
08-22
你还是把这个改好了 发完整版给我吧 你的修复我看不懂 弄得我不知道是直接替换 还是让我干什么“# E:\AI_System\core\config.py import os import sys import json import logging from pathlib import Path from dotenv import load_dotenv from prettytable import PrettyTable # 临时添加项目根目录到Python路径 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class CoreConfig: _instance = None @classmethod def get_instance(cls): """获取单例实例""" if cls._instance is None: cls._instance = cls() return cls._instance def __init__(self): """初始化配置系统""" # 设置日志 self.logger = logging.getLogger('CoreConfig') self.logger.setLevel(logging.INFO) # 确保有基本日志处理器 if not self.logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) # 设置基础目录 self.base_dir = Path(__file__).resolve().parent.parent self.env_prefix = "AI_SYSTEM" self.config = {} # 敏感字段列表(在日志和输出中掩码) self.sensitive_fields = ["DB_PASSWORD", "SECRET_KEY", "API_KEY", "ACCESS_TOKEN"] # 加载配置 self._load_config() # 直接调用方法,不要嵌套定义 self.logger.info("✅ 配置系统初始化完成") # 注意:所有方法都在类作用域内,不要嵌套定义 def _load_config(self): """加载所有配置""" # 1. 设置默认值 self._set_defaults() # 2. 加载配置文件 self._load_config_files() # 3. 加载环境变量 self._load_environment() # 4. 验证关键路径 self.validate_model_paths() def _mask_sensitive_value(self, key, value): """对敏感信息进行掩码处理""" if value and key in self.sensitive_fields: return "******" return value # ... 其他方法保持不变 ... def _log_sensitive_value(self, key, value): """在日志中安全地记录敏感信息""" if key in self.sensitive_fields: self.logger.info(f"🔄 环境变量覆盖: {key}=******") else: self.logger.info(f"🔄 环境变量覆盖: {key}={value}") def _set_defaults(self): """设置默认配置值""" # 系统路径配置 defaults = { "LOG_DIR": str(self.base_dir / "logs"), "CONFIG_DIR": str(self.base_dir / "config"), "MODEL_CACHE_DIR": str(self.base_dir / "model_cache"), "AGENT_NAME": "小蓝", "DEFAULT_USER": "管理员", "MAX_WORKERS": 4, "AGENT_RESPONSE_TIMEOUT": 30.0, # 模型路径配置 "MODEL_BASE_PATH": "E:/AI_Models", "TEXT_BASE": "E:/AI_Models/Qwen2-7B", "TEXT_CHAT": "E:/AI_Models/deepseek-7b-chat", "MULTIMODAL": "E:/AI_Models/deepseek-vl2", "IMAGE_GEN": "E:/AI_Models/sdxl", "YI_VL": "E:/AI_Models/yi-vl", "STABLE_DIFFUSION": "E:/AI_Models/stable-diffusion-xl-base-1", # 系统路径配置 "SYSTEM_ROOT": str(self.base_dir), "AGENT_DIR": str(self.base_dir / "agent"), "WEB_UI_DIR": str(self.base_dir / "web_ui"), "CORE_DIR": str(self.base_dir / "core"), "MODELS_DIR": str(self.base_dir / "models"), "LOGS_DIR": str(self.base_dir / "logs"), # 服务器配置 "HOST": "0.0.0.0", "FLASK_PORT": 8000, "GRADIO_PORT": 7860, # 数据库配置 "DB_HOST": "localhost", "DB_PORT": 5432, "DB_NAME": "ai_system", "DB_USER": "ai_user", "DB_PASSWORD": "", # 安全配置 "SECRET_KEY": "default-secret-key" } for key, value in defaults.items(): self.config[key] = value self.logger.debug(f"设置默认值: {key}={self._mask_sensitive_value(key, value)}") def _load_config_files(self): """加载配置文件""" # 确保配置目录存在 config_dir = self.base_dir / "config" config_dir.mkdir(exist_ok=True, parents=True) # 配置加载顺序 config_files = [ config_dir / 'default.json', config_dir / 'local.json' ] for config_file in config_files: if config_file.exists(): try: with open(config_file, 'r', encoding='utf-8') as f: config_data = json.load(f) # 掩码敏感信息 masked_data = {k: self._mask_sensitive_value(k, v) for k, v in config_data.items()} self.config.update(config_data) self.logger.info(f"📂 从 {config_file} 加载配置: {masked_data}") except Exception as e: self.logger.error(f"❌ 加载配置文件 {config_file} 错误: {str(e)}") else: self.logger.info(f"ℹ️ 配置文件不存在: {config_file},跳过") def _load_environment(self): """加载环境变量""" # 加载.env文件 env_file = self.base_dir / '.env' if env_file.exists(): try: # 加载.env文件 load_dotenv(dotenv_path=str(env_file), override=True) self.logger.info(f"🌐 从 {env_file} 加载环境变量") except Exception as e: self.logger.error(f"❌ 加载环境变量失败: {str(e)}") # 覆盖环境变量中的配置 for key in list(self.config.keys()): # 先尝试带前缀的环境变量 prefixed_key = f"{self.env_prefix}_{key}" env_value = os.getenv(prefixed_key) # 如果带前缀的环境变量不存在,尝试直接使用key if env_value is None: env_value = os.getenv(key) if env_value is not None: # 尝试转换数据类型 if env_value.lower() in ['true', 'false']: env_value = env_value.lower() == 'true' elif env_value.isdigit(): env_value = int(env_value) elif env_value.replace('.', '', 1).isdigit(): try: env_value = float(env_value) except ValueError: pass # 保持字符串 self.config[key] = env_value self._log_sensitive_value(key, env_value) def validate_model_paths(self): """验证所有模型路径是否存在""" model_keys = ["TEXT_BASE", "TEXT_CHAT", "MULTIMODAL", "IMAGE_GEN", "YI_VL", "STABLE_DIFFUSION"] results = {} for key in model_keys: path = self.get(key, "") if path: path_obj = Path(path) exists = path_obj.exists() results[key] = { "path": str(path_obj), "exists": exists } if not exists: self.logger.warning(f"⚠️ 模型路径不存在: {key} = {path}") else: results[key] = { "path": "", "exists": False } self.logger.warning(f"⚠️ 模型路径未配置: {key}") return results def get(self, key, default=None): """获取配置值""" return self.config.get(key, default) def __getitem__(self, key): """通过键访问配置值""" return self.config[key] def __contains__(self, key): """检查键是否存在""" return key in self.config def to_dict(self, mask_sensitive=True): """返回当前配置的字典表示""" if mask_sensitive: return {k: self._mask_sensitive_value(k, v) for k, v in self.config.items()} return self.config.copy() # 创建全局配置实例 config = CoreConfig.get_instance() # 测试代码 if __name__ == "__main__": # 设置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) print("=" * 50) print("配置系统测试") print("=" * 50) # 获取配置实例 config = CoreConfig.get_instance() # 打印基本配置 print(f"AGENT_NAME: {config.get('AGENT_NAME')}") print(f"SYSTEM_ROOT: {config.get('SYSTEM_ROOT')}") print(f"LOG_DIR: {config.get('LOG_DIR')}") print(f"AGENT_DIR: {config.get('AGENT_DIR')}") print(f"WEB_UI_DIR: {config.get('WEB_UI_DIR')}") print(f"DB_HOST: {config.get('DB_HOST')}") # 验证模型路径 print("\n模型路径验证结果:") for model, info in config.validate_model_paths().items(): status = "存在 ✅" if info["exists"] else "不存在 ❌" print(f"{model:20} {status} ({info['path']})") # 使用表格显示所有配置(美化输出) print("\n当前所有配置:") table = PrettyTable() table.field_names = ["配置项", "值"] table.align["配置项"] = "l" table.align["值"] = "l" # 获取掩码后的配置 masked_config = config.to_dict(mask_sensitive=True) for key, value in masked_config.items(): table.add_row([key, value]) print(table) print("\n测试完成!") ”还有这个 我不知道怎么添加,你下次最好把需要更改的文件 都弄成能直接复制粘贴的那种 你这么发 我看不明白 更改不明白 懂吗?“{ "ENV": "dev", "LOG_LEVEL": "DEBUG", "USE_GPU": false, "DEFAULT_MODEL": "minimal-model", "HOST": "0.0.0.0", "FLASK_PORT": 8000, "GRADIO_PORT": 7860, "DB_HOST": "localhost", "DB_PORT": 5432, "DB_NAME": "ai_system", "DB_USER": "ai_user" }”
08-24
评论 5
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值