Skip to content

Adding sampling to nbest #2884

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion src/lat/lattice-functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ double ComputeLatticeAlphasAndBetas(const CompactLattice &lat,
vector<double> *beta);



/// This is used in CompactLatticeLimitDepth.
struct LatticeArcRecord {
BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
Expand Down Expand Up @@ -1755,4 +1754,90 @@ void ReplaceAcousticScoresFromMap(
}
}


void SampleFromLattice(const Lattice &lat,
vector<Lattice> *nbest_lats, int32 n) {
KALDI_ASSERT(nbest_lats);
nbest_lats->clear();
nbest_lats->resize(n);

vector<double> alpha, beta;

ComputeLatticeAlphasAndBetas(lat, false, &alpha, &beta);

typedef Lattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;

KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
KALDI_ASSERT(lat.Start() == 0);

vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);

for (int32 i = 0; i < n; i++) {
StateId s = 0;
StateId out_state = 0;
Lattice &this_nbest = (*nbest_lats)[i];
this_nbest.AddState();
this_nbest.SetStart(out_state);

bool reached_final = false;
for (int32 t = 0; t <= lat.NumStates(); t++) {
double r = RandUniform(), cum_prob = 0;

Weight f = lat.Final(s);
if (f != Weight::Zero()) {
KALDI_ASSERT(state_times[s] == max_time &&
"Lattice is inconsistent (final-prob not at max_time)");

double final_like = -(f.Value1() + f.Value2());
double prob = Exp(final_like - beta[s]);
cum_prob += prob;

if (cum_prob > r) {
this_nbest.SetFinal(out_state, f);
reached_final = true;
break;
}
}

bool sampled_arc = false;
for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -ConvertToCost(arc.weight);
double prob = Exp(arc_like + beta[arc.nextstate] - beta[s]);
cum_prob += prob;

if (cum_prob > r) {
this_nbest.AddState();
this_nbest.AddArc(out_state, Arc(arc.ilabel, arc.olabel, arc.weight,
out_state + 1));
out_state++;
s = arc.nextstate;
sampled_arc = true;
break;
}
}

if (!sampled_arc) {
KALDI_ERR << "Could not sample an arc from state " << s << " at time "
<< state_times[s] << "; Something wrong with the lattice.";
}
}

KALDI_ASSERT(this_nbest.NumStates() == out_state + 1);
{
vector<int32> this_nbest_state_times;
KALDI_ASSERT(LatticeStateTimes(this_nbest, &this_nbest_state_times) == max_time);
}

if (!reached_final) {
KALDI_ERR << "Did not reach final state after " << lat.NumStates() << " steps; "
<< "Something went wrong with the lattice.";
}
}
}


} // namespace kaldi
13 changes: 13 additions & 0 deletions src/lat/lattice-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,19 @@ void ReplaceAcousticScoresFromMap(
PairHasher<int32> > &acoustic_scores,
Lattice *lat);

/// This function samples 'n' paths from a lattice by the probability of
/// those paths in the lattice. This function returns a vector of paths
/// stored in Lattice format.
///
/// @param [in] lat Input lattice
/// @param [out] nbest_lats
/// Pointer to a vector of Lattice into which the
/// n sampled paths will be stored
/// @param [in] n Number of paths to be sampled
void SampleFromLattice(const Lattice &lat,
std::vector<Lattice> *nbest_lats, int32 n = 1);


} // namespace kaldi

#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
13 changes: 11 additions & 2 deletions src/latbin/lattice-to-nbest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"

int main(int argc, char *argv[]) {
try {
Expand All @@ -44,6 +45,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage);
BaseFloat acoustic_scale = 1.0, lm_scale = 1.0;
bool random = false;
bool weighted_random = false;
int32 srand_seed = 0;
int32 n = 1;

Expand All @@ -53,6 +55,9 @@ int main(int argc, char *argv[]) {
po.Register("random", &random,
"If true, generate n random paths instead of n-best paths"
"In this case, all costs in generated paths will be zero.");
po.Register("weighted-random", &weighted_random,
"If true, generate n paths by sampling the paths based on the "
"probability of the paths");
po.Register("srand", &srand_seed, "Seed for random number generator "
"(only relevant if --random=true)");

Expand Down Expand Up @@ -91,15 +96,19 @@ int main(int argc, char *argv[]) {
std::vector<Lattice> nbest_lats;
{
Lattice nbest_lat;
if (!random) {
if (!random && !weighted_random) {
fst::ShortestPath(lat, &nbest_lat, n);
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
} else if (weighted_random) {
TopSortLatticeIfNeeded(&lat);
SampleFromLattice(lat, &nbest_lats, n);
} else {
fst::UniformArcSelector<LatticeArc> uniform_selector;
fst::RandGenOptions<fst::UniformArcSelector<LatticeArc> > opts(uniform_selector);
opts.npath = n;
fst::RandGen(lat, &nbest_lat, opts);
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
}
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
}

if (nbest_lats.empty()) {
Expand Down