#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "feat/wave-reader.h"
#include "feat/feature-functions.h"
#include "feat/feature-mfcc.h"
#include "feat/feature-window.h"
#include "feat/feature-spectrogram.h"
#include "feat/feature-plp.h"
#include "feat/feature-fbank.h"
#include "feat/feature-bottleneck.h"
#include "transform/fmllr-diag-gmm.h"
#include "transform/fmllr-diag-gmm-test.h"
#include "hmm/posterior.h"
#include "gmm/diag-gmm.h"
#include "gmm/decodable-am-diag-gmm.h"
#include "fst"
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
const char *usage = "Kaldi GMM alignment example.\n"
"Usage: gmm-align <config-file> <wav-file> <text-file>\n"
" e.g.: gmm-align config.txt 123.wav 123.text\n";
ParseOptions po(usage);
bool binary = true;
std::string feature_type = "mfcc";
std::string config_file, wav_scp, text_file;
po.Register("binary", &binary, "Write output in binary mode");
po.Read(argc, argv);
if (po.NumArgs() != 3) {
po.PrintUsage();
exit(1);
}
config_file = po.GetArg(1);
wav_scp = po.GetArg(2);
text_file = po.GetArg(3);
Config config;
config.ReadConfigFile(config_file);
std::string use_energy = config.GetString("feat.use-energy");
std::string sample_frequency = config.GetString("feat.sample-frequency");
std::string frame_shift = config.GetString("feat.frame-shift");
std::string length_tolerance = config.GetString("feat.length-tolerance");
std::string tree_rxfilename = config.GetString("model.tree-rxfilename");
std::string model_rxfilename = config.GetString("model.model-rxfilename");
std::string alignment_model_rxfilename = config.GetString("model.alignment-model-rxfilename");
std::string lex_rxfilename = config.GetString("model.lex-rxfilename");
std::string read_disambig_syms = config.GetString("model.read-disambig-syms");
std::string phone_symbol_table = config.GetString("model.phone-symbol-table");
std::string word_symbol_table = config.GetString("model.word-symbol-table");
std::string transition_scale = config.GetString("align.transition-scale");
std::string acoustic_scale = config.GetString("align.acoustic-scale");
std::string self_loop_scale = config.GetString("align.self-loop-scale");
std::string beam = config.GetString("align.beam");
std::string retry_beam = config.GetString("align.retry-beam");
std::string write_lengths = config.GetString("align.write-lengths");
std::string custom_output = config.GetString("align.custom-output");
DiagGmm gmm;
TransitionModel trans_model;
{
bool binary;
Input ki(model_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
gmm.Read(ki.Stream(), binary);
}
DiagGmm align_gmm;
TransitionModel align_trans_model;
{
bool binary;
Input ki(alignment_model_rxfilename, &binary);
align_trans_model.Read(ki.Stream(), binary);
align_gmm.Read(ki.Stream(), binary);
}
fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldi(lex_rxfilename);
fst::VectorFst<fst::StdArc> l_fst;
{
std::ifstream is("l.fst", std::ios_base::in | std::ios_base::binary);
l_fst.Read(is, binary ? fst::FstReadOptions() : fst::FstReadOptions("binary=false"));
}
RandomAccessTableReader<WaveHolder> wav_reader(wav_scp);
std::ifstream text_input(text_file);
std::vector<std::string> texts;
std::string line;
while (std::getline(text_input, line)) {
texts.push_back(line);
}
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const WaveData &wave_data = wav_reader.Value();
Matrix<BaseFloat> features;
try {
MfccOptions mfcc_options;
mfcc_options.use_energy = (use_energy == "true") ? true : false;
mfcc_options.frame_opts.frame_shift_ms = std::stof(frame_shift);
mfcc_options.frame_opts.frame_length_ms = 25;
mfcc_options.num_ceps = 13;
Mfcc mfcc(mfcc_options);
mfcc.Compute(wave_data.Data(), wave_data.SampFreq(), 1.0, &features);
} catch (std::exception &e) {
KALDI_WARN << "Failed to compute MFCC features for utterance " << utt;
continue;
}
ApplyCmvn(features, features);
DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, features, std::stof(acoustic_scale));
DecodableAmDiagGmmScaled align_decodable(align_gmm, align_trans_model, features, std::stof(acoustic_scale));
std::vector<int32> alignment;
{
using namespace fst;
using namespace kaldi::metrics;
VectorFst<LatticeArc> decode_fst_copy(*decode_fst);
Lattice lat;
{
LatticeFasterDecoderConfig decoder_config;
decoder_config.transition_scale = std::stof(transition_scale);
decoder_config.acoustic_scale = std::stof(acoustic_scale);
decoder_config.self_loop_scale = std::stof(self_loop_scale);
decoder_config.beam = std::stof(beam);
decoder_config.retry_beam = std::stof(retry_beam);
decoder_config.write_lengths = (write_lengths == "true") ? true : false;
decoder_config.custom_output = (custom_output == "true") ? true : false;
LatticeFasterDecoder decoder(decode_fst_copy, decoder_config);
bool success = decoder.Decode(&align_decodable);
if (!success) {
KALDI_WARN << "Failed to decode alignment for utterance " << utt;
continue;
}
decoder.GetBestPath(&alignment);
}
}
std::vector<int32> ali;
{
using namespace kaldi::metrics;
if (!-align_decodable.Decode(&align_decodable)) {
KALDI_WARN << "Failed to align features for utterance " << utt;
continue;
}
ali.resize(features.NumRows(), -1);
{
std::vector<int32> alignment;
alignment.reserve(lat.NumStates());
for (fst::StateIterator<Lattice> siter(lat); !siter.Done(); siter.Next())
for (fst::ArcIterator<Lattice> aiter(lat, siter.Value()); !aiter.Done(); aiter.Next()) {
const LatticeArc &arc = aiter.Value();
if (arc.ilabel != 0 || arc.olabel != 0)
alignment.push_back(arc.olabel);
}
if (alignment.empty() || alignment.front() != 0) {
KALDI_WARN << "Alignment did not align to utterance " << utt;
continue;
}
std::vector<int32> old2new;
int32 old2new_num_states = ComposeStates(lat, alignment, &old2new);
if (old2new.empty() || old2new.front() != 0) {
KALDI_WARN << "Alignment did not align to utterance " << utt;
continue;
}
for (int32 i = 0; i < lat.NumStates(); ++i)
if (lat.NumArcs(old2new[i]) > 0) {
int32 j = 0;
for (; j < alignment.size(); ++j)
if (alignment[j] == i) {
ali[j] = old2new[i];
break;
}
if (j == alignment.size()) KALDI_ASSERT(j == alignment.size());
}
}
}
Matrix<BaseFloat> fmllr_trans;
{
using namespace fst;
if (!fmllr_est.IsInvertible()) {
KALDI_WARN << "FMLLR transform not invertible for utterance " << utt;
continue;
}
Matrix<BaseFloat> fmllr_trans;
try {
fmllr_trans = fmllr_est.Invert();
} catch (std::exception &e) {
KALDI_WARN << "Failed to invert FMLLR transform for utterance " << utt;
continue;
}
if (fmllr_trans.NumRows() != features.NumRows() || fmllr_trans.NumCols() != features.NumCols()) {
KALDI_WARN << "Invalid dimensions of FMLLR transform for utterance " << utt;
continue;
}
Matrix<BaseFloat> feats_transformed(features.NumRows(), features.NumCols());
feats_transformed.AddMatMat(1.0, features, kNoTrans, fmllr_trans, kTrans, 0.0);
}
Matrix<BaseFloat> final_feats;
{
SpliceFeats(config, feats_transformed, &final_feats);
gmm_align_compiled.Init(l_fst, graph.Features());
std::vector<int32> ali;
if (!gmm_align_compiled.Align(features, &ali)) {
KALDI_WARN << "Failed to align features for utterance " << utt;
continue;
}
}
std::ofstream ali_output("final.ali");
for (const auto &phone : alignment) {
ali_output << phone << " ";
}
ali_output.close();
KALDI_LOG << "Alignment for utterance " << utt << " written to final.ali";
}
return 0;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}