# core/handler.py
import os
from typing import Dict, Any
from Progress.app import get_ai_assistant, get_task_executor, get_tts_engine
from Progress.utils.logger_utils import logger
assistant = get_ai_assistant()
executor = get_task_executor()
tts_engine = get_tts_engine()
def handle_user_input(user_text: str, source: str = "unknown") -> Dict[str, Any]:
"""
统一处理用户输入
:param user_text: 用户说的话
:param source: 来源 ('local', 'web', 'mobile', 'api')
:return: 结果字典
"""
if not user_text.strip():
return {
"success": False,
"response_to_user": "请输入有效内容",
"details": {}
}
try:
# AI 决策
decision = assistant.process_voice_command(user_text)
result = executor.execute_task_plan(decision)
reply = result.get("message", "操作完成。")
if not result.get("success") and not reply.startswith("抱歉"):
reply = f"抱歉,{reply}"
# 判断是否需要 TTS(仅本地设备)
should_play_tts = source in ["local", "raspberry", "desktop"]
if should_play_tts:
try:
audio_path = tts_engine.speak(reply)
logger.info(f"🔊 已播放语音: {audio_path}")
except Exception as e:
logger.warning(f"TTS 播放失败: {e}")
return {
"success": True,
"recognized_text": user_text,
"response_to_user": reply,
"details": result,
"tts_audio_url": f"/api/tts/audio?file={os.path.basename(audio_path)}" if should_play_tts else None,
"source": source
}
except Exception as e:
logger.exception("处理用户输入时出错")
return {
"success": False,
"response_to_user": "系统内部错误,请稍后再试。",
"error": str(e)
}
# main.py
import time
import signal
import threading
from typing import Any
from Progress.utils.logger_config import setup_logger
from Progress.app import get_tts_engine, get_voice_recognizer
from core.handler import handle_user_input
from database.config import config
logger = setup_logger("ai_assistant")
_shutdown_event = threading.Event()
def signal_handler(signum, frame):
logger.info(f"🛑 收到信号 {signum},准备退出...")
_shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def handle_single_interaction() -> bool:
rec = get_voice_recognizer()
text = rec.listen_and_recognize(timeout=3)
if _shutdown_event.is_set():
return False
if not text:
logger.info("🔇 未检测到语音")
return True
logger.info(f"🗣️ 用户说: '{text}'")
# 处理并自动播放 TTS(source='local' 触发播报)
result = handle_user_input(user_text=text, source="local")
expect_follow_up = result.get("details", {}).get("expect_follow_up", False)
rec.current_timeout = 8 if expect_follow_up else 3
should_exit = result.get("details", {}).get("should_exit", False)
return not should_exit
def main():
ENABLE_API_SERVER = config.get("app", "enable_api_server", default=True)
API_HOST = config.get("app", "api_host", default="127.0.0.1")
API_PORT = config.get("app", "api_port", default=5000)
logger.info("🚀 AI 助手启动中...")
# 可选:启动 API 服务
if ENABLE_API_SERVER:
try:
from api_server import APIServer
api_server = APIServer()
api_server.start()
logger.info(f"🌐 API 服务已启动: http://{API_HOST}:{API_PORT}")
except Exception as e:
logger.warning(f"⚠️ API 服务启动失败: {e}")
else:
logger.debug("🚫 API 服务已禁用 (ENABLE_API=false)")
logger.info("👂 助手已就绪,请开始说话...")
while not _shutdown_event.is_set():
try:
should_continue = handle_single_interaction()
if not should_continue:
break
except KeyboardInterrupt:
break
except Exception as e:
logger.exception("🔁 主循环异常")
time.sleep(1)
# 清理资源
try:
get_voice_recognizer().close()
except: pass
try:
get_tts_engine().stop()
except: pass
logger.info("👋 助手已退出")
if __name__ == "__main__":
main()
# api_server.py
import threading
import os
import base64
import time
from typing import Dict, Any
from flask import Flask, request, jsonify, send_file
from werkzeug.utils import secure_filename
from werkzeug.serving import make_server
from flask_cors import CORS
from Progress.utils.logger_utils import logger
from Progress.app import get_voice_recognizer
from core.handler import handle_user_input
from database.config import config
# =============================
# 配置加载(来自全局 config)
# =============================
ENABLE_API_SERVER = config.get("app", "enable_api_server", default=True)
API_HOST = config.get("app", "api_host", default="127.0.0.1")
API_PORT = config.get("app", "api_port", default=5000)
RUN_MODE = config.get("app", "run_mode", default="auto")
VOICE_RECOGNIZER_TIMEOUT = config.get("stt", "timeout", default=3)
TEMP_DIR = config.get("app", "temp_dir", default="temp_audio")
os.makedirs(TEMP_DIR, exist_ok=True)
# 初始化组件
recognizer = get_voice_recognizer()
# =============================
# 全局状态管理
# =============================
current_status = {
"is_listening": False,
"is_tts_playing": False,
"last_command_result": None,
"timestamp": int(time.time())
}
class APIServer:
"""
RESTful API 服务器,用于支持 Web / Mobile / IoT 设备远程接入 AI 助手。
所有请求最终交由 core.handler 统一处理,根据 source 决定行为(如是否播放 TTS)。
"""
def __init__(self):
self.app = Flask(__name__)
CORS(self.app) # 允许跨域请求
self.server = None
self.thread = None
self.running = False
self._add_routes()
logger.debug("🔧 APIServer 初始化完成")
def _update_status(self, **kwargs):
"""更新全局运行状态"""
current_status.update(kwargs)
current_status["timestamp"] = int(time.time())
def _determine_source(self) -> str:
"""
根据请求头判断客户端来源
返回: 'web', 'mobile', 'local', 'api'
"""
client_type = request.headers.get("X-Client-Type", "").lower().strip()
mapping = {
"web": ["web", "browser"],
"mobile": ["mobile", "android", "ios"],
"local": ["raspberry", "local-device", "pi", "desktop"]
}
for src, keywords in mapping.items():
if any(k in client_type for k in keywords):
return src
return "api"
def _should_play_tts(self, source: str) -> bool:
"""判断该来源是否需要触发本地 TTS 播放"""
return source == "local" # 仅本地物理设备自动播报
def _add_routes(self):
"""注册所有 API 路由"""
self._add_health_route()
self._add_status_route()
self._add_text_query_route()
self._add_voice_upload_route()
self._add_tts_audio_route()
def _add_health_route(self):
@self.app.route('/api/health', methods=['GET'])
def health():
return jsonify({
"status": "ok",
"mode": RUN_MODE,
"running": True,
"timestamp": int(time.time())
})
def _add_status_route(self):
@self.app.route('/api/status', methods=['GET'])
def status():
return jsonify(current_status.copy())
def _add_text_query_route(self):
@self.app.route('/api/text/query', methods=['POST'])
def text_query():
data: Dict = request.get_json() or {}
text = data.get("text", "").strip()
if not text:
return jsonify({"error": "缺少文本内容"}), 400
source = self._determine_source()
logger.info(f"📩 [{source}] 文本请求: '{text}'")
try:
# 统一处理(会根据 source 决定是否播放 TTS)
result = handle_user_input(user_text=text, source=source)
response_data = {
"success": result.get("success", False),
"response_to_user": result.get("response_to_user", ""),
}
# 若为本地设备且生成了音频,则返回 URL
if self._should_play_tts(source) and result.get("tts_audio_url"):
response_data["tts_audio_url"] = result["tts_audio_url"]
# 可选:附加细节(任务执行结果等)
details = result.get("details")
if details is not None:
response_data["details"] = details
return jsonify(response_data)
except Exception as e:
logger.exception(f"❌ 处理文本请求失败: {text}")
return jsonify({
"success": False,
"error": "内部服务错误",
"message": str(e)
}), 500
def _add_voice_upload_route(self):
@self.app.route('/api/voice/upload', methods=['POST'])
def voice_upload():
source = self._determine_source()
if not self._should_play_tts(source):
return jsonify({
"error": "语音上传功能仅限本地设备使用",
"hint": "请设置 Header: X-Client-Type: local"
}), 403
# 获取音频数据
audio_path = None
session_id = request.form.get('session_id', f"upload_{int(time.time())}")
try:
if 'file' in request.files:
file = request.files['file']
if not file.filename:
return jsonify({"error": "上传的文件名为空"}), 400
ext = os.path.splitext(file.filename)[1] or ".wav"
filename = secure_filename(f"{session_id}_{int(time.time())}{ext}")
file_path = os.path.join(TEMP_DIR, filename)
file.save(file_path)
audio_path = file_path
elif 'audio_base64' in request.form:
b64_str = request.form['audio_base64'].split(",")[-1]
raw_data = base64.b64decode(b64_str)
file_path = os.path.join(TEMP_DIR, f"{session_id}.wav")
with open(file_path, 'wb') as f:
f.write(raw_data)
audio_path = file_path
else:
return jsonify({"error": "请提供 'file' 或 'audio_base64' 字段"}), 400
# 开始语音识别
self._update_status(is_listening=True)
logger.debug(f"🎤 正在识别语音文件: {audio_path}")
try:
recognized_text = recognizer.listen_and_recognize(
audio_file=audio_path,
timeout=VOICE_RECOGNIZER_TIMEOUT
)
finally:
self._update_status(is_listening=False)
if not recognized_text:
logger.warning("⚠️ 语音识别未获取到有效文本")
return jsonify({
"success": False,
"error": "语音识别失败",
"response_to_user": "抱歉,我没听清,请再说一遍。"
}), 400
logger.info(f"👂 识别结果: '{recognized_text}'")
request.json = {"text": recognized_text} # 模拟 JSON 输入
return self.app.view_functions['text_query']() # 复用 text_query 逻辑
except Exception as e:
logger.exception("🎙️ 语音上传处理出错")
return jsonify({
"success": False,
"error": "语音处理异常",
"detail": str(e)
}), 500
finally:
# 可选:清理临时文件(也可后台定时清理)
if audio_path and os.path.exists(audio_path):
try:
os.remove(audio_path)
logger.debug(f"🗑️ 已删除临时语音文件: {audio_path}")
except:
pass
def _add_tts_audio_route(self):
@self.app.route('/api/tts/audio', methods=['GET'])
def tts_audio():
filename = request.args.get('file')
if not filename:
return jsonify({"error": "缺少参数 'file'"}), 400
file_path = os.path.join(TEMP_DIR, filename)
if not os.path.exists(file_path):
logger.warning(f"📁 请求的音频文件不存在: {file_path}")
return jsonify({"error": "文件不存在"}), 404
logger.debug(f"📥 下载 TTS 音频: {filename}")
return send_file(file_path, mimetype="audio/mpeg")
def start(self, host=None, port=None):
"""
启动 API 服务(非阻塞,运行在独立线程)
:param host: 绑定地址
:param port: 端口号
"""
host = host or API_HOST
port = port or API_PORT
if self.running:
logger.warning("⚠️ API 服务器已在运行,忽略重复启动")
return
def run():
try:
self.server = make_server(host, port, self.app)
logger.info(f"🌐 API 服务已启动 → http://{host}:{port} (模式: {RUN_MODE})")
self.running = True
self.server.serve_forever()
except Exception as e:
if self.running:
logger.error(f"🚨 API 服务意外终止: {e}")
else:
logger.debug("🛑 API 服务已正常关闭")
self.thread = threading.Thread(target=run, daemon=True)
self.thread.start()
def stop(self):
"""安全关闭 API 服务"""
if not self.running:
return
logger.info("🛑 正在关闭 API 服务...")
try:
self.server.shutdown()
except AttributeError:
logger.warning("⚠️ server 对象尚未初始化,跳过 shutdown")
except Exception as e:
logger.error(f"❌ shutdown 出错: {e}")
self.running = False
if self.thread:
self.thread.join(timeout=3)
if self.thread.is_alive():
logger.warning("⚠️ API 线程未能及时退出")
logger.info("✅ API 服务已关闭")
# ================
# 全局实例(单例)
# ================
_api_server_instance = None
def get_api_server() -> APIServer:
"""获取 API 服务单例"""
global _api_server_instance
if _api_server_instance is None:
_api_server_instance = APIServer()
return _api_server_instance
# 方便直接调用
__all__ = ['APIServer', 'get_api_server']
# database/config.py
import json
import os
import sys
from typing import Any, Dict, Optional
from pathlib import Path
from Progress.utils.logger_utils import logger
# 确保 Progress 包可导入
if 'Progress' not in sys.modules:
project_root = str(Path(__file__).parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
try:
import Progress
except ImportError as e:
print(f"⚠️ 无法导入 Progress 模块,请检查路径: {project_root}, 错误: {e}")
class ConfigManager:
def __init__(self):
from Progress.utils.resource_helper import get_internal_path, get_app_path
self.BASE_CONFIG_FILE = get_internal_path("database", "base_config.json")
self.CONFIG_FILE = os.path.join(get_app_path(), "config.json")
self.DEFAULT_CONFIG: Optional[Dict] = None
self.config = self.load_config()
self._watchers = {} # 监听器字典:key_path -> callback(old, new)
def watch(self, *keys, callback):
"""
注册一个监听器,当指定配置项变化时触发回调
:param keys: 配置路径,如 ("tts", "voice")
:param callback: 回调函数,接受两个参数 (old_value, new_value)
"""
key_path = ".".join(str(k) for k in keys)
self._watchers[key_path] = callback
logger.debug(f"👀 开始监听配置项: {key_path}")
def set(self, value, *keys):
"""
设置配置项,并触发变更通知。
示例: config.set("zh-CN-YunxiNeural", "tts", "voice")
注意:仅修改内存中的值,需调用 .save() 持久化。
"""
if not keys:
raise ValueError("必须指定至少一个键")
# 获取旧值
old_value = self.get(*keys)
# 安全设置新值(递归创建嵌套结构)
data = self.config
for k in keys[:-1]:
if k not in data or not isinstance(data[k], dict):
data[k] = {}
data = data[k]
current_key = keys[-1]
# 如果值未变,跳过以避免误触发回调
if current_key in data and data[current_key] == value:
return
data[current_key] = value
# 构造 key 路径用于查找 watcher
key_path = ".".join(str(k) for k in keys)
# 触发监听器
if key_path in self._watchers:
try:
self._watchers[key_path](old_value, value)
except Exception as e:
logger.error(f"❌ 执行监听回调失败 [{key_path}]: {e}")
def _load_default(self) -> Dict:
"""加载默认配置模板"""
if self.DEFAULT_CONFIG is None:
if not os.path.exists(self.BASE_CONFIG_FILE):
raise FileNotFoundError(f"❌ 默认配置文件不存在: {self.BASE_CONFIG_FILE}")
try:
with open(self.BASE_CONFIG_FILE, 'r', encoding='utf-8') as f:
self.DEFAULT_CONFIG = json.load(f)
except Exception as e:
raise RuntimeError(f"❌ 无法读取默认配置文件: {e}")
return self.DEFAULT_CONFIG.copy()
def load_config(self) -> Dict:
"""加载用户配置,若不存在则生成"""
if not os.path.exists(self.CONFIG_FILE):
print(f"🔧 配置文件 {self.CONFIG_FILE} 不存在,正在基于默认模板创建...")
default = self._load_default()
if self.save_config(default):
print(f"✅ 默认配置已生成: {self.CONFIG_FILE}")
else:
print(f"❌ 默认配置生成失败,请检查路径权限: {os.path.dirname(self.CONFIG_FILE)}")
return default
try:
with open(self.CONFIG_FILE, 'r', encoding='utf-8') as f:
user_config = json.load(f)
default = self._load_default()
merged = default.copy()
self.deep_update(merged, user_config)
return merged
except (json.JSONDecodeError, UnicodeDecodeError) as e:
print(f"⚠️ 配置文件格式错误或编码异常: {e}")
return self._recover_from_corrupted()
except PermissionError as e:
print(f"⚠️ 无权限读取配置文件: {e}")
return self._recover_from_corrupted()
except Exception as e:
print(f"⚠️ 未知错误导致配置加载失败: {type(e).__name__}: {e}")
return self._recover_from_corrupted()
def _recover_from_corrupted(self) -> Dict:
"""配置损坏时尝试备份并重建"""
backup_file = self.CONFIG_FILE + ".backup"
try:
if os.path.exists(self.CONFIG_FILE):
os.rename(self.CONFIG_FILE, backup_file)
print(f"📁 原始损坏配置已备份为: {backup_file}")
default = self._load_default()
self.save_config(default)
print(f"✅ 已使用默认配置重建 {self.CONFIG_FILE}")
return default
except Exception as e:
print(f"❌ 自动恢复失败: {e},将返回内存中默认配置")
return self._load_default()
def deep_update(self, default: Dict, override: Dict):
"""递归更新嵌套字典"""
for k, v in override.items():
if k in default and isinstance(default[k], dict) and isinstance(v, dict):
self.deep_update(default[k], v)
else:
default[k] = v
def save_config(self, config: Dict) -> bool:
"""保存当前配置到 config.json"""
try:
from Progress.utils.resource_helper import ensure_directory
config_dir = os.path.dirname(self.CONFIG_FILE)
if not ensure_directory(config_dir):
print(f"❌ 无法创建配置目录: {config_dir}")
return False
with open(self.CONFIG_FILE, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=4, ensure_ascii=False)
return True
except PermissionError:
print(f"❌ 权限不足,无法写入配置文件: {self.CONFIG_FILE}")
return False
except Exception as e:
print(f"❌ 保存配置失败: {type(e).__name__}: {e}")
return False
def get(self, *keys, default=None) -> Any:
"""
安全获取嵌套配置项
示例: config.get("ai_model", "api_key", default="none")
"""
data = self.config
try:
for k in keys:
data = data[k]
return data
except (KeyError, TypeError):
return default
def save(self) -> bool:
"""
将当前内存中的配置保存到文件
返回: 是否成功
"""
return self.save_config(self.config)
# 全局单例
config = ConfigManager()
from functools import wraps
import inspect
import logging
# 全局注册表
REGISTERED_FUNCTIONS = {}
FUNCTION_SCHEMA = []
FUNCTION_MAP = {} # (intent, action) -> method_name
logger = logging.getLogger("ai_assistant")
def ai_callable(
*,
description: str,
params: dict,
intent: str = None,
action: str = None,
concurrent: bool = False # 新增:是否允许并发执行
):
def decorator(func):
func_name = func.__name__
metadata = {
"func": func,
"description": description,
"params": params,
"intent": intent,
"action": action,
"signature": str(inspect.signature(func)),
"concurrent": concurrent # 记录是否可并发
}
REGISTERED_FUNCTIONS[func_name] = metadata
FUNCTION_SCHEMA.append({
"name": func_name,
"description": description,
"parameters": params
})
if intent and action:
key = (intent, action)
if key in FUNCTION_MAP:
raise ValueError(f"冲突:语义 ({intent}, {action}) 已被函数 {FUNCTION_MAP[key]} 占用")
FUNCTION_MAP[key] = func_name
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper._ai_metadata = metadata
return wrapper
return decorator
# Progress/utils/resource_helper.py
import os
import sys
from typing import Optional
def get_internal_path(*relative_path_parts) -> str:
"""
获取内部资源路径(如 base_config.json),适用于开发和打包环境。
示例: get_internal_path("database", "base_config.json")
"""
if getattr(sys, 'frozen', False):
base_path = sys._MEIPASS
else:
# __file__ → Progress/utils/resource_helper.py
current_dir = os.path.dirname(os.path.abspath(__file__))
progress_root = os.path.dirname(current_dir) # Progress/
project_root = os.path.dirname(progress_root) # AI_Manager/
base_path = project_root
return os.path.join(base_path, *relative_path_parts)
def get_app_path() -> str:
"""
获取应用运行数据保存路径(config.json、日志等)
打包后:exe 所在目录
开发时:AI_Manager/ 根目录
"""
if getattr(sys, 'frozen', False):
return os.path.dirname(sys.executable)
else:
current_dir = os.path.dirname(os.path.abspath(__file__))
progress_root = os.path.dirname(current_dir)
project_root = os.path.dirname(progress_root)
return project_root
def resource_path(*sub_paths: str, base_key: str = "resource_path") -> str:
"""
通用用户资源路径解析(基于 config 的 resource_path)
示例: resource_path("Music", "bgm.mp3") → <resource_path>/Music/bgm.mp3
:param sub_paths: 子路径组件
:param base_key: 在 config.paths 中的键名,默认 "resource_path"
"""
# 延迟导入,避免循环依赖
from database.config import config
raw_base = config.get("paths", base_key)
if not raw_base:
raise ValueError(f"配置项 paths.{base_key} 未设置")
if os.path.isabs(raw_base):
base_path = raw_base
else:
base_path = os.path.join(get_app_path(), raw_base)
full_path = os.path.normpath(base_path)
for part in sub_paths:
clean_part = str(part).strip().lstrip(r'./\ ')
for p in clean_part.replace('\\', '/').split('/'):
if p:
full_path = os.path.join(full_path, p)
return os.path.normpath(full_path)
def ensure_directory(path: str) -> bool:
"""
确保目录存在。若 path 是文件路径,则创建其父目录。
"""
dir_path = path
basename = os.path.basename(dir_path)
if '.' in basename and len(basename) > 1 and not basename.startswith('.'):
dir_path = os.path.dirname(path)
if not dir_path or dir_path in ('.', './', '..'):
return True
try:
os.makedirs(dir_path, exist_ok=True)
return True
except PermissionError:
print(f"❌ 权限不足,无法创建目录: {dir_path}")
return False
except Exception as e:
print(f"❌ 创建目录失败: {dir_path}, 错误: {type(e).__name__}: {e}")
return False
检查一下,最后做个整理