MeloTTS-ONNX原模型及量化模型推理

MeloTTS-ONNX模型文本到语音推理

ChatTTS

概述

ChatTTS 是一个基于 MeloTTS ONNX 实现的文本到语音 (TTS) 系统。它允许用户输入文本并实时生成语音音频,支持中英混合语言合成。该项目利用 ONNX 运行时进行高效推理,并包括音频播放的实用工具。

该项目基于针对 ONNX 优化的 MeloTTS 和 OpenVoice Tone Clone 实现。目前,它支持 zh-mix-en(中英混合)语言。

项目结构

./ChatTTS
├── configs
├── gits
│   └── nltk_data
│       ├── collections
│       ├── packages
│       │   ├── chunkers
│       │   ├── corpora
│       │   ├── grammars
│       │   ├── help
│       │   ├── misc
│       │   ├── models
│       │   ├── sentiment
│       │   ├── stemmers
│       │   ├── taggers
│       │   └── tokenizers
│       └── tools
├── model_classes
├── models
│   └── melotts
│       └── MeloTTS-ONNX
│           ├── build
│           │   ├── bdist.linux-x86_64
│           │   └── lib
│           │       └── melo_onnx
│           │           └── text
│           │               ├── english_utils
│           │               ├── es_phonemizer
│           │               └── fr_phonemizer
│           ├── dist
│           ├── melo_onnx
│           │   └── text
│           │       ├── english_utils
│           │       ├── es_phonemizer
│           │       └── fr_phonemizer
│           └── melotts_onnx.egg-info
├── model_utils
└── outputs

38 directories

脚本说明

export_model_info.py -m [onnx_model_path] -o [output_path]

import onnx
from onnx import helper
import argparse

# 数据类型映射表
DATA_TYPE_MAP = {
    1: "float32",  # onnx.TensorProto.FLOAT
    2: "uint8",    # onnx.TensorProto.UINT8
    3: "int8",     # onnx.TensorProto.INT8
    4: "uint16",   # onnx.TensorProto.UINT16
    5: "int16",    # onnx.TensorProto.INT16
    6: "int32",    # onnx.TensorProto.INT32
    7: "int64",    # onnx.TensorProto.INT64
    8: "string",   # onnx.TensorProto.STRING
    9: "bool",     # onnx.TensorProto.BOOL
    10: "float16", # onnx.TensorProto.FLOAT16
    11: "float64", # onnx.TensorProto.DOUBLE
    12: "uint32",  # onnx.TensorProto.UINT32
    13: "uint64",  # onnx.TensorProto.UINT64
    14: "complex64",  # onnx.TensorProto.COMPLEX64
    15: "complex128", # onnx.TensorProto.COMPLEX128
}

def export_onnx_info(model_path, output_file="model.info"):
    """
    将ONNX模型信息导出到文本文件
    
    参数:
        model_path (str): ONNX模型文件路径
        output_file (str): 输出信息文件路径 (默认: model.info)
    """
    # 加载ONNX模型
    model = onnx.load(model_path)
    
    # 打开文件准备写入
    with open(output_file, 'w', encoding='utf-8') as f:
        # 1. 写入模型基本信息
        f.write("=" * 60 + "\n")
        f.write(f"ONNX模型基本信息\n")
        f.write("=" * 60 + "\n")
        f.write(f"模型文件路径: {model_path}\n")
        f.write(f"ONNX版本: {model.ir_version}\n")
        f.write(f"生产者信息: {model.producer_name} {model.producer_version}\n")
        f.write(f"模型版本: {model.model_version}\n")
        f.write(f"描述: {model.doc_string}\n\n")
        
        # 2. 写入模型输入信息
        f.write("=" * 60 + "\n")
        f.write(f"模型输入信息 (共 {len(model.graph.input)} 个输入)\n")
        f.write("=" * 60 + "\n")
        for i, input in enumerate(model.graph.input):
            f.write(f"Input {i+1}: {input.name}\n")
            f.write(f"  数据类型: {input.type.tensor_type.elem_type}\n")
            shape = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
            f.write(f"  形状: {shape}\n")
            if input.doc_string:
                f.write(f"  描述: {input.doc_string}\n")
            f.write("\n")
        
        # 3. 写入模型输出信息
        f.write("=" * 60 + "\n")
        f.write(f"模型输出信息 (共 {len(model.graph.output)} 个输出)\n")
        f.write("=" * 60 + "\n")
        for i, output in enumerate(model.graph.output):
            f.write(f"Output {i+1}: {output.name}\n")
            f.write(f"  数据类型: {output.type.tensor_type.elem_type}\n")
            shape = [dim.dim_value for dim in output.type.tensor_type.shape.dim]
            f.write(f"  形状: {shape}\n")
            if output.doc_string:
                f.write(f"  描述: {output.doc_string}\n")
            f.write("\n")
        
        # 4. 写入模型权重信息
        f.write("=" * 60 + "\n")
        f.write(f"模型权重信息 (共 {len(model.graph.initializer)} 个参数)\n")
        f.write("=" * 60 + "\n")
        total_params = 0
        for initializer in model.graph.initializer:
            dims = initializer.dims
            param_count = 1
            for dim in dims:
                param_count *= dim
            total_params += param_count
            
            # 方法1:直接使用映射表获取类型名
            if initializer.data_type in DATA_TYPE_MAP:
                dtype = DATA_TYPE_MAP[initializer.data_type]
            else:
                try:
                    # 方法2:尝试使用helper函数
                    np_dtype = helper.tensor_dtype_to_np_dtype(initializer.data_type)
                    dtype = np_dtype.name if hasattr(np_dtype, 'name') else str(np_dtype)
                except:
                    # 方法3:完全手动
                    dtype = f"未知类型({initializer.data_type})"
            
            f.write(f"参数名: {initializer.name}\n")
            f.write(f"  数据类型: {dtype}\n")
            f.write(f"  形状: {dims}\n")
            f.write(f"  参数数量: {param_count}\n")
            f.write("\n")
            
        f.write(f"模型总参数量: {total_params}\n\n")
        
        # 5. 写入运算符信息
        f.write("=" * 60 + "\n")
        f.write(f"模型运算符信息 (共 {len(model.graph.node)} 个运算符)\n")
        f.write("=" * 60 + "\n")
        op_counts = {}
        for node in model.graph.node:
            op_type = node.op_type
            op_counts[op_type] = op_counts.get(op_type, 0) + 1
        
        # 按运算符使用频率排序
        sorted_ops = sorted(op_counts.items(), key=lambda x: x[1], reverse=True)
        
        f.write(f"{'运算符类型':<15} | 使用次数\n")
        f.write("-" * 30 + "\n")
        for op_type, count in sorted_ops:
            f.write(f"{op_type:<15} | {count}\n")
        
        f.write("\n详细层信息:\n")
        f.write("-" * 60 + "\n")
        for i, node in enumerate(model.graph.node):
            f.write(f"层 {i+1}: {node.name} ({node.op_type})\n")
            f.write(f"  输入: {', '.join(node.input)}\n")
            f.write(f"  输出: {', '.join(node.output)}\n")
            f.write(f"  属性:\n")
            for attr in node.attribute:
                if attr.name in ['kernel_shape', 'strides', 'pads']:
                    value = list(attr.ints)
                elif attr.type == onnx.AttributeProto.FLOAT:
                    value = attr.f
                else:
                    value = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else attr.s
                f.write(f"    - {attr.name}: {value}\n")
            f.write("\n")
        
        # 6. 写入元数据信息
        if model.metadata_props:
            f.write("\n" + "=" * 60 + "\n")
            f.write("模型元数据信息\n")
            f.write("=" * 60 + "\n")
            for key, value in model.metadata_props.items():
                f.write(f"{key}: {value}\n")
        
        # 7. 可选的额外信息
        f.write("\n" + "=" * 60 + "\n")
        f.write("额外信息\n")
        f.write("=" * 60 + "\n")
        f.write(f"导入版本: {', '.join([f'{opset.domain}-{opset.version}' for opset in model.opset_import])}\n")
        f.write(f"模型计算图名称: {model.graph.name}\n")
    
    print(f"模型信息已成功导出到 {output_file}")

# 使用示例
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Export ONNX model information')
    parser.add_argument('--model_path', '-m', type=str, required=True, help='Path to the ONNX model file')
    parser.add_argument('--output_file', '-o', type=str, default='model.info', help='Path to the output file')
    args = parser.parse_args()
    
    # onnx_model_path = "./models/TTS/outetts_onnx/model.onnx"  # 替换为你的ONNX模型路径
    
    onnx_model_path = args.model_path
    output_file = args.output_file
    export_onnx_info(onnx_model_path, output_file)

量化说明

bert_lml_model_ori.onnx量化
模型名称:bert_lml_model_w4a16.onnx
模型导出路径:ChatTTS/outputs/bert_lml_model_w4a16.onnx
量化类型:w4a16
模型大小缩减:583MB -> 196MB
CPU占用率:24% -> 21%

功能

  • 实时文本到语音合成
  • 使用 sounddevice 和 librosa 进行音频重采样和播放
  • 简单的命令行界面用于输入文本
  • 利用预训练的 ONNX 模型进行快速推理

要求

该项目需要 Python 3.x 以及 req.info 中列出的依赖项。该文件包含已安装包的 pip freeze 输出。

关键依赖项包括:

  • melo_onnx(从 models/melotts/MeloTTS-ONNX 的可编辑安装)
  • sounddevice
  • numpy
  • librosa
  • nltk(用于 POS 标记,如 run.py 中所示)

安装

  1. 克隆仓库:

    git clone https://gitee.com/jackroing/ai-tts.git
    
  2. 安装依赖项:

    req.info 文件列出了所有必需的包。您可以使用 pip 安装它们,但请注意它包括 melo_onnx 的可编辑安装。

    pip install -r req.info
    

    如果可编辑路径有问题,请手动安装 melo_onnx 包:

    cd models/melotts/MeloTTS-ONNX
    pip install -e .
    cd ../../..
    

    确保从 req.info 安装所有其他依赖项。

  3. 设置 NLTK 数据:

    NLTK 需要特定的数据文件来运行。将项目 NLTK 数据目录的内容复制到系统 NLTK 数据路径:

    sudo cp -r gits/nltk_data/packages/* /root/nltk_data/
    

    注意: 此步骤对于 NLTK 功能至关重要,例如项目中使用的 POS 标记。目标目录 /root/nltk_data/ 假设具有 root 访问权限;如果您的 NLTK 数据路径不同(通常为 ~/nltk_data/),请调整。

  4. 下载模型(如果必要):

    模型位于 models/melotts/ 中。如果不存在,请从以下位置下载:

    • TTS 模型:ModelScopeHugging Face
    • 确保 tts_model.onnxbert_lml_model.onnx 等文件位于 models/melotts/ 中。

使用

运行主脚本以启动 TTS 接口:

python run.py
  • 脚本将查询音频设备并初始化 TTS 模型。
  • 在提示符处输入文本(例如,“Hello, how are you?”)。
  • 系统将合成音频,将其重采样到 48kHz,并播放。
  • 输入“exit”以退出。

示例

请输入: Hello, this is a test.
[音频播放]
请输入: exit

注意事项

  • 音频设备: 脚本硬编码输出设备=8。请根据系统音频设备(使用 get_devinfo() 输出)在 run.py 中调整。
  • 语言支持: 目前针对中英混合进行了优化。为获得最佳结果,输入文本应为小写。
  • 性能: 默认使用 CPUExecutionProvider。如果可用 GPU,请修改为“CUDAExecutionProvider”。
  • NLTK 检查: 脚本在启动时包含 POS 标记测试以验证 NLTK 安装。
  • 确保您具有必要的音频硬件用于播放。

输出结果

在这里插入图片描述

贡献

欢迎贡献!请打开 issue 或 pull request 以进行改进。

许可证

该项目根据 MIT 许可证授权(基于底层 MeloTTS 的假设;如有需要,请调整)。

有关底层 MeloTTS ONNX 实现的更多详细信息,请参阅 models/melotts/README_cn.md

联系方式

  • 公众号:“CrazyNET”
  • 仓库地址:https://gitee.com/jackroing/ai-tts.git
  • 项目获取:公众号回复 “melotts”
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值