zipformer语音识别的部署

一、简介

使用k2 zipformer流媒体的中英文ASR模型。

本例中使用的模型来自以下开源项目:

https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t

重新做了一下环境

二、装环境

1、建立一个单独的环境flask

conda create -n flask python==3.10 -y

conda activate flask

2、装pytorch

pip3 install torch==2.5.0 torchvision==0.20.0 -i https://pypi.mirrors.ustc.edu.cn/simple/ 

3、其它环境

pip install soundfile
pip install flask
pip install sounddevice

4、如果装报错

pip install pyaudio

Failed to build pyaudio

conda install pyaudio

5、报错

Traceback (most recent call last):
  File "/home/orangepi/work_11.15/rknn_model_zoo/examples/zipformer/python/zipformer.py", line 4, in <module>
    import kaldifeat

 解决方法:安装kaldifeat模块包

进入官网From pre-compiled wheels (Recommended) — kaldifeat 1.25.5 documentation

python3 setup.py install

三、运行

1、进入环境

查看文件都有什么主要模型文件名字和一个1.wav测试音频

cd /home/sxj/zipformer

conda activate flask

2、运行

python app.py

服务端

import onnxruntime
from rknn.api import RKNN
import torch
import kaldifeat
import soundfile as sf
import numpy as np
import scipy
import os

import tempfile

from flask import Flask, request, jsonify

app = Flask(__name__)

# 这里可以放其他代码(BaseModel, OnnxModel, RKNNModel等)

class BaseModel():
    def __init__(self):
        self.model_config = {'x': [1, 103, 80], 'cached_len_0': [2, 1], 'cached_len_1': [2, 1], 'cached_len_2': [2, 1],
                             'cached_len_3': [2, 1],
                             'cached_len_4': [2, 1], 'cached_avg_0': [2, 1, 256], 'cached_avg_1': [2, 1, 256],
                             'cached_avg_2': [2, 1, 256],
                             'cached_avg_3': [2, 1, 256], 'cached_avg_4': [2, 1, 256], 'cached_key_0': [2, 192, 1, 192],
                             'cached_key_1': [2, 96, 1, 192],
                             'cached_key_2': [2, 48, 1, 192], 'cached_key_3': [2, 24, 1, 192],
                             'cached_key_4': [2, 96, 1, 192], 'cached_val_0': [2, 192, 1, 96],
                             'cached_val_1': [2, 96, 1, 96], 'cached_val_2': [2, 48, 1, 96],
                             'cached_val_3': [2, 24, 1, 96], 'cached_val_4': [2, 96, 1, 96],
                             'cached_val2_0': [2, 192, 1, 96], 'cached_val2_1': [2, 96, 1, 96],
                             'cached_val2_2': [2, 48, 1, 96], 'cached_val2_3': [2, 24, 1, 96],
                             'cached_val2_4': [2, 96, 1, 96], 'cached_conv1_0': [2, 1, 256, 30],
                             'cached_conv1_1': [2, 1, 256, 30], 'cached_conv1_2': [2, 1, 256, 30],
                             'cached_conv1_3': [2, 1, 256, 30], 'cached_conv1_4': [2, 1, 256, 30],
                             'cached_conv2_0': [2, 1, 256, 30], 'cached_conv2_1': [2, 1, 256, 30],
                             'cached_conv2_2': [2, 1, 256, 30], 'cached_conv2_3': [2, 1, 256, 30],
                             'cached_conv2_4': [2, 1, 256, 30]}

    def init_encoder_input(self):
        self.encoder_input = []
        self.encoder_input_dict = {}
        for input_name in self.model_config:
            if 'cached_len' in input_name:
                data = np.zeros((self.model_config[input_name]), dtype=np.int64)
                self.encoder_input.append(data)
                self.encoder_input_dict[input_name] = data
            else:
                data = np.zeros((self.model_config[input_name]), dtype=np.float32)
                self.encoder_input.append(data)
                self.encoder_input_dict[input_name] = data

    def update_encoder_input(self, out, model_type):
        for idx, input_name in enumerate(self.encoder_input_dict):
            if idx == 0:
                continue
            if idx > 10 and model_type == 'rknn':
                data = self.convert_nchw_to_nhwc(out[idx])
            else:
                data = out[idx]
            self.encoder_input[idx] = data
            self.encoder_input_dict[input_name] = data

    def convert_nchw_to_nhwc(self, src):
        dst = np.transpose(src, (0, 2, 3, 1))
        return dst

    def init_model(self, model_path, target, device_id):
        pass

    def release_model(self):
        pass

    def run_encoder(self, x):
        pass

    def run_decoder(self, decoder_input):
        pass

    def run_joiner(self, encoder_out, decoder_out):
        pass

    def run_greedy_search(self, frames, context_size, decoder_out, hyp, num_processed_frames, timestamp, frame_offset):
        encoder_out = self.run_encoder(frames)
        encoder_out = encoder_out.squeeze(0)

        blank_id = 0
        unk_id = 2
        if decoder_out is None and hyp is None:
            hyp = [blank_id] * context_size
            decoder_input = np.array([hyp], dtype=np.int64)
            decoder_out = self.run_decoder(decoder_input)

        T = encoder_out.shape[0]
        for t in range(T):
            cur_encoder_out = encoder_out[t: t + 1]
            joiner_out = self.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
            y = np.argmax(joiner_out, axis=0)
            if y != blank_id and y != unk_id:
                timestamp.append(frame_offset + t)
                hyp.append(y)
                decoder_input = hyp[-context_size:]
                decoder_input = np.array([decoder_input], dtype=np.int64)
                decoder_out = self.run_decoder(decoder_input)
        frame_offset += T

        return hyp, decoder_out, timestamp, frame_offset


class OnnxModel(BaseModel):
    def __init__(
            self,
            encoder_model_path,
            decoder_model_path,
            joiner_model_path,
            target,
            device_id
    ):
        super().__init__()

        self.encoder = self.init_model(encoder_model_path, target, device_id)
        self.decoder = self.init_model(decoder_model_path, target, device_id)
        self.joiner = self.init_model(joiner_model_path, target, device_id)

    def init_model(self, model_path, target, device_id):
        model = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
        return model

    def release_model(self):
        del self.encoder
        del self.decoder
        del self.joiner
        self.encoder = None
        self.decoder = None
        self.joiner = None

    def run_encoder(self, x):
        self.encoder_input[0] = x.numpy()
        self.encoder_input_dict['x'] = x.numpy()
        out = self.encoder.run(None, self.encoder_input_dict)
        self.update_encoder_input(out, 'onnx')
        return out[0]

    def run_decoder(self, decoder_input):
        out = self.decoder.run(None, {self.decoder.get_inputs()[0].name: decoder_input})[0]
        return out

    def run_joiner(self, encoder_out, decoder_out):
        out = self.joiner.run(None, {self.joiner.get_inputs()[0].name: encoder_out,
                                     self.joiner.get_inputs()[1].name: decoder_out})[0]
        return out


class RKNNModel(BaseModel):
    def __init__(
            self,
            encoder_model_path,
            decoder_model_path,
            joiner_model_path,
            target,
            device_id
    ):
        super().__init__()

        self.encoder = self.init_model(encoder_model_path, target, device_id)
        self.decoder = self.init_model(decoder_model_path, target, device_id)
        self.joiner = self.init_model(joiner_model_path, target, device_id)

    def init_model(self, model_path, target, device_id):
        # Create RKNN object
        rknn = RKNN(verbose=False)

        # Load RKNN model
        print('--> Loading model')
        ret = rknn.load_rknn(model_path)
        if ret != 0:
            print('Load RKNN model \"{}\" failed!'.format(model_path))
            exit(ret)
        print('done')

        # init runtime environment
        print('--> Init runtime environment')
        ret = rknn.init_runtime(
            target=target, device_id=device_id)
        if ret != 0:
            print('Init runtime environment failed')
            exit(ret)

        return rknn

    def release_model(self):
        self.encoder.release()
        self.decoder.release()
        self.joiner.release()
        self.encoder = None
        self.decoder = None
        self.joiner = None

    def run_encoder(self, x):
        self.encoder_input[0] = x.numpy()
        self.encoder_input_dict['x'] = x.numpy()
        out = self.encoder.inference(inputs=self.encoder_input)
        self.update_encoder_input(out, 'rknn')
        return out[0]

    def run_decoder(self, decoder_input):
        out = self.decoder.inference(inputs=decoder_input)[0]
        return out

    def run_joiner(self, encoder_out, decoder_out):
        out = self.joiner.inference(inputs=[encoder_out, decoder_out])[0]
        return out


def read_vocab(tokens_file):
    vocab = {}
    with open(tokens_file, 'r') as f:
        for line in f:
            if len(line.strip().split(' ')) < 2:
                key = line.strip().split(' ')[0]
                value = ""
            else:
                value, key = line.strip().split(' ')
            vocab[key] = value
    return vocab


def set_model(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
    if (encoder_model_path.endswith(".rknn") and
            decoder_model_path.endswith(".rknn") and
            joiner_model_path.endswith(".rknn")):
        model = RKNNModel(encoder_model_path, decoder_model_path, joiner_model_path,
                          target=target, device_id=device_id)
    elif (encoder_model_path.endswith(".onnx") and
          decoder_model_path.endswith(".onnx") and
          joiner_model_path.endswith(".onnx")):
        model = OnnxModel(encoder_model_path, decoder_model_path, joiner_model_path,
                          target=target, device_id=device_id)
    else:
        raise ValueError("Model files must be either all .onnx or all .rknn")
    return model

def run_model(model, audio_data):
    # Set kaldifeat config
    opts = kaldifeat.FbankOptions()
    opts.frame_opts.samp_freq = 16000 # sample_rate=16000
    opts.mel_opts.num_bins = 80
    opts.mel_opts.high_freq = -400
    opts.frame_opts.dither = 0
    opts.frame_opts.snip_edges = False
    fbank = kaldifeat.OnlineFbank(opts)

    # Inference
    num_processed_frames = 0
    segment = 103
    offset = 96
    context_size = 2
    hyp = None
    decoder_out = None

    fbank.accept_waveform(sampling_rate=16000, waveform=audio_data)
    num_frames = fbank.num_frames_ready
    timestamp = []
    frame_offset = 0

    while num_frames - num_processed_frames > 0:
        if (num_frames - num_processed_frames) < segment:
            tail_padding_len = (segment - (num_frames - num_processed_frames)) / 100.0
            tail_padding = torch.zeros(int(tail_padding_len * 16000), dtype=torch.float32)
            fbank.accept_waveform(sampling_rate=16000, waveform=tail_padding)
        frames = []
        for i in range(segment):
            frames.append(fbank.get_frame(num_processed_frames + i))

        frames = torch.cat(frames, dim=0)
        frames = frames.unsqueeze(0)
        hyp, decoder_out, timestamp, frame_offset = model.run_greedy_search(frames, context_size, decoder_out, hyp, num_processed_frames, timestamp, frame_offset)
        num_processed_frames += offset

    return hyp[context_size:], timestamp

def post_process(hyp, vocab, timestamp):
    text = ""
    for i in hyp:
        text += vocab[str(i)]
    text = text.replace("▁", " ").strip()

    frame_shift_ms = 10
    subsampling_factor = 4
    frame_shift_s = frame_shift_ms / 1000.0 * subsampling_factor
    real_timestamp = [round(frame_shift_s * t, 2) for t in timestamp]
    return text, real_timestamp

def ensure_sample_rate(waveform, original_sample_rate, desired_sample_rate=16000):
    if original_sample_rate != desired_sample_rate:
        print("resample_audio: {} HZ -> {} HZ".format(original_sample_rate, desired_sample_rate))
        desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate))
        waveform = scipy.signal.resample(waveform, desired_length)
    return waveform, desired_sample_rate

def ensure_channels(waveform, original_channels, desired_channels=1):
    if original_channels != desired_channels:
        print("convert_channels: {} -> {}".format(original_channels, desired_channels))
        waveform = np.mean(waveform, axis=1)
    return waveform, desired_channels

model = None
vocab = None

def initialize_model():
    global model, vocab
    # 硬编码模型路径和参数
    args = type('', (), {})()
    args.encoder_model_path = "encoder-epoch-99-avg-1.onnx"  # 修改为你的onnx模型文件路径
    args.decoder_model_path = "decoder-epoch-99-avg-1.onnx"  # 修改为你的onnx模型文件路径
    args.joiner_model_path = "joiner-epoch-99-avg-1.onnx"    # 修改为你的onnx模型文件路径
    args.target = "rk3588"                                    # 根据你的硬件修改
    args.device_id = None                                     # 如果有特定设备ID,可以设置
    args.vocab_path = "vocab.txt"                             # 修改为你的词汇表文件路径

    model = set_model(
        args.encoder_model_path,
        args.decoder_model_path,
        args.joiner_model_path,
        args.target,
        args.device_id
    )
    model.init_encoder_input()
    vocab = read_vocab(args.vocab_path)

def process_audio(tmp_file_name):
    global model, vocab
    if model is None or vocab is None:
        print("Model or vocab not initialized.")
        return "", []
    audio, sr = sf.read(tmp_file_name)
    # 音频预处理(转换为16000 Hz,单声道等)
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    if sr != 16000:
        audio = scipy.signal.resample(audio, int(len(audio) * 16000 / sr))
    audio = torch.from_numpy(audio).float()
    hyp, timestamp = run_model(model, audio)
    text, real_timestamp = post_process(hyp, vocab, timestamp)
    return text, real_timestamp

@app.route('/transcribe', methods=['POST'])
def transcribe_audio():
    if 'audio' not in request.files:
        return jsonify({"error": "No audio file provided"}), 400

    audio_file = request.files['audio']
    if audio_file.filename == '':
        return jsonify({"error": "Empty filename"}), 400

    # 创建临时文件
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
        audio_file.save(tmp_file.name)
        try:
            text, timestamps = process_audio(tmp_file.name)
            os.unlink(tmp_file.name)
            return jsonify({
                "text": text,
                "timestamps": timestamps
            })
        except Exception as e:
            os.unlink(tmp_file.name)
            return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    # 初始化模型和词汇表
    initialize_model()

    # 启动Flask应用,默认端口5000
    app.run(host='0.0.0.0', port=5000, threaded=False)  # RKNN可能需要单线程模式

客户端

python client.py

import requests

def send_audio_for_transcription(audio_path, server_url):
    """发送音频文件到服务端进行语音识别"""
    try:
        with open(audio_path, 'rb') as audio_file:
            files = {'audio': (audio_path, audio_file, 'audio/wav')}
            response = requests.post(server_url, files=files)
            
        if response.status_code == 200:
            result = response.json()
            return result
        else:
            return {'error': f"Server returned status code {response.status_code}", 'details': response.text}
    except Exception as e:
        return {'error': str(e)}

def main():
    # 硬编码音频文件路径和服务器地址
    audio_path = '1.wav'
    server_url = 'http://localhost:5000/transcribe'

    result = send_audio_for_transcription(audio_path, server_url)
    
    if 'error' in result:
        print(f"错误发生: {result['error']}")
        if 'details' in result:
            print(f"详细信息: {result['details']}")
    else:
        print("\n识别结果:")
        print(f"文本: {result['text']}")
        print("时间戳:")
        for timestamp in result['timestamps']:
            print(f"{timestamp:.2f}s", end=" ")
        print("\n")

if __name__ == "__main__":
    main()

最新和安卓通讯1

app.py

import onnxruntime
from rknn.api import RKNN
import torch
import kaldifeat
import numpy as np
import json
import websockets
import asyncio
import logging
import traceback
import os
import re

# 配置日志输出
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s [%(levelname)s] %(message)s'
)
logger = logging.getLogger(__name__)

# 自定义日志级别名称为中文
logging.addLevelName(logging.DEBUG, "调试")
logging.addLevelName(logging.INFO, "信息")
logging.addLevelName(logging.WARNING, "警告")
logging.addLevelName(logging.ERROR, "错误")
logging.addLevelName(logging.CRITICAL, "严重")

class BaseModel:
    def __init__(self):
        logger.debug("初始化BaseModel")
        self.model_config = {
            'x': [1, 103, 80], 'cached_len_0': [2, 1], 'cached_len_1': [2, 1], 'cached_len_2': [2, 1],
            'cached_len_3': [2, 1], 'cached_len_4': [2, 1], 'cached_avg_0': [2, 1, 256],
            'cached_avg_1': [2, 1, 256], 'cached_avg_2': [2, 1, 256], 'cached_avg_3': [2, 1, 256],
            'cached_avg_4': [2, 1, 256], 'cached_key_0': [2, 192, 1, 192], 'cached_key_1': [2, 96, 1, 192],
            'cached_key_2': [2, 48, 1, 192], 'cached_key_3': [2, 24, 1, 192], 'cached_key_4': [2, 96, 1, 192],
            'cached_val_0': [2, 192, 1, 96], 'cached_val_1': [2, 96, 1, 96], 'cached_val_2': [2, 48, 1, 96],
            'cached_val_3': [2, 24, 1, 96], 'cached_val_4': [2, 96, 1, 96],
            'cached_val2_0': [2, 192, 1, 96], 'cached_val2_1': [2, 96, 1, 96], 'cached_val2_2': [2, 48, 1, 96],
            'cached_val2_3': [2, 24, 1, 96], 'cached_val2_4': [2, 96, 1, 96],
            'cached_conv1_0': [2, 1, 256, 30], 'cached_conv1_1': [2, 1, 256, 30],
            'cached_conv1_2': [2, 1, 256, 30], 'cached_conv1_3': [2, 1, 256, 30], 'cached_conv1_4': [2, 1, 256, 30],
            'cached_conv2_0': [2, 1, 256, 30], 'cached_conv2_1': [2, 1, 256, 30],
            'cached_conv2_2': [2, 1, 256, 30], 'cached_conv2_3': [2, 1, 256, 30], 'cached_conv2_4': [2, 1, 256, 30]
        }

    def init_encoder_input(self):
        logger.debug("初始化编码器输入")
        encoder_input_dict = {}
        for input_name in self.model_config:
            if 'cached_len' in input_name:
                data = np.zeros((self.model_config[input_name]), dtype=np.int64)
                encoder_input_dict[input_name] = data
            else:
                data = np.zeros((self.model_config[input_name]), dtype=np.float32)
                encoder_input_dict[input_name] = data
        return encoder_input_dict

    def convert_nchw_to_nhwc(self, src):
        dst = np.transpose(src, (0, 2, 3, 1))
        return dst

    def init_model(self, model_path, target, device_id):
        pass

    def release_model(self):
        pass

    def run_encoder(self, x, encoder_input_dict):
        pass

    def run_decoder(self, decoder_input):
        pass

    def run_joiner(self, encoder_out, decoder_out):
        pass


class OnnxModel(BaseModel):
    def __init__(self, encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
        super().__init__()
        logger.debug("初始化ONNX模型")
        self.encoder = self.init_model(encoder_model_path, target, device_id)
        self.decoder = self.init_model(decoder_model_path, target, device_id)
        self.joiner = self.init_model(joiner_model_path, target, device_id)

    def init_model(self, model_path, target, device_id):
        logger.debug(f"加载ONNX模型: {model_path}")
        if not os.path.exists(model_path):
            logger.error(f"模型文件未找到: {model_path}")
            raise FileNotFoundError(f"模型文件未找到: {model_path}")
        try:
            model = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
            logger.debug(f"ONNX模型 {model_path} 加载成功")
            return model
        except Exception as e:
            logger.error(f"加载ONNX模型失败: {model_path}, 错误: {str(e)}")
            raise

    def release_model(self):
        logger.debug("释放ONNX模型")
        del self.encoder, self.decoder, self.joiner
        self.encoder = self.decoder = self.joiner = None

    def run_encoder(self, x, encoder_input_dict):
        encoder_input_dict['x'] = x.numpy()
        inputs = {}
        for input_name in self.model_config:
            inputs[input_name] = encoder_input_dict[input_name]
        logger.debug("运行ONNX编码器推理")
        try:
            out = self.encoder.run(None, inputs)
            new_encoder_input_dict = {}
            for idx, input_name in enumerate(self.model_config):
                if idx == 0:
                    encoder_output = out[0]
                else:
                    new_encoder_input_dict[input_name] = out[idx]
            return encoder_output, new_encoder_input_dict
        except Exception as e:
            logger.error(f"ONNX编码器运行错误: {str(e)}")
            logger.debug(traceback.format_exc())
            raise

    def run_decoder(self, decoder_input):
        logger.debug(f"运行ONNX解码器,输入形状: {decoder_input.shape}")
        try:
            out = self.decoder.run(None, {self.decoder.get_inputs()[0].name: decoder_input})[0]
            return out
        except Exception as e:
            logger.error(f"ONNX解码器运行错误: {str(e)}")
            logger.debug(traceback.format_exc())
            raise

    def run_joiner(self, encoder_out, decoder_out):
        logger.debug(f"运行ONNX连接器,形状: {encoder_out.shape}, {decoder_out.shape}")
        try:
            out = self.joiner.run(None, {
                self.joiner.get_inputs()[0].name: encoder_out,
                self.joiner.get_inputs()[1].name: decoder_out
            })[0]
            return out
        except Exception as e:
            logger.error(f"ONNX连接器运行错误: {str(e)}")
            logger.debug(traceback.format_exc())
            raise


class RKNNModel(BaseModel):
    def __init__(self, encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
        super().__init__()
        logger.debug("初始化RKNN模型")
        self.encoder = self.init_model(encoder_model_path, target, device_id)
        self.decoder = self.init_model(decoder_model_path, target, device_id)
        self.joiner = self.init_model(joiner_model_path, target, device_id)

    def init_model(self, model_path, target, device_id):
        logger.debug(f"加载RKNN模型: {model_path}")
        if not os.path.exists(model_path):
            logger.error(f"模型文件未找到: {model_path}")
            raise FileNotFoundError(f"模型文件未找到: {model_path}")
        rknn = RKNN(verbose=False)
        ret = rknn.load_rknn(model_path)
        if ret != 0:
            logger.error('加载RKNN模型失败!')
            exit(ret)
        ret = rknn.init_runtime(target=target, device_id=device_id)
        if ret != 0:
            logger.error('初始化运行环境失败')
            exit(ret)
        logger.debug(f"RKNN模型 {model_path} 加载成功")
        return rknn

    def release_model(self):
        logger.debug("释放RKNN模型")
        self.encoder.release()
        self.decoder.release()
        self.joiner.release()
        self.encoder = self.decoder = self.joiner = None

    def run_encoder(self, x, encoder_input_dict):
        encoder_input_dict['x'] = x.numpy()
        input_list = []
        for input_name in self.model_config:
            input_list.append(encoder_input_dict[input_name])
        logger.debug("运行RKNN编码器推理")
        try:
            out = self.encoder.inference(inputs=input_list)
            new_encoder_input_dict = {}
            for idx, input_name in enumerate(self.model_config):
                if idx == 0:
                    encoder_output = out[0]
                else:
                    data = out[idx]
                    if idx > 10:
                        data = self.convert_nchw_to_nhwc(data)
                    new_encoder_input_dict[input_name] = data
            return encoder_output, new_encoder_input_dict
        except Exception as e:
            logger.error(f"RKNN编码器推理错误: {str(e)}")
            logger.debug(traceback.format_exc())
            raise

    def run_decoder(self, decoder_input):
        logger.debug(f"运行RKNN解码器,输入形状: {decoder_input.shape}")
        try:
            out = self.decoder.inference(inputs=decoder_input)[0]
            return out
        except Exception as e:
            logger.error(f"RKNN解码器推理错误: {str(e)}")
            logger.debug(traceback.format_exc())
            raise

    def run_joiner(self, encoder_out, decoder_out):
        logger.debug(f"运行RKNN连接器,形状: {encoder_out.shape}, {decoder_out.shape}")
        try:
            out = self.joiner.inference(inputs=[encoder_out, decoder_out])[0]
            return out
        except Exception as e:
            logger.error(f"RKNN连接器推理错误: {str(e)}")
            logger.debug(traceback.format_exc())
            raise


def read_vocab(tokens_file):
    logger.debug(f"从 {tokens_file} 读取词汇表")
    vocab = {}
    with open(tokens_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split(' ')
            if len(parts) < 2:
                key, value = parts[0], ''
            else:
                value, key = parts[0], parts[1]
            vocab[key] = value
    logger.debug(f"加载了 {len(vocab)} 个词汇条目")
    return vocab


def set_model(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
    logger.debug("设置模型")
    encoder_ext = os.path.splitext(encoder_model_path)[1]
    decoder_ext = os.path.splitext(decoder_model_path)[1]
    joiner_ext = os.path.splitext(joiner_model_path)[1]

    if encoder_ext != decoder_ext or encoder_ext != joiner_ext:
        logger.error("所有模型文件必须具有相同的扩展名(例如 .onnx 或 .rknn)")
        raise ValueError("所有模型文件必须具有相同的扩展名(例如 .onnx 或 .rknn)")

    if encoder_ext == ".onnx":
        return OnnxModel(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id)
    elif encoder_ext == ".rknn":
        return RKNNModel(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id)
    else:
        logger.error("不支持的模型格式:必须是 .onnx 或 .rknn")
        raise ValueError("不支持的模型格式:必须是 .onnx 或 .rknn")


def post_process(hyp, vocab, timestamp):
    logger.debug("后处理结果")
    if not hyp or len(timestamp) != len(hyp):
        return "", []

    frame_shift_ms = 10
    subsampling_factor = 4
    frame_shift_s = frame_shift_ms / 1000.0 * subsampling_factor
    real_timestamp = [round(t * frame_shift_s, 2) for t in timestamp]

    # 标点插入阈值(秒)
    comma_threshold = 0.4
    period_threshold = 0.8

    tokens_with_punct = []

    for i in range(len(hyp)):
        token_id = hyp[i]
        token = vocab.get(str(token_id), "")
        if not token:
            continue
        token = token.replace("▁", " ").strip()

        punctuation = ''
        if i < len(timestamp) - 1:
            current_time = real_timestamp[i]
            next_time = real_timestamp[i + 1]
            delta = next_time - current_time
            if delta >= period_threshold:
                punctuation = '。'
            elif delta >= comma_threshold:
                punctuation = ','
        tokens_with_punct.append(token + punctuation)

    # 确保最后一个 token 有标点结尾
    if tokens_with_punct:
        last_token = tokens_with_punct[-1]
        if not (last_token.endswith(('。', '.', '!', '!', '?', '?'))):
            if last_token.endswith(','):
                tokens_with_punct[-1] = last_token[:-1] + '。'
            else:
                tokens_with_punct[-1] += '。'

    # 合并并清理多余标点
    text = ''.join(tokens_with_punct)
    text = re.sub(r'([。,!?\.\?!])[。,!?\.\?!]*', r'\1', text)

    logger.debug(f"生成文本: {text}, 时间戳: {real_timestamp}")
    return text, real_timestamp


model = None
vocab = None


def initialize_model():
    global model, vocab
    logger.info("初始化模型")
    args = type('', (), {})()
    args.encoder_model_path = "encoder-epoch-99-avg-1.onnx"
    args.decoder_model_path = "decoder-epoch-99-avg-1.onnx"
    args.joiner_model_path = "joiner-epoch-99-avg-1.onnx"
    args.target = "rk3588"
    args.device_id = None
    args.vocab_path = "vocab.txt"

    try:
        # 如果已有模型,先释放
        if model is not None:
            model.release_model()
            logger.debug("释放已有模型")

        for model_path in [args.encoder_model_path, args.decoder_model_path, args.joiner_model_path]:
            if not os.path.exists(model_path):
                logger.error(f"模型文件未找到: {model_path}")
                raise FileNotFoundError(f"模型文件未找到: {model_path}")

        model = set_model(args.encoder_model_path, args.decoder_model_path, args.joiner_model_path, args.target, args.device_id)
        vocab = read_vocab(args.vocab_path)
        logger.info("模型初始化成功")
    except Exception as e:
        logger.error(f"模型初始化失败: {str(e)}")
        logger.debug(traceback.format_exc())
        raise SystemExit(1)


async def handle_client(websocket, path=None):
    global model, vocab
    if model is None or vocab is None:
        logger.error("模型未初始化")
        raise RuntimeError("模型未初始化")

    try:
        logger.info(f"客户端已连接,路径: {path if path is not None else 'N/A'}")

        # 每个客户端维护自己的状态
        encoder_input_dict = model.init_encoder_input()
        opts = kaldifeat.FbankOptions()
        opts.frame_opts.samp_freq = 16000
        opts.mel_opts.num_bins = 80
        opts.frame_opts.dither = 0
        opts.frame_opts.snip_edges = False
        fbank = kaldifeat.OnlineFbank(opts)

        context_size = 2
        blank_id = 0
        hyp = [blank_id] * context_size
        decoder_out = None
        num_processed_frames = 0
        timestamp = []
        frame_offset = 0

        segment = 103
        offset = 96

        while True:
            try:
                data = await websocket.recv()
                logger.debug(f"收到数据大小: {len(data) if isinstance(data, bytes) else 0}")

                if not isinstance(data, bytes):
                    continue

                audio_np = np.frombuffer(data, dtype=np.int16)
                logger.debug(f"解码 {len(audio_np)} 个采样点")

                audio_float = audio_np.astype(np.float32) / 32768.0
                audio_tensor = torch.from_numpy(audio_float)
                fbank.accept_waveform(sampling_rate=16000, waveform=audio_tensor)

                while True:
                    num_frames = fbank.num_frames_ready
                    if num_frames - num_processed_frames < segment:
                        break

                    frames = []
                    for i in range(segment):
                        frame = fbank.get_frame(num_processed_frames + i)
                        frames.append(frame)
                    frames = torch.cat(frames, dim=0).unsqueeze(0)

                    # 运行编码器,更新状态
                    encoder_out, encoder_input_dict = model.run_encoder(frames, encoder_input_dict)
                    encoder_out = encoder_out.squeeze(0)

                    blank_id = 0
                    unk_id = 2

                    if decoder_out is None:
                        decoder_input = np.array([hyp], dtype=np.int64)
                        decoder_out = model.run_decoder(decoder_input)

                    T = encoder_out.shape[0]
                    for t in range(T):
                        cur_encoder_out = encoder_out[t: t + 1]
                        joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
                        y = np.argmax(joiner_out, axis=0)
                        if y != blank_id and y != unk_id:
                            timestamp.append(frame_offset + t)
                            hyp.append(y)
                            decoder_input = hyp[-context_size:]
                            decoder_input = np.array([decoder_input], dtype=np.int64)
                            decoder_out = model.run_decoder(decoder_input)
                    frame_offset += T

                    current_text, real_timestamp = post_process(hyp[context_size:], vocab, timestamp)
                    logger.info(f"识别结果: {current_text}")

                    response = json.dumps({
                        "text": current_text,
                        "timestamps": real_timestamp
                    }).encode('utf-8')
                    await websocket.send(response)

                    # 重置状态,保留最新内容
                    hyp = hyp[:context_size]  # 保留context部分
                    timestamp = []
                    frame_offset = 0
                    decoder_out = None

                    num_processed_frames += offset

            except websockets.exceptions.ConnectionClosed:
                break

    except Exception as e:
        logger.error(f"发生错误: {str(e)}")
        logger.debug(traceback.format_exc())
        try:
            error_msg = json.dumps({"error": str(e)}).encode('utf-8')
            await websocket.send(error_msg)
        except:
            pass


async def main():
    try:
        initialize_model()
        server = await websockets.serve(
            handle_client, "0.0.0.0", 5000,
            ping_interval=20,
            ping_timeout=None,
        )
        logger.info("WebSocket服务器启动,监听端口5000...")
        await server.wait_closed()
    except Exception as e:
        logger.error(f"服务器错误: {str(e)}")
        logger.debug(traceback.format_exc())
    finally:
        try:
            if model is not None:
                model.release_model()
                logger.info("模型已释放")
        except:
            pass


if __name__ == "__main__":
    asyncio.run(main())

我这里自带的yolo环境,里面主要是rknn的环境

 ONNX转换为RKNN模型需要使用官方rknn_model_zoo工具:rknn_model_zoo-2.2.0

在这里把环境包复制一份到3588上安装一下就完成!!!

(有人可能没有看懂怎么在npu上搭建环境,只需要把我标出的文件拷贝到板端然后用

pip install -r 文件名称 -i https://pypi.tuna.tsinghua.edu.cn/simple  下载就完成了!!!)

经测试在电脑端及3588上都可以正常跑

开发板3588上跑缺失一个安装包报错

Traceback (most recent call last):
  File "/home/orangepi/work_11.15/rknn_model_zoo/examples/zipformer/python/zipformer.py", line 4, in <module>
    import kaldifeat

解决方法:安装kaldifeat模块包

进入官网From pre-compiled wheels (Recommended) — kaldifeat 1.25.5 documentation

git clone https://github.com/csukuangfj/kaldifeat

cd kaldifeat
python3 setup.py install

开始编译

编译完成

支持平台

RK3566, RK3568, RK3588, RK3562, RK3576

四、下载模型(在3588上)

1、下载模型文件

cd model


 

./download_model.sh

2、模型转换

cd python

python convert.py ../model/encoder-epoch-99-avg-1.onnx rk3588



# output model will be saved as ../model/encoder-epoch-99-avg-1.rknn3

python convert.py ../model/decoder-epoch-99-avg-1.onnx rk3588


# output model will be saved as ../model/decoder-epoch-99-avg-1.rknn

python convert.py ../model/joiner-epoch-99-avg-1.onnx rk3588


# output model will be saved as ../model/joiner-epoch-99-avg-1.rknn

3、运行

kaldifeat模块包官网

From pre-compiled wheels (Recommended) — kaldifeat 1.25.5 documentation

https://csukuangfj.github.io/kaldifeat/cpu.html

# Install kaldifeat
# Refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html for installation.
# This python demo is tested under version: kaldifeat-1.25.4.dev20240223

开发板3588上跑

cd python

python zipformer.py --encoder_model_path encoder-epoch-99-avg-1.rknn --decoder_model_path decoder-epoch-99-avg-1.rknn --joiner_model_path joiner-epoch-99-avg-1.rknn --target 3588

完成成功输出!

电脑端跑

conda activate yuyin

python zipformer.py --encoder_model_path encoder-epoch-99-avg-1.onnx --decoder_model_path decoder-epoch-99-avg-1.onnx --joiner_model_path joiner-epoch-99-avg-1.onnx 

成功运行输出:

 python zipformer.py   --encoder_model_path encoder-epoch-99-avg-1.onnx   --decoder_model_path decoder-epoch-99-avg-1.onnx   --joiner_model_path joiner-epoch-99-avg-1.onnx   --sample_rate 16000

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

山山而川_R

谢谢鼓励

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值