gmm对齐代码

#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "feat/wave-reader.h"
#include "feat/feature-functions.h"
#include "feat/feature-mfcc.h"
#include "feat/feature-window.h"
#include "feat/feature-spectrogram.h"
#include "feat/feature-plp.h"
#include "feat/feature-fbank.h"
#include "feat/feature-bottleneck.h"
#include "transform/fmllr-diag-gmm.h"
#include "transform/fmllr-diag-gmm-test.h"
#include "hmm/posterior.h"
#include "gmm/diag-gmm.h"
#include "gmm/decodable-am-diag-gmm.h"
#include "fst"

int main(int argc, char *argv[]) {
  try {
    using namespace kaldi;
    typedef kaldi::int32 int32;
    typedef kaldi::int64 int64;

    const char *usage = "Kaldi GMM alignment example.\n"
                        "Usage: gmm-align <config-file> <wav-file> <text-file>\n"
                        " e.g.: gmm-align config.txt 123.wav 123.text\n";

    ParseOptions po(usage);
    bool binary = true;
    std::string feature_type = "mfcc";
    std::string config_file, wav_scp, text_file;
    po.Register("binary", &binary, "Write output in binary mode");
    po.Read(argc, argv);

    if (po.NumArgs() != 3) {
      po.PrintUsage();
      exit(1);
    }

    config_file = po.GetArg(1);
    wav_scp = po.GetArg(2);
    text_file = po.GetArg(3);

    // Read configuration file
    Config config;
    config.ReadConfigFile(config_file);

    // Read feature configuration
    std::string use_energy = config.GetString("feat.use-energy");
    std::string sample_frequency = config.GetString("feat.sample-frequency");
    std::string frame_shift = config.GetString("feat.frame-shift");
    std::string length_tolerance = config.GetString("feat.length-tolerance");

    // Read model configuration
    std::string tree_rxfilename = config.GetString("model.tree-rxfilename");
    std::string model_rxfilename = config.GetString("model.model-rxfilename");
    std::string alignment_model_rxfilename = config.GetString("model.alignment-model-rxfilename");
    std::string lex_rxfilename = config.GetString("model.lex-rxfilename");
    std::string read_disambig_syms = config.GetString("model.read-disambig-syms");
    std::string phone_symbol_table = config.GetString("model.phone-symbol-table");
    std::string word_symbol_table = config.GetString("model.word-symbol-table");

    // Read alignment configuration
    std::string transition_scale = config.GetString("align.transition-scale");
    std::string acoustic_scale = config.GetString("align.acoustic-scale");
    std::string self_loop_scale = config.GetString("align.self-loop-scale");
    std::string beam = config.GetString("align.beam");
    std::string retry_beam = config.GetString("align.retry-beam");
    std::string write_lengths = config.GetString("align.write-lengths");
    std::string custom_output = config.GetString("align.custom-output");

    DiagGmm gmm;
    TransitionModel trans_model;
    {
      bool binary;
      Input ki(model_rxfilename, &binary);
      trans_model.Read(ki.Stream(), binary);
      gmm.Read(ki.Stream(), binary);
    }

    DiagGmm align_gmm;
    TransitionModel align_trans_model;
    {
      bool binary;
      Input ki(alignment_model_rxfilename, &binary);
      align_trans_model.Read(ki.Stream(), binary);
      align_gmm.Read(ki.Stream(), binary);
    }

    fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldi(lex_rxfilename);
    fst::VectorFst<fst::StdArc> l_fst;
    {
      std::ifstream is("l.fst", std::ios_base::in | std::ios_base::binary);
      l_fst.Read(is, binary ? fst::FstReadOptions() : fst::FstReadOptions("binary=false"));
    }
    
    // Read audio
    RandomAccessTableReader<WaveHolder> wav_reader(wav_scp);
    
    // Read text
    std::ifstream text_input(text_file);
    std::vector<std::string> texts;
    std::string line;
    while (std::getline(text_input, line)) {
      texts.push_back(line);
    }
    
    for (; !wav_reader.Done(); wav_reader.Next()) {
      std::string utt = wav_reader.Key();
      const WaveData &wave_data = wav_reader.Value();
      
      // Compute MFCC features
      Matrix<BaseFloat> features;
      try {
        MfccOptions mfcc_options;
        mfcc_options.use_energy = (use_energy == "true") ? true : false;
        mfcc_options.frame_opts.frame_shift_ms = std::stof(frame_shift);
        mfcc_options.frame_opts.frame_length_ms = 25;
        mfcc_options.num_ceps = 13;
        Mfcc mfcc(mfcc_options);
      
        // Process wave data
        mfcc.Compute(wave_data.Data(), wave_data.SampFreq(), 1.0, &features);
      } catch (std::exception &e) {
        KALDI_WARN << "Failed to compute MFCC features for utterance " << utt;
        continue;
      }
      
      // Normalize features
      ApplyCmvn(features, features);
      
      // ④启动model和alignment model
      DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, features, std::stof(acoustic_scale));
      DecodableAmDiagGmmScaled align_decodable(align_gmm, align_trans_model, features, std::stof(acoustic_scale));
      
      // ⑤compile-train-graphs
      std::vector<int32> alignment;
      {
        using namespace fst;
        using namespace kaldi::metrics;
        
        VectorFst<LatticeArc> decode_fst_copy(*decode_fst);
        Lattice lat;
        {
          LatticeFasterDecoderConfig decoder_config;
          decoder_config.transition_scale = std::stof(transition_scale);
          decoder_config.acoustic_scale = std::stof(acoustic_scale);
          decoder_config.self_loop_scale = std::stof(self_loop_scale);
          decoder_config.beam = std::stof(beam);
          decoder_config.retry_beam = std::stof(retry_beam);
          decoder_config.write_lengths = (write_lengths == "true") ? true : false;
          decoder_config.custom_output = (custom_output == "true") ? true : false;
          
          LatticeFasterDecoder decoder(decode_fst_copy, decoder_config);
          bool success = decoder.Decode(&align_decodable);
          if (!success) {
            KALDI_WARN << "Failed to decode alignment for utterance " << utt;
            continue;
          }
          decoder.GetBestPath(&alignment);
        }
      }
      
      // ⑥gmm-align-compiled
      std::vector<int32> ali;
      {
        using namespace kaldi::metrics;
        if (!-align_decodable.Decode(&align_decodable)) {
          KALDI_WARN << "Failed to align features for utterance " << utt;
          continue;
        }
        ali.resize(features.NumRows(), -1);
        {
          std::vector<int32> alignment;
          alignment.reserve(lat.NumStates());
          for (fst::StateIterator<Lattice> siter(lat); !siter.Done(); siter.Next())
            for (fst::ArcIterator<Lattice> aiter(lat, siter.Value()); !aiter.Done(); aiter.Next()) {
              const LatticeArc &arc = aiter.Value();
              if (arc.ilabel != 0 || arc.olabel != 0)
                alignment.push_back(arc.olabel);
            }
          if (alignment.empty() || alignment.front() != 0) {
            KALDI_WARN << "Alignment did not align to utterance " << utt;
            continue;
          }
          std::vector<int32> old2new;
          int32 old2new_num_states = ComposeStates(lat, alignment, &old2new);
          if (old2new.empty() || old2new.front() != 0) {
            KALDI_WARN << "Alignment did not align to utterance " << utt;
            continue;
          }
          for (int32 i = 0; i < lat.NumStates(); ++i)
            if (lat.NumArcs(old2new[i]) > 0) {
              int32 j = 0;
              for (; j < alignment.size(); ++j)
                if (alignment[j] == i) {
                  ali[j] = old2new[i];
                  break;
                }
              if (j == alignment.size()) KALDI_ASSERT(j == alignment.size());
            }
        }
      }
      
      // ⑦估计声学特征的线性变换
      Matrix<BaseFloat> fmllr_trans;
      {
        using namespace fst;
        if (!fmllr_est.IsInvertible()) {
          KALDI_WARN << "FMLLR transform not invertible for utterance " << utt;
          continue;
        }
        Matrix<BaseFloat> fmllr_trans;
        try {
          fmllr_trans = fmllr_est.Invert();
        } catch (std::exception &e) {
          KALDI_WARN << "Failed to invert FMLLR transform for utterance " << utt;
          continue;
        }
        if (fmllr_trans.NumRows() != features.NumRows() || fmllr_trans.NumCols() != features.NumCols()) {
          KALDI_WARN << "Invalid dimensions of FMLLR transform for utterance " << utt;
          continue;
        }
        Matrix<BaseFloat> feats_transformed(features.NumRows(), features.NumCols());
        feats_transformed.AddMatMat(1.0, features, kNoTrans, fmllr_trans, kTrans, 0.0);
      }
      
      // ⑧transform-feats → gmm-align-compiled
      Matrix<BaseFloat> final_feats;
      {
        SpliceFeats(config, feats_transformed, &final_feats);
        gmm_align_compiled.Init(l_fst, graph.Features());
        std::vector<int32> ali;
        if (!gmm_align_compiled.Align(features, &ali)) {
          KALDI_WARN << "Failed to align features for utterance " << utt;
          continue;
        }
      }
      
      // Output the alignment
      std::ofstream ali_output("final.ali");
      for (const auto &phone : alignment) {
        ali_output << phone << " ";
      }
      ali_output.close();
      
      KALDI_LOG << "Alignment for utterance " << utt << " written to final.ali";
    }

    return 0;
  } catch (const std::exception &e) {
    std::cerr << e.what();
    return -1;
  }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值