测速提升性能

import os
import time
import numpy as np
import pyaudio
import requests
import threading
import wave
import re
import io
import queue
import torch
import torchaudio
from io import BytesIO
from datetime import datetime
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

# 全局线程池
global_executor = ThreadPoolExecutor(max_workers=4)

class Config:
    def __init__(self):
        self.sampling_rate = 16000
        # 语音检测参数
        self.vad_threshold = 0.05
        self.chunk_size = 2048
        self.silence_threshold = 0.01
        self.min_speech_duration = 0.6
        self.silence_duration = 1.0
        # 打断参数
        self.interrupt_frames = 20
        self.max_audio_length = 30

class FunASRProcessor:
    def __init__(self, config):
        self.config = config

    def recognize(self, audio_data):
        audio_bytes = BytesIO()
        with wave.open(audio_bytes, 'wb') as wf:
            wf.setnchannels(1)
            wf.setsampwidth(2)
            wf.setframerate(self.config.sampling_rate)
            wf.writeframes(b''.join(audio_data))
        audio_bytes.seek(0)

        try:
            url = "http://0.0.0.0:8000/transcribe/"
            files = {"file": ("audio.wav", audio_bytes, "audio/wav")}
            
            # 记录开始时间
            start_time = time.time()
            
            response = requests.post(url, files=files)
            
            # 记录结束时间
            end_time = time.time()
            
            # 计算耗时
            elapsed_time = end_time - start_time
            # 计算数据量
            data_size = len(audio_bytes.getvalue())
            # 计算传输速度
            speed = data_size / elapsed_time if elapsed_time > 0 else 0
            
            # 打印耗时和速度
            print(f"ASR API 调用耗时: {elapsed_time:.2f} 秒, 数据量: {data_size} 字节, 传输速度: {speed:.2f} 字节/秒")

            if response.status_code == 200:
                result = response.json()
                text = result.get("transcription", "")
                print("识别结果:" + text)
                return text
            else:
                print(f"API 调用失败: {response.status_code} - {response.text}")
                return ""
        except Exception as e:
            print(f"ASR 识别失败: {e}")
            return ""

class LLMProcessor:
    def __init__(self):
        self.url = "http://0.0.0.0:8080/v1/chat/completions"
        self.headers = {"Content-Type": "application/json", "Authorization": "Bearer sk-no-key-required"}
        self.system_prompt = "你是一个希望助手,回答简洁不超过200字。"
        self.history = [{"role": "system", "content": self.system_prompt}]

    def get_response(self, text):
        if not text.strip():
            return "抱歉,我没有听清楚你说什么。"

        self.history.append({"role": "user", "content": text})
        payload = {"model": "Qwen2.5-7B-Instruct", "messages": self.history}

        try:
            # 记录开始时间
            start_time = time.time()
            
            response = requests.post(self.url, headers=self.headers, json=payload)
            
            # 记录结束时间
            end_time = time.time()
            
            # 计算耗时
            elapsed_time = end_time - start_time
            # 计算数据量
            data_size = len(response.content)
            # 计算传输速度
            speed = data_size / elapsed_time if elapsed_time > 0 else 0
            
            # 打印耗时和速度
            print(f"LLM API 调用耗时: {elapsed_time:.2f} 秒, 数据量: {data_size} 字节, 传输速度: {speed:.2f} 字节/秒")

            if response.status_code == 200:
                assistant_message = response.json()["choices"][0]["message"]["content"]
                self.history.append({"role": "assistant", "content": assistant_message})
                print("模型回答:" + assistant_message)
                return assistant_message
            return "抱歉,我现在不太方便回答。"
        except Exception as e:
            print(f"LLM processing failed: {e}")
            return "抱歉,我遇到了一些技术问题。"

class GPTSoVITSClient:
    def __init__(self):
        self.config = Config()
        self.url = "http://0.0.0.0:50000/inference_sft"
        self.prompt_sr, self.target_sr = 16000, 22050
        self.segment_queue = queue.Queue()
        self.tts_thread = None
        self.processing = False
        self._lock = threading.Lock()
        self.executor = global_executor

    def split_text(self, text):
        segments = re.split(r'(?<=[。!?.!?])\s*', text)
        return [segment.strip() for segment in segments if segment.strip()]

    def text_to_speech(self, text):
        segments = self.split_text(text)
        with self._lock:
            while not self.segment_queue.empty():
                try:
                    self.segment_queue.get_nowait()
                except queue.Empty:
                    break
            for segment in segments:
                self.segment_queue.put(segment)
        if not self.processing:
            self.start_processing()

    def start_processing(self):
        self.processing = True
        self.executor.submit(self._process_segments)

    def _process_segments(self):
        try:
            while True:
                try:
                    segment = self.segment_queue.get_nowait()
                except queue.Empty:
                    break

                try:
                    payload = {
                        'tts_text': segment,
                        'spk_id': '中文女'
                    }
                    
                    # 记录开始时间
                    start_time = time.time()
                    
                    response = requests.request("GET", self.url, data=payload, stream=True)
                    
                    # 记录结束时间
                    end_time = time.time()
                    
                    # 计算耗时
                    elapsed_time = end_time - start_time
                    # 计算数据量
                    data_size = 0
                    tts_audio = b''
                    for r in response.iter_content(chunk_size=16000):
                        tts_audio += r
                        data_size += len(r)
                    
                    # 计算传输速度
                    speed = data_size / elapsed_time if elapsed_time > 0 else 0
                    
                    # 打印耗时和速度
                    print(f"TTS API 调用耗时: {elapsed_time:.2f} 秒, 数据量: {data_size} 字节, 传输速度: {speed:.2f} 字节/秒")

                    tts_speech = torch.frombuffer(tts_audio, dtype=torch.int16).unsqueeze(dim=0)

                    audio_buffer = io.BytesIO()
                    torchaudio.save(audio_buffer, tts_speech, self.target_sr, format='wav')
                    audio_buffer.seek(0)

                    audio_player.play_audio(audio_buffer)

                    while audio_player.playing and not audio_player.stop_flag.is_set():
                        time.sleep(0.1)

                except Exception as e:
                    print(f"Error processing TTS segment: {e}")
                    continue
                finally:
                    self.segment_queue.task_done()
        finally:
            self.processing = False

class AudioPlayer:
    def __init__(self):
        self.playing = False
        self.stop_flag = threading.Event()
        self.play_thread = None
        self._lock = threading.Lock()
        self.pyaudio_instance = pyaudio.PyAudio()

    def play_audio(self, audio_buffer):
        with self._lock:
            if self.playing:
                self.stop_playing()
                time.sleep(0.1)

            self.stop_flag.clear()
            self.playing = True

            def play_thread():
                try:
                    chunk = 2048
                    with wave.open(audio_buffer, 'rb') as wf:
                        stream = self.pyaudio_instance.open(
                            format=self.pyaudio_instance.get_format_from_width(wf.getsampwidth()),
                            channels=wf.getnchannels(),
                            rate=wf.getframerate(),
                            output=True
                        )

                        data = wf.readframes(chunk)
                        while data and not self.stop_flag.is_set():
                            stream.write(data)
                            data = wf.readframes(chunk)

                        stream.stop_stream()
                        stream.close()
                except Exception as e:
                    print(f"Error playing audio: {e}")
                finally:
                    with self._lock:
                        self.playing = False

            self.play_thread = threading.Thread(target=play_thread)
            self.play_thread.daemon = True
            self.play_thread.start()

    def stop_playing(self):
        with self._lock:
            if self.playing:
                self.stop_flag.set()
                if self.play_thread and self.play_thread.is_alive():
                    self.play_thread.join(timeout=1.0)
                self.playing = False

    def __del__(self):
        self.pyaudio_instance.terminate()

class EnhancedVoiceDetector:
    def __init__(self, config):
        self.threshold = config.vad_threshold
        self.silence_threshold = config.silence_threshold
        self.min_speech_frames = int(config.min_speech_duration * config.sampling_rate / config.chunk_size)
        self.silence_frames = int(config.silence_duration * config.sampling_rate / config.chunk_size)
        self.speech_frames_counter = 0
        self.silence_frames_counter = 0
        self.is_speaking = False

    def calculate_energy(self, audio_data):
        audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
        energy = np.sqrt(np.mean(np.square(audio_array)))
        return energy

    def is_voice_detected(self, audio_data):
        current_energy = self.calculate_energy(audio_data)

        if current_energy > self.threshold:
            self.speech_frames_counter += 1
            self.silence_frames_counter = 0
            if self.speech_frames_counter >= self.min_speech_frames:
                self.is_speaking = True
        else:
            self.silence_frames_counter += 1
            if self.silence_frames_counter >= self.silence_frames:
                self.speech_frames_counter = 0
                self.is_speaking = False

        return self.is_speaking

class InteractiveVoiceAssistant:
    def __init__(self):
        self.config = Config()
        self.asr_processor = FunASRProcessor(self.config)
        self.llm_processor = LLMProcessor()
        self.tts_client = GPTSoVITSClient()
        self.audio_player = AudioPlayer()
        self.voice_detector = EnhancedVoiceDetector(self.config)
        self.audio_queue = queue.Queue()
        self.is_processing = False
        self.max_buffer_size = int(self.config.max_audio_length * self.config.sampling_rate / self.config.chunk_size)
        self.interrupt_flag = threading.Event()
        self.processing_thread = None
        self._lock = threading.Lock()
        self.executor = global_executor

    def process_audio(self):
        while True:
            try:
                if not self.is_processing:
                    audio_frames = []
                    while not self.audio_queue.empty():
                        segment = self.audio_queue.get_nowait()
                        if isinstance(segment, list):
                            audio_frames.extend(segment)
                        else:
                            audio_frames.append(segment)

                    if audio_frames:
                        with self._lock:
                            self.is_processing = True

                        if self.audio_player.playing:
                            print("检测到打断,停止当前播放")
                            self.audio_player.stop_playing()
                            # 记录停止播放的时间
                            stop_time = time.time()
                            print(f"停止播放时间: {stop_time}")

                        text = self.asr_processor.recognize(audio_frames)
                        print(f"识别文本: {text}")

                        if text.strip():
                            response = self.llm_processor.get_response(text)
                            print(f"生成回复: {response}")
                            self.tts_client.text_to_speech(response)

                        with self._lock:
                            self.is_processing = False
                        self.interrupt_flag.clear()
            except Exception as e:
                print(f"处理音频时出错: {e}")
                with self._lock:
                    self.is_processing = False
            time.sleep(0.1)

    def run(self):
        try:
            global audio_player
            audio_player = self.audio_player

            p = pyaudio.PyAudio()
            stream = p.open(format=pyaudio.paInt16,
                            channels=1,
                            rate=self.config.sampling_rate,
                            input=True,
                            frames_per_buffer=self.config.chunk_size)

            self.processing_thread = threading.Thread(target=self.process_audio)
            self.processing_thread.daemon = True
            self.processing_thread.start()

            print("开始录音,按Ctrl+C停止... (说话时可以打断当前播放)")
            audio_buffer = []
            consecutive_voice_frames = 0
            try:
                while True:
                    data = stream.read(self.config.chunk_size, exception_on_overflow=False)

                    if self.voice_detector.is_voice_detected(data):
                        audio_buffer.append(data)
                        consecutive_voice_frames += 1

                        if consecutive_voice_frames > self.config.interrupt_frames and self.audio_player.playing:
                            print("检测到用户打断")
                            self.interrupt_flag.set()
                            # 记录检测到打断的时间
                            interrupt_time = time.time()
                            print(f"检测到打断时间: {interrupt_time}")
                            self.audio_player.stop_playing()
                            # 将当前audio_buffer放入队列,然后清空audio_buffer
                            self.audio_queue.put(audio_buffer.copy())
                            audio_buffer = []
                            consecutive_voice_frames = 0
                    else:
                        consecutive_voice_frames = 0
                        if audio_buffer:
                            if len(audio_buffer) > self.voice_detector.min_speech_frames:
                                self.audio_queue.put(audio_buffer.copy())
                            audio_buffer = []

            except KeyboardInterrupt:
                print("\n录音结束")
            finally:
                stream.stop_stream()
                stream.close()
                p.terminate()
                self.audio_player.stop_playing()

        except Exception as e:
            print(f"运行时出错: {e}")
            raise

def main():
    assistant = InteractiveVoiceAssistant()
    assistant.run()

if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值