一、简介
使用k2 zipformer流媒体的中英文ASR模型。
本例中使用的模型来自以下开源项目:
https://huggingface.co/csukuangfj/k2fsa-zipformer-bilingual-zh-en-t
重新做了一下环境
二、装环境
1、建立一个单独的环境flask
conda create -n flask python==3.10 -y
conda activate flask
2、装pytorch
pip3 install torch==2.5.0 torchvision==0.20.0 -i https://pypi.mirrors.ustc.edu.cn/simple/
3、其它环境
pip install soundfile
pip install flask
pip install sounddevice
4、如果装报错
pip install pyaudio
Failed to build pyaudio
conda install pyaudio
5、报错
Traceback (most recent call last):
File "/home/orangepi/work_11.15/rknn_model_zoo/examples/zipformer/python/zipformer.py", line 4, in <module>
import kaldifeat
解决方法:安装kaldifeat模块包
进入官网From pre-compiled wheels (Recommended) — kaldifeat 1.25.5 documentation
python3 setup.py install
三、运行
1、进入环境
查看文件都有什么主要模型文件名字和一个1.wav测试音频
cd /home/sxj/zipformer
conda activate flask
2、运行
python app.py
服务端
import onnxruntime
from rknn.api import RKNN
import torch
import kaldifeat
import soundfile as sf
import numpy as np
import scipy
import os
import tempfile
from flask import Flask, request, jsonify
app = Flask(__name__)
# 这里可以放其他代码(BaseModel, OnnxModel, RKNNModel等)
class BaseModel():
def __init__(self):
self.model_config = {'x': [1, 103, 80], 'cached_len_0': [2, 1], 'cached_len_1': [2, 1], 'cached_len_2': [2, 1],
'cached_len_3': [2, 1],
'cached_len_4': [2, 1], 'cached_avg_0': [2, 1, 256], 'cached_avg_1': [2, 1, 256],
'cached_avg_2': [2, 1, 256],
'cached_avg_3': [2, 1, 256], 'cached_avg_4': [2, 1, 256], 'cached_key_0': [2, 192, 1, 192],
'cached_key_1': [2, 96, 1, 192],
'cached_key_2': [2, 48, 1, 192], 'cached_key_3': [2, 24, 1, 192],
'cached_key_4': [2, 96, 1, 192], 'cached_val_0': [2, 192, 1, 96],
'cached_val_1': [2, 96, 1, 96], 'cached_val_2': [2, 48, 1, 96],
'cached_val_3': [2, 24, 1, 96], 'cached_val_4': [2, 96, 1, 96],
'cached_val2_0': [2, 192, 1, 96], 'cached_val2_1': [2, 96, 1, 96],
'cached_val2_2': [2, 48, 1, 96], 'cached_val2_3': [2, 24, 1, 96],
'cached_val2_4': [2, 96, 1, 96], 'cached_conv1_0': [2, 1, 256, 30],
'cached_conv1_1': [2, 1, 256, 30], 'cached_conv1_2': [2, 1, 256, 30],
'cached_conv1_3': [2, 1, 256, 30], 'cached_conv1_4': [2, 1, 256, 30],
'cached_conv2_0': [2, 1, 256, 30], 'cached_conv2_1': [2, 1, 256, 30],
'cached_conv2_2': [2, 1, 256, 30], 'cached_conv2_3': [2, 1, 256, 30],
'cached_conv2_4': [2, 1, 256, 30]}
def init_encoder_input(self):
self.encoder_input = []
self.encoder_input_dict = {}
for input_name in self.model_config:
if 'cached_len' in input_name:
data = np.zeros((self.model_config[input_name]), dtype=np.int64)
self.encoder_input.append(data)
self.encoder_input_dict[input_name] = data
else:
data = np.zeros((self.model_config[input_name]), dtype=np.float32)
self.encoder_input.append(data)
self.encoder_input_dict[input_name] = data
def update_encoder_input(self, out, model_type):
for idx, input_name in enumerate(self.encoder_input_dict):
if idx == 0:
continue
if idx > 10 and model_type == 'rknn':
data = self.convert_nchw_to_nhwc(out[idx])
else:
data = out[idx]
self.encoder_input[idx] = data
self.encoder_input_dict[input_name] = data
def convert_nchw_to_nhwc(self, src):
dst = np.transpose(src, (0, 2, 3, 1))
return dst
def init_model(self, model_path, target, device_id):
pass
def release_model(self):
pass
def run_encoder(self, x):
pass
def run_decoder(self, decoder_input):
pass
def run_joiner(self, encoder_out, decoder_out):
pass
def run_greedy_search(self, frames, context_size, decoder_out, hyp, num_processed_frames, timestamp, frame_offset):
encoder_out = self.run_encoder(frames)
encoder_out = encoder_out.squeeze(0)
blank_id = 0
unk_id = 2
if decoder_out is None and hyp is None:
hyp = [blank_id] * context_size
decoder_input = np.array([hyp], dtype=np.int64)
decoder_out = self.run_decoder(decoder_input)
T = encoder_out.shape[0]
for t in range(T):
cur_encoder_out = encoder_out[t: t + 1]
joiner_out = self.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = np.argmax(joiner_out, axis=0)
if y != blank_id and y != unk_id:
timestamp.append(frame_offset + t)
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = np.array([decoder_input], dtype=np.int64)
decoder_out = self.run_decoder(decoder_input)
frame_offset += T
return hyp, decoder_out, timestamp, frame_offset
class OnnxModel(BaseModel):
def __init__(
self,
encoder_model_path,
decoder_model_path,
joiner_model_path,
target,
device_id
):
super().__init__()
self.encoder = self.init_model(encoder_model_path, target, device_id)
self.decoder = self.init_model(decoder_model_path, target, device_id)
self.joiner = self.init_model(joiner_model_path, target, device_id)
def init_model(self, model_path, target, device_id):
model = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
return model
def release_model(self):
del self.encoder
del self.decoder
del self.joiner
self.encoder = None
self.decoder = None
self.joiner = None
def run_encoder(self, x):
self.encoder_input[0] = x.numpy()
self.encoder_input_dict['x'] = x.numpy()
out = self.encoder.run(None, self.encoder_input_dict)
self.update_encoder_input(out, 'onnx')
return out[0]
def run_decoder(self, decoder_input):
out = self.decoder.run(None, {self.decoder.get_inputs()[0].name: decoder_input})[0]
return out
def run_joiner(self, encoder_out, decoder_out):
out = self.joiner.run(None, {self.joiner.get_inputs()[0].name: encoder_out,
self.joiner.get_inputs()[1].name: decoder_out})[0]
return out
class RKNNModel(BaseModel):
def __init__(
self,
encoder_model_path,
decoder_model_path,
joiner_model_path,
target,
device_id
):
super().__init__()
self.encoder = self.init_model(encoder_model_path, target, device_id)
self.decoder = self.init_model(decoder_model_path, target, device_id)
self.joiner = self.init_model(joiner_model_path, target, device_id)
def init_model(self, model_path, target, device_id):
# Create RKNN object
rknn = RKNN(verbose=False)
# Load RKNN model
print('--> Loading model')
ret = rknn.load_rknn(model_path)
if ret != 0:
print('Load RKNN model \"{}\" failed!'.format(model_path))
exit(ret)
print('done')
# init runtime environment
print('--> Init runtime environment')
ret = rknn.init_runtime(
target=target, device_id=device_id)
if ret != 0:
print('Init runtime environment failed')
exit(ret)
return rknn
def release_model(self):
self.encoder.release()
self.decoder.release()
self.joiner.release()
self.encoder = None
self.decoder = None
self.joiner = None
def run_encoder(self, x):
self.encoder_input[0] = x.numpy()
self.encoder_input_dict['x'] = x.numpy()
out = self.encoder.inference(inputs=self.encoder_input)
self.update_encoder_input(out, 'rknn')
return out[0]
def run_decoder(self, decoder_input):
out = self.decoder.inference(inputs=decoder_input)[0]
return out
def run_joiner(self, encoder_out, decoder_out):
out = self.joiner.inference(inputs=[encoder_out, decoder_out])[0]
return out
def read_vocab(tokens_file):
vocab = {}
with open(tokens_file, 'r') as f:
for line in f:
if len(line.strip().split(' ')) < 2:
key = line.strip().split(' ')[0]
value = ""
else:
value, key = line.strip().split(' ')
vocab[key] = value
return vocab
def set_model(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
if (encoder_model_path.endswith(".rknn") and
decoder_model_path.endswith(".rknn") and
joiner_model_path.endswith(".rknn")):
model = RKNNModel(encoder_model_path, decoder_model_path, joiner_model_path,
target=target, device_id=device_id)
elif (encoder_model_path.endswith(".onnx") and
decoder_model_path.endswith(".onnx") and
joiner_model_path.endswith(".onnx")):
model = OnnxModel(encoder_model_path, decoder_model_path, joiner_model_path,
target=target, device_id=device_id)
else:
raise ValueError("Model files must be either all .onnx or all .rknn")
return model
def run_model(model, audio_data):
# Set kaldifeat config
opts = kaldifeat.FbankOptions()
opts.frame_opts.samp_freq = 16000 # sample_rate=16000
opts.mel_opts.num_bins = 80
opts.mel_opts.high_freq = -400
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
fbank = kaldifeat.OnlineFbank(opts)
# Inference
num_processed_frames = 0
segment = 103
offset = 96
context_size = 2
hyp = None
decoder_out = None
fbank.accept_waveform(sampling_rate=16000, waveform=audio_data)
num_frames = fbank.num_frames_ready
timestamp = []
frame_offset = 0
while num_frames - num_processed_frames > 0:
if (num_frames - num_processed_frames) < segment:
tail_padding_len = (segment - (num_frames - num_processed_frames)) / 100.0
tail_padding = torch.zeros(int(tail_padding_len * 16000), dtype=torch.float32)
fbank.accept_waveform(sampling_rate=16000, waveform=tail_padding)
frames = []
for i in range(segment):
frames.append(fbank.get_frame(num_processed_frames + i))
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
hyp, decoder_out, timestamp, frame_offset = model.run_greedy_search(frames, context_size, decoder_out, hyp, num_processed_frames, timestamp, frame_offset)
num_processed_frames += offset
return hyp[context_size:], timestamp
def post_process(hyp, vocab, timestamp):
text = ""
for i in hyp:
text += vocab[str(i)]
text = text.replace("▁", " ").strip()
frame_shift_ms = 10
subsampling_factor = 4
frame_shift_s = frame_shift_ms / 1000.0 * subsampling_factor
real_timestamp = [round(frame_shift_s * t, 2) for t in timestamp]
return text, real_timestamp
def ensure_sample_rate(waveform, original_sample_rate, desired_sample_rate=16000):
if original_sample_rate != desired_sample_rate:
print("resample_audio: {} HZ -> {} HZ".format(original_sample_rate, desired_sample_rate))
desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate))
waveform = scipy.signal.resample(waveform, desired_length)
return waveform, desired_sample_rate
def ensure_channels(waveform, original_channels, desired_channels=1):
if original_channels != desired_channels:
print("convert_channels: {} -> {}".format(original_channels, desired_channels))
waveform = np.mean(waveform, axis=1)
return waveform, desired_channels
model = None
vocab = None
def initialize_model():
global model, vocab
# 硬编码模型路径和参数
args = type('', (), {})()
args.encoder_model_path = "encoder-epoch-99-avg-1.onnx" # 修改为你的onnx模型文件路径
args.decoder_model_path = "decoder-epoch-99-avg-1.onnx" # 修改为你的onnx模型文件路径
args.joiner_model_path = "joiner-epoch-99-avg-1.onnx" # 修改为你的onnx模型文件路径
args.target = "rk3588" # 根据你的硬件修改
args.device_id = None # 如果有特定设备ID,可以设置
args.vocab_path = "vocab.txt" # 修改为你的词汇表文件路径
model = set_model(
args.encoder_model_path,
args.decoder_model_path,
args.joiner_model_path,
args.target,
args.device_id
)
model.init_encoder_input()
vocab = read_vocab(args.vocab_path)
def process_audio(tmp_file_name):
global model, vocab
if model is None or vocab is None:
print("Model or vocab not initialized.")
return "", []
audio, sr = sf.read(tmp_file_name)
# 音频预处理(转换为16000 Hz,单声道等)
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
if sr != 16000:
audio = scipy.signal.resample(audio, int(len(audio) * 16000 / sr))
audio = torch.from_numpy(audio).float()
hyp, timestamp = run_model(model, audio)
text, real_timestamp = post_process(hyp, vocab, timestamp)
return text, real_timestamp
@app.route('/transcribe', methods=['POST'])
def transcribe_audio():
if 'audio' not in request.files:
return jsonify({"error": "No audio file provided"}), 400
audio_file = request.files['audio']
if audio_file.filename == '':
return jsonify({"error": "Empty filename"}), 400
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
audio_file.save(tmp_file.name)
try:
text, timestamps = process_audio(tmp_file.name)
os.unlink(tmp_file.name)
return jsonify({
"text": text,
"timestamps": timestamps
})
except Exception as e:
os.unlink(tmp_file.name)
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
# 初始化模型和词汇表
initialize_model()
# 启动Flask应用,默认端口5000
app.run(host='0.0.0.0', port=5000, threaded=False) # RKNN可能需要单线程模式
客户端
python client.py
import requests
def send_audio_for_transcription(audio_path, server_url):
"""发送音频文件到服务端进行语音识别"""
try:
with open(audio_path, 'rb') as audio_file:
files = {'audio': (audio_path, audio_file, 'audio/wav')}
response = requests.post(server_url, files=files)
if response.status_code == 200:
result = response.json()
return result
else:
return {'error': f"Server returned status code {response.status_code}", 'details': response.text}
except Exception as e:
return {'error': str(e)}
def main():
# 硬编码音频文件路径和服务器地址
audio_path = '1.wav'
server_url = 'http://localhost:5000/transcribe'
result = send_audio_for_transcription(audio_path, server_url)
if 'error' in result:
print(f"错误发生: {result['error']}")
if 'details' in result:
print(f"详细信息: {result['details']}")
else:
print("\n识别结果:")
print(f"文本: {result['text']}")
print("时间戳:")
for timestamp in result['timestamps']:
print(f"{timestamp:.2f}s", end=" ")
print("\n")
if __name__ == "__main__":
main()
最新和安卓通讯1
app.py
import onnxruntime
from rknn.api import RKNN
import torch
import kaldifeat
import numpy as np
import json
import websockets
import asyncio
import logging
import traceback
import os
import re
# 配置日志输出
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s [%(levelname)s] %(message)s'
)
logger = logging.getLogger(__name__)
# 自定义日志级别名称为中文
logging.addLevelName(logging.DEBUG, "调试")
logging.addLevelName(logging.INFO, "信息")
logging.addLevelName(logging.WARNING, "警告")
logging.addLevelName(logging.ERROR, "错误")
logging.addLevelName(logging.CRITICAL, "严重")
class BaseModel:
def __init__(self):
logger.debug("初始化BaseModel")
self.model_config = {
'x': [1, 103, 80], 'cached_len_0': [2, 1], 'cached_len_1': [2, 1], 'cached_len_2': [2, 1],
'cached_len_3': [2, 1], 'cached_len_4': [2, 1], 'cached_avg_0': [2, 1, 256],
'cached_avg_1': [2, 1, 256], 'cached_avg_2': [2, 1, 256], 'cached_avg_3': [2, 1, 256],
'cached_avg_4': [2, 1, 256], 'cached_key_0': [2, 192, 1, 192], 'cached_key_1': [2, 96, 1, 192],
'cached_key_2': [2, 48, 1, 192], 'cached_key_3': [2, 24, 1, 192], 'cached_key_4': [2, 96, 1, 192],
'cached_val_0': [2, 192, 1, 96], 'cached_val_1': [2, 96, 1, 96], 'cached_val_2': [2, 48, 1, 96],
'cached_val_3': [2, 24, 1, 96], 'cached_val_4': [2, 96, 1, 96],
'cached_val2_0': [2, 192, 1, 96], 'cached_val2_1': [2, 96, 1, 96], 'cached_val2_2': [2, 48, 1, 96],
'cached_val2_3': [2, 24, 1, 96], 'cached_val2_4': [2, 96, 1, 96],
'cached_conv1_0': [2, 1, 256, 30], 'cached_conv1_1': [2, 1, 256, 30],
'cached_conv1_2': [2, 1, 256, 30], 'cached_conv1_3': [2, 1, 256, 30], 'cached_conv1_4': [2, 1, 256, 30],
'cached_conv2_0': [2, 1, 256, 30], 'cached_conv2_1': [2, 1, 256, 30],
'cached_conv2_2': [2, 1, 256, 30], 'cached_conv2_3': [2, 1, 256, 30], 'cached_conv2_4': [2, 1, 256, 30]
}
def init_encoder_input(self):
logger.debug("初始化编码器输入")
encoder_input_dict = {}
for input_name in self.model_config:
if 'cached_len' in input_name:
data = np.zeros((self.model_config[input_name]), dtype=np.int64)
encoder_input_dict[input_name] = data
else:
data = np.zeros((self.model_config[input_name]), dtype=np.float32)
encoder_input_dict[input_name] = data
return encoder_input_dict
def convert_nchw_to_nhwc(self, src):
dst = np.transpose(src, (0, 2, 3, 1))
return dst
def init_model(self, model_path, target, device_id):
pass
def release_model(self):
pass
def run_encoder(self, x, encoder_input_dict):
pass
def run_decoder(self, decoder_input):
pass
def run_joiner(self, encoder_out, decoder_out):
pass
class OnnxModel(BaseModel):
def __init__(self, encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
super().__init__()
logger.debug("初始化ONNX模型")
self.encoder = self.init_model(encoder_model_path, target, device_id)
self.decoder = self.init_model(decoder_model_path, target, device_id)
self.joiner = self.init_model(joiner_model_path, target, device_id)
def init_model(self, model_path, target, device_id):
logger.debug(f"加载ONNX模型: {model_path}")
if not os.path.exists(model_path):
logger.error(f"模型文件未找到: {model_path}")
raise FileNotFoundError(f"模型文件未找到: {model_path}")
try:
model = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
logger.debug(f"ONNX模型 {model_path} 加载成功")
return model
except Exception as e:
logger.error(f"加载ONNX模型失败: {model_path}, 错误: {str(e)}")
raise
def release_model(self):
logger.debug("释放ONNX模型")
del self.encoder, self.decoder, self.joiner
self.encoder = self.decoder = self.joiner = None
def run_encoder(self, x, encoder_input_dict):
encoder_input_dict['x'] = x.numpy()
inputs = {}
for input_name in self.model_config:
inputs[input_name] = encoder_input_dict[input_name]
logger.debug("运行ONNX编码器推理")
try:
out = self.encoder.run(None, inputs)
new_encoder_input_dict = {}
for idx, input_name in enumerate(self.model_config):
if idx == 0:
encoder_output = out[0]
else:
new_encoder_input_dict[input_name] = out[idx]
return encoder_output, new_encoder_input_dict
except Exception as e:
logger.error(f"ONNX编码器运行错误: {str(e)}")
logger.debug(traceback.format_exc())
raise
def run_decoder(self, decoder_input):
logger.debug(f"运行ONNX解码器,输入形状: {decoder_input.shape}")
try:
out = self.decoder.run(None, {self.decoder.get_inputs()[0].name: decoder_input})[0]
return out
except Exception as e:
logger.error(f"ONNX解码器运行错误: {str(e)}")
logger.debug(traceback.format_exc())
raise
def run_joiner(self, encoder_out, decoder_out):
logger.debug(f"运行ONNX连接器,形状: {encoder_out.shape}, {decoder_out.shape}")
try:
out = self.joiner.run(None, {
self.joiner.get_inputs()[0].name: encoder_out,
self.joiner.get_inputs()[1].name: decoder_out
})[0]
return out
except Exception as e:
logger.error(f"ONNX连接器运行错误: {str(e)}")
logger.debug(traceback.format_exc())
raise
class RKNNModel(BaseModel):
def __init__(self, encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
super().__init__()
logger.debug("初始化RKNN模型")
self.encoder = self.init_model(encoder_model_path, target, device_id)
self.decoder = self.init_model(decoder_model_path, target, device_id)
self.joiner = self.init_model(joiner_model_path, target, device_id)
def init_model(self, model_path, target, device_id):
logger.debug(f"加载RKNN模型: {model_path}")
if not os.path.exists(model_path):
logger.error(f"模型文件未找到: {model_path}")
raise FileNotFoundError(f"模型文件未找到: {model_path}")
rknn = RKNN(verbose=False)
ret = rknn.load_rknn(model_path)
if ret != 0:
logger.error('加载RKNN模型失败!')
exit(ret)
ret = rknn.init_runtime(target=target, device_id=device_id)
if ret != 0:
logger.error('初始化运行环境失败')
exit(ret)
logger.debug(f"RKNN模型 {model_path} 加载成功")
return rknn
def release_model(self):
logger.debug("释放RKNN模型")
self.encoder.release()
self.decoder.release()
self.joiner.release()
self.encoder = self.decoder = self.joiner = None
def run_encoder(self, x, encoder_input_dict):
encoder_input_dict['x'] = x.numpy()
input_list = []
for input_name in self.model_config:
input_list.append(encoder_input_dict[input_name])
logger.debug("运行RKNN编码器推理")
try:
out = self.encoder.inference(inputs=input_list)
new_encoder_input_dict = {}
for idx, input_name in enumerate(self.model_config):
if idx == 0:
encoder_output = out[0]
else:
data = out[idx]
if idx > 10:
data = self.convert_nchw_to_nhwc(data)
new_encoder_input_dict[input_name] = data
return encoder_output, new_encoder_input_dict
except Exception as e:
logger.error(f"RKNN编码器推理错误: {str(e)}")
logger.debug(traceback.format_exc())
raise
def run_decoder(self, decoder_input):
logger.debug(f"运行RKNN解码器,输入形状: {decoder_input.shape}")
try:
out = self.decoder.inference(inputs=decoder_input)[0]
return out
except Exception as e:
logger.error(f"RKNN解码器推理错误: {str(e)}")
logger.debug(traceback.format_exc())
raise
def run_joiner(self, encoder_out, decoder_out):
logger.debug(f"运行RKNN连接器,形状: {encoder_out.shape}, {decoder_out.shape}")
try:
out = self.joiner.inference(inputs=[encoder_out, decoder_out])[0]
return out
except Exception as e:
logger.error(f"RKNN连接器推理错误: {str(e)}")
logger.debug(traceback.format_exc())
raise
def read_vocab(tokens_file):
logger.debug(f"从 {tokens_file} 读取词汇表")
vocab = {}
with open(tokens_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split(' ')
if len(parts) < 2:
key, value = parts[0], ''
else:
value, key = parts[0], parts[1]
vocab[key] = value
logger.debug(f"加载了 {len(vocab)} 个词汇条目")
return vocab
def set_model(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id):
logger.debug("设置模型")
encoder_ext = os.path.splitext(encoder_model_path)[1]
decoder_ext = os.path.splitext(decoder_model_path)[1]
joiner_ext = os.path.splitext(joiner_model_path)[1]
if encoder_ext != decoder_ext or encoder_ext != joiner_ext:
logger.error("所有模型文件必须具有相同的扩展名(例如 .onnx 或 .rknn)")
raise ValueError("所有模型文件必须具有相同的扩展名(例如 .onnx 或 .rknn)")
if encoder_ext == ".onnx":
return OnnxModel(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id)
elif encoder_ext == ".rknn":
return RKNNModel(encoder_model_path, decoder_model_path, joiner_model_path, target, device_id)
else:
logger.error("不支持的模型格式:必须是 .onnx 或 .rknn")
raise ValueError("不支持的模型格式:必须是 .onnx 或 .rknn")
def post_process(hyp, vocab, timestamp):
logger.debug("后处理结果")
if not hyp or len(timestamp) != len(hyp):
return "", []
frame_shift_ms = 10
subsampling_factor = 4
frame_shift_s = frame_shift_ms / 1000.0 * subsampling_factor
real_timestamp = [round(t * frame_shift_s, 2) for t in timestamp]
# 标点插入阈值(秒)
comma_threshold = 0.4
period_threshold = 0.8
tokens_with_punct = []
for i in range(len(hyp)):
token_id = hyp[i]
token = vocab.get(str(token_id), "")
if not token:
continue
token = token.replace("▁", " ").strip()
punctuation = ''
if i < len(timestamp) - 1:
current_time = real_timestamp[i]
next_time = real_timestamp[i + 1]
delta = next_time - current_time
if delta >= period_threshold:
punctuation = '。'
elif delta >= comma_threshold:
punctuation = ','
tokens_with_punct.append(token + punctuation)
# 确保最后一个 token 有标点结尾
if tokens_with_punct:
last_token = tokens_with_punct[-1]
if not (last_token.endswith(('。', '.', '!', '!', '?', '?'))):
if last_token.endswith(','):
tokens_with_punct[-1] = last_token[:-1] + '。'
else:
tokens_with_punct[-1] += '。'
# 合并并清理多余标点
text = ''.join(tokens_with_punct)
text = re.sub(r'([。,!?\.\?!])[。,!?\.\?!]*', r'\1', text)
logger.debug(f"生成文本: {text}, 时间戳: {real_timestamp}")
return text, real_timestamp
model = None
vocab = None
def initialize_model():
global model, vocab
logger.info("初始化模型")
args = type('', (), {})()
args.encoder_model_path = "encoder-epoch-99-avg-1.onnx"
args.decoder_model_path = "decoder-epoch-99-avg-1.onnx"
args.joiner_model_path = "joiner-epoch-99-avg-1.onnx"
args.target = "rk3588"
args.device_id = None
args.vocab_path = "vocab.txt"
try:
# 如果已有模型,先释放
if model is not None:
model.release_model()
logger.debug("释放已有模型")
for model_path in [args.encoder_model_path, args.decoder_model_path, args.joiner_model_path]:
if not os.path.exists(model_path):
logger.error(f"模型文件未找到: {model_path}")
raise FileNotFoundError(f"模型文件未找到: {model_path}")
model = set_model(args.encoder_model_path, args.decoder_model_path, args.joiner_model_path, args.target, args.device_id)
vocab = read_vocab(args.vocab_path)
logger.info("模型初始化成功")
except Exception as e:
logger.error(f"模型初始化失败: {str(e)}")
logger.debug(traceback.format_exc())
raise SystemExit(1)
async def handle_client(websocket, path=None):
global model, vocab
if model is None or vocab is None:
logger.error("模型未初始化")
raise RuntimeError("模型未初始化")
try:
logger.info(f"客户端已连接,路径: {path if path is not None else 'N/A'}")
# 每个客户端维护自己的状态
encoder_input_dict = model.init_encoder_input()
opts = kaldifeat.FbankOptions()
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
fbank = kaldifeat.OnlineFbank(opts)
context_size = 2
blank_id = 0
hyp = [blank_id] * context_size
decoder_out = None
num_processed_frames = 0
timestamp = []
frame_offset = 0
segment = 103
offset = 96
while True:
try:
data = await websocket.recv()
logger.debug(f"收到数据大小: {len(data) if isinstance(data, bytes) else 0}")
if not isinstance(data, bytes):
continue
audio_np = np.frombuffer(data, dtype=np.int16)
logger.debug(f"解码 {len(audio_np)} 个采样点")
audio_float = audio_np.astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio_float)
fbank.accept_waveform(sampling_rate=16000, waveform=audio_tensor)
while True:
num_frames = fbank.num_frames_ready
if num_frames - num_processed_frames < segment:
break
frames = []
for i in range(segment):
frame = fbank.get_frame(num_processed_frames + i)
frames.append(frame)
frames = torch.cat(frames, dim=0).unsqueeze(0)
# 运行编码器,更新状态
encoder_out, encoder_input_dict = model.run_encoder(frames, encoder_input_dict)
encoder_out = encoder_out.squeeze(0)
blank_id = 0
unk_id = 2
if decoder_out is None:
decoder_input = np.array([hyp], dtype=np.int64)
decoder_out = model.run_decoder(decoder_input)
T = encoder_out.shape[0]
for t in range(T):
cur_encoder_out = encoder_out[t: t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = np.argmax(joiner_out, axis=0)
if y != blank_id and y != unk_id:
timestamp.append(frame_offset + t)
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = np.array([decoder_input], dtype=np.int64)
decoder_out = model.run_decoder(decoder_input)
frame_offset += T
current_text, real_timestamp = post_process(hyp[context_size:], vocab, timestamp)
logger.info(f"识别结果: {current_text}")
response = json.dumps({
"text": current_text,
"timestamps": real_timestamp
}).encode('utf-8')
await websocket.send(response)
# 重置状态,保留最新内容
hyp = hyp[:context_size] # 保留context部分
timestamp = []
frame_offset = 0
decoder_out = None
num_processed_frames += offset
except websockets.exceptions.ConnectionClosed:
break
except Exception as e:
logger.error(f"发生错误: {str(e)}")
logger.debug(traceback.format_exc())
try:
error_msg = json.dumps({"error": str(e)}).encode('utf-8')
await websocket.send(error_msg)
except:
pass
async def main():
try:
initialize_model()
server = await websockets.serve(
handle_client, "0.0.0.0", 5000,
ping_interval=20,
ping_timeout=None,
)
logger.info("WebSocket服务器启动,监听端口5000...")
await server.wait_closed()
except Exception as e:
logger.error(f"服务器错误: {str(e)}")
logger.debug(traceback.format_exc())
finally:
try:
if model is not None:
model.release_model()
logger.info("模型已释放")
except:
pass
if __name__ == "__main__":
asyncio.run(main())
我这里自带的yolo环境,里面主要是rknn的环境
ONNX转换为RKNN模型需要使用官方rknn_model_zoo工具:rknn_model_zoo-2.2.0
在这里把环境包复制一份到3588上安装一下就完成!!!
(有人可能没有看懂怎么在npu上搭建环境,只需要把我标出的文件拷贝到板端然后用
pip install -r 文件名称 -i https://pypi.tuna.tsinghua.edu.cn/simple 下载就完成了!!!)
经测试在电脑端及3588上都可以正常跑
开发板3588上跑缺失一个安装包报错
Traceback (most recent call last):
File "/home/orangepi/work_11.15/rknn_model_zoo/examples/zipformer/python/zipformer.py", line 4, in <module>
import kaldifeat
解决方法:安装kaldifeat模块包
进入官网From pre-compiled wheels (Recommended) — kaldifeat 1.25.5 documentation
git clone https://github.com/csukuangfj/kaldifeat cd kaldifeat
python3 setup.py install
开始编译
编译完成
支持平台
RK3566, RK3568, RK3588, RK3562, RK3576
四、下载模型(在3588上)
1、下载模型文件
cd model
./download_model.sh
2、模型转换
cd python
python convert.py ../model/encoder-epoch-99-avg-1.onnx rk3588
# output model will be saved as ../model/encoder-epoch-99-avg-1.rknn3
python convert.py ../model/decoder-epoch-99-avg-1.onnx rk3588
# output model will be saved as ../model/decoder-epoch-99-avg-1.rknn
python convert.py ../model/joiner-epoch-99-avg-1.onnx rk3588
# output model will be saved as ../model/joiner-epoch-99-avg-1.rknn
3、运行
kaldifeat模块包官网
From pre-compiled wheels (Recommended) — kaldifeat 1.25.5 documentation
https://csukuangfj.github.io/kaldifeat/cpu.html
# Install kaldifeat
# Refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html for installation.
# This python demo is tested under version: kaldifeat-1.25.4.dev20240223
开发板3588上跑
cd python
python zipformer.py --encoder_model_path encoder-epoch-99-avg-1.rknn --decoder_model_path decoder-epoch-99-avg-1.rknn --joiner_model_path joiner-epoch-99-avg-1.rknn --target 3588
完成成功输出!
电脑端跑
conda activate yuyin
python zipformer.py --encoder_model_path encoder-epoch-99-avg-1.onnx --decoder_model_path decoder-epoch-99-avg-1.onnx --joiner_model_path joiner-epoch-99-avg-1.onnx
成功运行输出:
python zipformer.py --encoder_model_path encoder-epoch-99-avg-1.onnx --decoder_model_path decoder-epoch-99-avg-1.onnx --joiner_model_path joiner-epoch-99-avg-1.onnx --sample_rate 16000