// Silero VAD (ONNX Runtime) + EMA 平滑 + 健壮错误处理
// I/O 形状:
// input : float32 [1, T]
// state : float32 [2, 1, 128] // 注意 rank=3,且第0维是 2
// sr : int64 scalar (0维)
// output : float32 [1]
// stateN : float32 [2, 1, 128]
#include "silero_vad.h"
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <onnxruntime_c_api.h>
// 调试打印(编译时加 -DVAD_DEBUG 开启)
// #define VAD_DEBUG
// ---- Silero RNN 状态维度([2,1,128])----
#define STATE_LAYERS 2
#define STATE_BATCH 1
#define STATE_HSIZE 128
#define STATE_ELEMS (STATE_LAYERS*STATE_BATCH*STATE_HSIZE)
// ---- ORT 报错返回宏 ----
// 注意:这个宏用于 *返回 int* 的函数(例如 vad_run_frame)
#define ORT_RETURN_ON_ERROR(api, call) do { \
OrtStatus* _st = (call); \
if (_st) { \
const char* _msg = (api)->GetErrorMessage(_st); \
fprintf(stderr, "ORT ERROR %s:%d: %s\n", \
__FILE__, __LINE__, _msg ? _msg : "(null)"); \
(api)->ReleaseStatus(_st); \
return -2; \
} \
} while(0)
// 新增:用于 *返回指针* 的函数(例如 vad_create),统一 goto 清理
#define ORT_GOTO_ON_ERROR(api, call, label) do { \
OrtStatus* _st = (call); \
if (_st) { \
const char* _msg = (api)->GetErrorMessage(_st); \
fprintf(stderr, "ORT ERROR %s:%d: %s\n", \
__FILE__, __LINE__, _msg ? _msg : "(null)"); \
(api)->ReleaseStatus(_st); \
goto label; \
} \
} while(0)
typedef struct {
const OrtApi* api;
OrtEnv* env;
OrtSessionOptions* opts;
OrtSession* sess;
OrtAllocator* allocator;
OrtMemoryInfo* mem;
// I/O 名称
const char* name_input; // "input"
const char* name_state; // "state"
const char* name_sr; // "sr"
const char* name_output; // "output"
const char* name_stateN; // "stateN"
// RNN 隐藏状态缓存([2,1,128])
float state_in [STATE_ELEMS];
float state_out[STATE_ELEMS];
int64_t sr_scalar; // 16000
// 迟滞去抖
float th_on, th_off;
int min_on, min_off;
int cnt_on, cnt_off;
int speech_state;
// ===== 多帧平滑(EMA) =====
int smooth_on; // 0: 关;1: 开
float ema_alpha; // a∈[0,1),越大越平滑
float prob_ema; // EMA 累积状态
} Impl;
struct VADHandle { Impl v; };
VADHandle* vad_create(const char* model_path) {
if (!model_path) return NULL;
VADHandle* h = (VADHandle*)calloc(1, sizeof(VADHandle));
if (!h) return NULL;
Impl* v = &h->v;
v->api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
ORT_GOTO_ON_ERROR(v->api, v->api->CreateEnv(
ORT_LOGGING_LEVEL_WARNING, "silero_vad", &v->env), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->CreateSessionOptions(&v->opts), fail);
// 这些 Set* 也会返回 OrtStatus*,必须检查
ORT_GOTO_ON_ERROR(v->api, v->api->SetIntraOpNumThreads(v->opts, 1), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->SetInterOpNumThreads(v->opts, 1), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->SetSessionGraphOptimizationLevel(
v->opts, ORT_ENABLE_EXTENDED), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->CreateSession(
v->env, model_path, v->opts, &v->sess), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->GetAllocatorWithDefaultOptions(&v->allocator), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->CreateCpuMemoryInfo(
OrtArenaAllocator, OrtMemTypeDefault, &v->mem), fail);
v->name_input = "input";
v->name_state = "state";
v->name_sr = "sr";
v->name_output = "output";
v->name_stateN = "stateN";
memset(v->state_in, 0, sizeof(v->state_in));
memset(v->state_out, 0, sizeof(v->state_out));
v->sr_scalar = 16000;
// 迟滞默认(你原来的参数)
v->th_on = 0.20f; v->th_off = 0.10f;
v->min_on = 3; v->min_off = 5;
v->cnt_on = v->cnt_off = 0;
v->speech_state = 0;
// EMA 平滑默认:关闭;参数给一个常用值
v->smooth_on = 0;
v->ema_alpha = 0.90f;
v->prob_ema = 0.0f;
#ifdef VAD_DEBUG
size_t nin=0, nout=0;
ORT_GOTO_ON_ERROR(v->api, v->api->SessionGetInputCount(v->sess, &nin), fail);
ORT_GOTO_ON_ERROR(v->api, v->api->SessionGetOutputCount(v->sess, &nout), fail);
fprintf(stderr, "[VAD] inputs=%zu, outputs=%zu\n", nin, nout);
// 这里不再打印 I/O 名称,避免 C 版本用 OrtAllocatedStringPtr 的兼容性问题
#endif
return h;
fail:
if (v->mem) { v->api->ReleaseMemoryInfo(v->mem); v->mem = NULL; }
if (v->sess) { v->api->ReleaseSession(v->sess); v->sess = NULL; }
if (v->opts) { v->api->ReleaseSessionOptions(v->opts); v->opts = NULL; }
if (v->env) { v->api->ReleaseEnv(v->env); v->env = NULL; }
free(h);
return NULL;
}
void vad_set_hysteresis(VADHandle* h, float th_on, float th_off, int min_on, int min_off) {
if (!h) return;
Impl* v = &h->v;
v->th_on = th_on;
v->th_off = th_off;
v->min_on = (min_on > 0 ? min_on : 1);
v->min_off= (min_off > 0 ? min_off : 1);
}
void vad_set_smoothing(VADHandle* h, int enable, float alpha) {
if (!h) return;
Impl* v = &h->v;
v->smooth_on = (enable ? 1 : 0);
if (alpha < 0.0f) alpha = 0.0f;
if (alpha >= 1.0f) alpha = 0.9999f; // 防止1.0造成死记忆
v->ema_alpha = alpha;
}
int vad_run_frame(VADHandle* h, const float* pcm, size_t n_samp,
float* prob_out, int* speech_flag) {
if (!h || !pcm || n_samp == 0) return -1;
Impl* v = &h->v;
#ifdef VAD_DEBUG
fprintf(stderr, "Silero VAD run: input=[1,%zu], state=[%d,%d,%d], sr=%lld\n",
n_samp, STATE_LAYERS, STATE_BATCH, STATE_HSIZE, (long long)v->sr_scalar);
#endif
// input: float32 [1, T]
int64_t d_input[2] = {1, (int64_t)n_samp};
OrtValue* t_input = NULL;
ORT_RETURN_ON_ERROR(v->api,
v->api->CreateTensorWithDataAsOrtValue(
v->mem, (void*)pcm, sizeof(float) * n_samp,
d_input, 2, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &t_input));
// state: float32 [2,1,128] (rank=3)
int64_t d_state[3] = {STATE_LAYERS, STATE_BATCH, STATE_HSIZE};
OrtValue* t_state = NULL;
ORT_RETURN_ON_ERROR(v->api,
v->api->CreateTensorWithDataAsOrtValue(
v->mem, (void*)v->state_in, sizeof(v->state_in),
d_state, 3, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &t_state));
// sr: int64 scalar (rank=0,shape=NULL)
OrtValue* t_sr = NULL;
ORT_RETURN_ON_ERROR(v->api,
v->api->CreateTensorWithDataAsOrtValue(
v->mem, (void*)&v->sr_scalar, sizeof(int64_t),
NULL, 0, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, &t_sr));
const char* in_names[] = { v->name_input, v->name_state, v->name_sr };
const OrtValue* ins[] = { t_input, t_state, t_sr };
const char* out_names[] = { v->name_output, v->name_stateN };
OrtValue* outs[2] = { NULL, NULL };
// 运行
OrtStatus* st = v->api->Run(v->sess, NULL, in_names, ins, 3, out_names, 2, outs);
// 释放输入张量
v->api->ReleaseValue(t_input);
v->api->ReleaseValue(t_state);
v->api->ReleaseValue(t_sr);
if (st) {
const char* msg = v->api->GetErrorMessage(st);
fprintf(stderr, "ORT Run ERROR: %s\n", msg ? msg : "(null)");
v->api->ReleaseStatus(st);
return -2;
}
// output: float32[1]
float* prob_ptr = NULL;
ORT_RETURN_ON_ERROR(v->api, v->api->GetTensorMutableData(outs[0], (void**)&prob_ptr));
float prob = prob_ptr[0];
// stateN: float32[2,1,128]
float* stn_ptr = NULL;
ORT_RETURN_ON_ERROR(v->api, v->api->GetTensorMutableData(outs[1], (void**)&stn_ptr));
memcpy(v->state_out, stn_ptr, sizeof(v->state_out));
// 释放输出
v->api->ReleaseValue(outs[0]);
v->api->ReleaseValue(outs[1]);
// 回传下一帧 state
memcpy(v->state_in, v->state_out, sizeof(v->state_in));
// ===== 多帧平滑(EMA)=====
float prob_used = prob;
if (v->smooth_on) {
// y_t = a*y_{t-1} + (1-a)*x_t
v->prob_ema = v->ema_alpha * v->prob_ema + (1.0f - v->ema_alpha) * prob;
prob_used = v->prob_ema;
}
if (prob_out) *prob_out = prob_used;
// 迟滞去抖(用平滑后的 prob_used)
if (prob_used >= v->th_on) { v->cnt_on++; v->cnt_off = 0; }
else if (prob_used <= v->th_off) { v->cnt_off++; v->cnt_on = 0; }
if (!v->speech_state && v->cnt_on >= v->min_on ) v->speech_state = 1;
if ( v->speech_state && v->cnt_off >= v->min_off) v->speech_state = 0;
if (speech_flag) *speech_flag = v->speech_state;
return 0;
}
void vad_reset(VADHandle* h) {
if (!h) return;
Impl* v = &h->v;
memset(v->state_in, 0, sizeof(v->state_in));
memset(v->state_out, 0, sizeof(v->state_out));
v->cnt_on = v->cnt_off = 0;
v->speech_state = 0;
v->prob_ema = 0.0f; // 平滑状态也清零
}
void vad_destroy(VADHandle* h) {
if (!h) return;
Impl* v = &h->v;
if (v->mem) v->api->ReleaseMemoryInfo(v->mem);
if (v->sess) v->api->ReleaseSession(v->sess);
if (v->opts) v->api->ReleaseSessionOptions(v->opts);
if (v->env) v->api->ReleaseEnv(v->env);
free(h);
}
我这里面的迟滞是怎么写的