import os
import time
import numpy as np
import pyaudio
import requests
import threading
import wave
import re
import io
import queue
import torch
import torchaudio
from io import BytesIO
from datetime import datetime
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
# 全局线程池
global_executor = ThreadPoolExecutor(max_workers=4)
class Config:
def __init__(self):
self.sampling_rate = 16000
# 语音检测参数
self.vad_threshold = 0.05
self.chunk_size = 2048
self.silence_threshold = 0.01
self.min_speech_duration = 0.6
self.silence_duration = 1.0
# 打断参数
self.interrupt_frames = 20
self.max_audio_length = 30
class FunASRProcessor:
def __init__(self, config):
self.config = config
def recognize(self, audio_data):
audio_bytes = BytesIO()
with wave.open(audio_bytes, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.config.sampling_rate)
wf.writeframes(b''.join(audio_data))
audio_bytes.seek(0)
try:
url = "http://0.0.0.0:8000/transcribe/"
files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
# 记录开始时间
start_time = time.time()
response = requests.post(url, files=files)
# 记录结束时间
end_time = time.time()
# 计算耗时
elapsed_time = end_time - start_time
# 计算数据量
data_size = len(audio_bytes.getvalue())
# 计算传输速度
speed = data_size / elapsed_time if elapsed_time > 0 else 0
# 打印耗时和速度
print(f"ASR API 调用耗时: {elapsed_time:.2f} 秒, 数据量: {data_size} 字节, 传输速度: {speed:.2f} 字节/秒")
if response.status_code == 200:
result = response.json()
text = result.get("transcription", "")
print("识别结果:" + text)
return text
else:
print(f"API 调用失败: {response.status_code} - {response.text}")
return ""
except Exception as e:
print(f"ASR 识别失败: {e}")
return ""
class LLMProcessor:
def __init__(self):
self.url = "http://0.0.0.0:8080/v1/chat/completions"
self.headers = {"Content-Type": "application/json", "Authorization": "Bearer sk-no-key-required"}
self.system_prompt = "你是一个希望助手,回答简洁不超过200字。"
self.history = [{"role": "system", "content": self.system_prompt}]
def get_response(self, text):
if not text.strip():
return "抱歉,我没有听清楚你说什么。"
self.history.append({"role": "user", "content": text})
payload = {"model": "Qwen2.5-7B-Instruct", "messages": self.history}
try:
# 记录开始时间
start_time = time.time()
response = requests.post(self.url, headers=self.headers, json=payload)
# 记录结束时间
end_time = time.time()
# 计算耗时
elapsed_time = end_time - start_time
# 计算数据量
data_size = len(response.content)
# 计算传输速度
speed = data_size / elapsed_time if elapsed_time > 0 else 0
# 打印耗时和速度
print(f"LLM API 调用耗时: {elapsed_time:.2f} 秒, 数据量: {data_size} 字节, 传输速度: {speed:.2f} 字节/秒")
if response.status_code == 200:
assistant_message = response.json()["choices"][0]["message"]["content"]
self.history.append({"role": "assistant", "content": assistant_message})
print("模型回答:" + assistant_message)
return assistant_message
return "抱歉,我现在不太方便回答。"
except Exception as e:
print(f"LLM processing failed: {e}")
return "抱歉,我遇到了一些技术问题。"
class GPTSoVITSClient:
def __init__(self):
self.config = Config()
self.url = "http://0.0.0.0:50000/inference_sft"
self.prompt_sr, self.target_sr = 16000, 22050
self.segment_queue = queue.Queue()
self.tts_thread = None
self.processing = False
self._lock = threading.Lock()
self.executor = global_executor
def split_text(self, text):
segments = re.split(r'(?<=[。!?.!?])\s*', text)
return [segment.strip() for segment in segments if segment.strip()]
def text_to_speech(self, text):
segments = self.split_text(text)
with self._lock:
while not self.segment_queue.empty():
try:
self.segment_queue.get_nowait()
except queue.Empty:
break
for segment in segments:
self.segment_queue.put(segment)
if not self.processing:
self.start_processing()
def start_processing(self):
self.processing = True
self.executor.submit(self._process_segments)
def _process_segments(self):
try:
while True:
try:
segment = self.segment_queue.get_nowait()
except queue.Empty:
break
try:
payload = {
'tts_text': segment,
'spk_id': '中文女'
}
# 记录开始时间
start_time = time.time()
response = requests.request("GET", self.url, data=payload, stream=True)
# 记录结束时间
end_time = time.time()
# 计算耗时
elapsed_time = end_time - start_time
# 计算数据量
data_size = 0
tts_audio = b''
for r in response.iter_content(chunk_size=16000):
tts_audio += r
data_size += len(r)
# 计算传输速度
speed = data_size / elapsed_time if elapsed_time > 0 else 0
# 打印耗时和速度
print(f"TTS API 调用耗时: {elapsed_time:.2f} 秒, 数据量: {data_size} 字节, 传输速度: {speed:.2f} 字节/秒")
tts_speech = torch.frombuffer(tts_audio, dtype=torch.int16).unsqueeze(dim=0)
audio_buffer = io.BytesIO()
torchaudio.save(audio_buffer, tts_speech, self.target_sr, format='wav')
audio_buffer.seek(0)
audio_player.play_audio(audio_buffer)
while audio_player.playing and not audio_player.stop_flag.is_set():
time.sleep(0.1)
except Exception as e:
print(f"Error processing TTS segment: {e}")
continue
finally:
self.segment_queue.task_done()
finally:
self.processing = False
class AudioPlayer:
def __init__(self):
self.playing = False
self.stop_flag = threading.Event()
self.play_thread = None
self._lock = threading.Lock()
self.pyaudio_instance = pyaudio.PyAudio()
def play_audio(self, audio_buffer):
with self._lock:
if self.playing:
self.stop_playing()
time.sleep(0.1)
self.stop_flag.clear()
self.playing = True
def play_thread():
try:
chunk = 2048
with wave.open(audio_buffer, 'rb') as wf:
stream = self.pyaudio_instance.open(
format=self.pyaudio_instance.get_format_from_width(wf.getsampwidth()),
channels=wf.getnchannels(),
rate=wf.getframerate(),
output=True
)
data = wf.readframes(chunk)
while data and not self.stop_flag.is_set():
stream.write(data)
data = wf.readframes(chunk)
stream.stop_stream()
stream.close()
except Exception as e:
print(f"Error playing audio: {e}")
finally:
with self._lock:
self.playing = False
self.play_thread = threading.Thread(target=play_thread)
self.play_thread.daemon = True
self.play_thread.start()
def stop_playing(self):
with self._lock:
if self.playing:
self.stop_flag.set()
if self.play_thread and self.play_thread.is_alive():
self.play_thread.join(timeout=1.0)
self.playing = False
def __del__(self):
self.pyaudio_instance.terminate()
class EnhancedVoiceDetector:
def __init__(self, config):
self.threshold = config.vad_threshold
self.silence_threshold = config.silence_threshold
self.min_speech_frames = int(config.min_speech_duration * config.sampling_rate / config.chunk_size)
self.silence_frames = int(config.silence_duration * config.sampling_rate / config.chunk_size)
self.speech_frames_counter = 0
self.silence_frames_counter = 0
self.is_speaking = False
def calculate_energy(self, audio_data):
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
energy = np.sqrt(np.mean(np.square(audio_array)))
return energy
def is_voice_detected(self, audio_data):
current_energy = self.calculate_energy(audio_data)
if current_energy > self.threshold:
self.speech_frames_counter += 1
self.silence_frames_counter = 0
if self.speech_frames_counter >= self.min_speech_frames:
self.is_speaking = True
else:
self.silence_frames_counter += 1
if self.silence_frames_counter >= self.silence_frames:
self.speech_frames_counter = 0
self.is_speaking = False
return self.is_speaking
class InteractiveVoiceAssistant:
def __init__(self):
self.config = Config()
self.asr_processor = FunASRProcessor(self.config)
self.llm_processor = LLMProcessor()
self.tts_client = GPTSoVITSClient()
self.audio_player = AudioPlayer()
self.voice_detector = EnhancedVoiceDetector(self.config)
self.audio_queue = queue.Queue()
self.is_processing = False
self.max_buffer_size = int(self.config.max_audio_length * self.config.sampling_rate / self.config.chunk_size)
self.interrupt_flag = threading.Event()
self.processing_thread = None
self._lock = threading.Lock()
self.executor = global_executor
def process_audio(self):
while True:
try:
if not self.is_processing:
audio_frames = []
while not self.audio_queue.empty():
segment = self.audio_queue.get_nowait()
if isinstance(segment, list):
audio_frames.extend(segment)
else:
audio_frames.append(segment)
if audio_frames:
with self._lock:
self.is_processing = True
if self.audio_player.playing:
print("检测到打断,停止当前播放")
self.audio_player.stop_playing()
# 记录停止播放的时间
stop_time = time.time()
print(f"停止播放时间: {stop_time}")
text = self.asr_processor.recognize(audio_frames)
print(f"识别文本: {text}")
if text.strip():
response = self.llm_processor.get_response(text)
print(f"生成回复: {response}")
self.tts_client.text_to_speech(response)
with self._lock:
self.is_processing = False
self.interrupt_flag.clear()
except Exception as e:
print(f"处理音频时出错: {e}")
with self._lock:
self.is_processing = False
time.sleep(0.1)
def run(self):
try:
global audio_player
audio_player = self.audio_player
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16,
channels=1,
rate=self.config.sampling_rate,
input=True,
frames_per_buffer=self.config.chunk_size)
self.processing_thread = threading.Thread(target=self.process_audio)
self.processing_thread.daemon = True
self.processing_thread.start()
print("开始录音,按Ctrl+C停止... (说话时可以打断当前播放)")
audio_buffer = []
consecutive_voice_frames = 0
try:
while True:
data = stream.read(self.config.chunk_size, exception_on_overflow=False)
if self.voice_detector.is_voice_detected(data):
audio_buffer.append(data)
consecutive_voice_frames += 1
if consecutive_voice_frames > self.config.interrupt_frames and self.audio_player.playing:
print("检测到用户打断")
self.interrupt_flag.set()
# 记录检测到打断的时间
interrupt_time = time.time()
print(f"检测到打断时间: {interrupt_time}")
self.audio_player.stop_playing()
# 将当前audio_buffer放入队列,然后清空audio_buffer
self.audio_queue.put(audio_buffer.copy())
audio_buffer = []
consecutive_voice_frames = 0
else:
consecutive_voice_frames = 0
if audio_buffer:
if len(audio_buffer) > self.voice_detector.min_speech_frames:
self.audio_queue.put(audio_buffer.copy())
audio_buffer = []
except KeyboardInterrupt:
print("\n录音结束")
finally:
stream.stop_stream()
stream.close()
p.terminate()
self.audio_player.stop_playing()
except Exception as e:
print(f"运行时出错: {e}")
raise
def main():
assistant = InteractiveVoiceAssistant()
assistant.run()
if __name__ == "__main__":
main()