diff --git a/src/lat/lattice-functions.cc b/src/lat/lattice-functions.cc index 54c856a9403..60c56c7beae 100644 --- a/src/lat/lattice-functions.cc +++ b/src/lat/lattice-functions.cc @@ -517,7 +517,6 @@ double ComputeLatticeAlphasAndBetas(const CompactLattice &lat, vector *beta); - /// This is used in CompactLatticeLimitDepth. struct LatticeArcRecord { BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc, @@ -1755,4 +1754,90 @@ void ReplaceAcousticScoresFromMap( } } + +void SampleFromLattice(const Lattice &lat, + vector *nbest_lats, int32 n) { + KALDI_ASSERT(nbest_lats); + nbest_lats->clear(); + nbest_lats->resize(n); + + vector alpha, beta; + + ComputeLatticeAlphasAndBetas(lat, false, &alpha, &beta); + + typedef Lattice::Arc Arc; + typedef Arc::Weight Weight; + typedef Arc::StateId StateId; + + KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted); + KALDI_ASSERT(lat.Start() == 0); + + vector state_times; + int32 max_time = LatticeStateTimes(lat, &state_times); + + for (int32 i = 0; i < n; i++) { + StateId s = 0; + StateId out_state = 0; + Lattice &this_nbest = (*nbest_lats)[i]; + this_nbest.AddState(); + this_nbest.SetStart(out_state); + + bool reached_final = false; + for (int32 t = 0; t <= lat.NumStates(); t++) { + double r = RandUniform(), cum_prob = 0; + + Weight f = lat.Final(s); + if (f != Weight::Zero()) { + KALDI_ASSERT(state_times[s] == max_time && + "Lattice is inconsistent (final-prob not at max_time)"); + + double final_like = -(f.Value1() + f.Value2()); + double prob = Exp(final_like - beta[s]); + cum_prob += prob; + + if (cum_prob > r) { + this_nbest.SetFinal(out_state, f); + reached_final = true; + break; + } + } + + bool sampled_arc = false; + for (fst::ArcIterator aiter(lat, s); !aiter.Done(); aiter.Next()) { + const Arc &arc = aiter.Value(); + double arc_like = -ConvertToCost(arc.weight); + double prob = Exp(arc_like + beta[arc.nextstate] - beta[s]); + cum_prob += prob; + + if (cum_prob > r) { + this_nbest.AddState(); + this_nbest.AddArc(out_state, Arc(arc.ilabel, arc.olabel, arc.weight, + out_state + 1)); + out_state++; + s = arc.nextstate; + sampled_arc = true; + break; + } + } + + if (!sampled_arc) { + KALDI_ERR << "Could not sample an arc from state " << s << " at time " + << state_times[s] << "; Something wrong with the lattice."; + } + } + + KALDI_ASSERT(this_nbest.NumStates() == out_state + 1); + { + vector this_nbest_state_times; + KALDI_ASSERT(LatticeStateTimes(this_nbest, &this_nbest_state_times) == max_time); + } + + if (!reached_final) { + KALDI_ERR << "Did not reach final state after " << lat.NumStates() << " steps; " + << "Something went wrong with the lattice."; + } + } +} + + } // namespace kaldi diff --git a/src/lat/lattice-functions.h b/src/lat/lattice-functions.h index c7fe4833a4a..eb18b054287 100644 --- a/src/lat/lattice-functions.h +++ b/src/lat/lattice-functions.h @@ -421,6 +421,19 @@ void ReplaceAcousticScoresFromMap( PairHasher > &acoustic_scores, Lattice *lat); +/// This function samples 'n' paths from a lattice by the probability of +/// those paths in the lattice. This function returns a vector of paths +/// stored in Lattice format. +/// +/// @param [in] lat Input lattice +/// @param [out] nbest_lats +/// Pointer to a vector of Lattice into which the +/// n sampled paths will be stored +/// @param [in] n Number of paths to be sampled +void SampleFromLattice(const Lattice &lat, + std::vector *nbest_lats, int32 n = 1); + + } // namespace kaldi #endif // KALDI_LAT_LATTICE_FUNCTIONS_H_ diff --git a/src/latbin/lattice-to-nbest.cc b/src/latbin/lattice-to-nbest.cc index f5ecbe044c3..2b2dbdc2746 100644 --- a/src/latbin/lattice-to-nbest.cc +++ b/src/latbin/lattice-to-nbest.cc @@ -22,6 +22,7 @@ #include "util/common-utils.h" #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" +#include "lat/lattice-functions.h" int main(int argc, char *argv[]) { try { @@ -44,6 +45,7 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); BaseFloat acoustic_scale = 1.0, lm_scale = 1.0; bool random = false; + bool weighted_random = false; int32 srand_seed = 0; int32 n = 1; @@ -53,6 +55,9 @@ int main(int argc, char *argv[]) { po.Register("random", &random, "If true, generate n random paths instead of n-best paths" "In this case, all costs in generated paths will be zero."); + po.Register("weighted-random", &weighted_random, + "If true, generate n paths by sampling the paths based on the " + "probability of the paths"); po.Register("srand", &srand_seed, "Seed for random number generator " "(only relevant if --random=true)"); @@ -91,15 +96,19 @@ int main(int argc, char *argv[]) { std::vector nbest_lats; { Lattice nbest_lat; - if (!random) { + if (!random && !weighted_random) { fst::ShortestPath(lat, &nbest_lat, n); + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + } else if (weighted_random) { + TopSortLatticeIfNeeded(&lat); + SampleFromLattice(lat, &nbest_lats, n); } else { fst::UniformArcSelector uniform_selector; fst::RandGenOptions > opts(uniform_selector); opts.npath = n; fst::RandGen(lat, &nbest_lat, opts); + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); } - fst::ConvertNbestToVector(nbest_lat, &nbest_lats); } if (nbest_lats.empty()) {