Skip to content

Commit a164e13

Browse files
committed
fill lm_probs/context_scores only if LM/ContextGraph is present (make Result smaller)
1 parent 0f1107b commit a164e13

7 files changed

+29
-22
lines changed

sherpa-onnx/csrc/hypothesis.h

+2
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ struct Hypothesis {
3636

3737
// lm_probs[i] contains the lm score for each token in ys.
3838
// Used only in transducer mofified beam-search.
39+
// Elements filled only if LM is used.
3940
std::vector<float> lm_probs;
4041

4142
// context_scores[i] contains the context-graph score for each token in ys.
4243
// Used only in transducer mofified beam-search.
44+
// Elements filled only if `ContextGraph` is used.
4345
std::vector<float> context_scores;
4446

4547
// The total score of ys in log space.

sherpa-onnx/csrc/online-recognizer.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace sherpa_onnx {
2020

2121
/// Helper for `OnlineRecognizerResult::AsJsonString()`
2222
template<typename T>
23-
const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6) {
23+
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
2424
std::ostringstream oss;
2525
oss << std::fixed << std::setprecision(precision);
2626
oss << "[ ";
@@ -35,9 +35,8 @@ const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6)
3535

3636
/// Helper for `OnlineRecognizerResult::AsJsonString()`
3737
template<> // explicit specialization for T = std::string
38-
const std::string& VecToString<std::string>(const std::vector<std::string>& vec,
39-
int32_t) // ignore 2nd arg
40-
{
38+
std::string VecToString<std::string>(const std::vector<std::string>& vec,
39+
int32_t) { // ignore 2nd arg
4140
std::ostringstream oss;
4241
oss << "[ ";
4342
std::string sep = "";
@@ -57,9 +56,10 @@ std::string OnlineRecognizerResult::AsJsonString() const {
5756
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
5857
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
5958
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
60-
os << "\"constext_scores\": " << VecToString(context_scores, 6) << ", ";
59+
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
6160
os << "\"segment\": " << segment << ", ";
62-
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time << ", ";
61+
os << "\"start_time\": " << std::fixed << std::setprecision(2)
62+
<< start_time << ", ";
6363
os << "\"is_final\": " << (is_final ? "true" : "false");
6464
os << "}";
6565
return os.str();

sherpa-onnx/csrc/online-recognizer.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ struct OnlineRecognizerResult {
4242

4343
std::vector<float> ys_probs; //< log-prob scores from ASR model
4444
std::vector<float> lm_probs; //< log-prob scores from language model
45-
std::vector<float> context_scores; //< log-domain scores from "hot-phrase" contextual boosting
45+
//
46+
/// log-domain scores from "hot-phrase" contextual boosting
47+
std::vector<float> context_scores;
4648

4749
/// ID of this segment
4850
/// When an endpoint is detected, it is incremented

sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc

-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ void OnlineTransducerGreedySearchDecoder::Decode(
153153
// probability
154154
r.ys_probs.push_back(p_logprob[y]);
155155
}
156-
157156
}
158157
if (emitted) {
159158
Ort::Value decoder_input = model_->BuildDecoderInput(*result);

sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc

+11-6
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
190190
if (new_token != 0 && new_token != unk_id_) {
191191
const Hypothesis& prev_i = prev[hyp_index];
192192
// subtract 'prev[i]' path scores, which were added before
193-
// for getting topk tokens
193+
// getting topk tokens
194194
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
195195
new_hyp.ys_probs.push_back(y_prob);
196196

197-
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
198-
if (lm_scale_ != 0.0) {
199-
lm_prob /= lm_scale_; // remove lm-scale
197+
if (lm_) { // export only when LM is used
198+
float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
199+
if (lm_scale_ != 0.0) {
200+
lm_prob /= lm_scale_; // remove lm-scale
201+
}
202+
new_hyp.lm_probs.push_back(lm_prob);
200203
}
201-
new_hyp.lm_probs.push_back(lm_prob);
202204

203-
new_hyp.context_scores.push_back(context_score);
205+
// export only when `ContextGraph` is used
206+
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
207+
new_hyp.context_scores.push_back(context_score);
208+
}
204209
}
205210

206211
hyps.Add(std::move(new_hyp));

sherpa-onnx/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ as the device_name.
276276
}
277277
}
278278

279-
using namespace std::chrono_literals;
279+
using std::chrono_literals::20ms;
280280
std::this_thread::sleep_for(20ms); // sleep for 20ms
281281
}
282282

sherpa-onnx/python/csrc/online-recognizer.cc

+6-7
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,16 @@ static void PybindOnlineRecognizerResult(py::module *m) {
3737
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
3838
.def_property_readonly(
3939
"context_scores",
40-
[](PyClass &self) -> std::vector<float> { return self.context_scores; })
41-
.def_property_readonly(
40+
[](PyClass &self) -> std::vector<float> {
41+
return self.context_scores;
42+
})
43+
.def_property_readonly(
4244
"segment",
4345
[](PyClass &self) -> int32_t { return self.segment; })
44-
.def_property_readonly(
45-
"start_time",
46-
[](PyClass &self) -> float { return self.start_time; })
47-
.def_property_readonly(
46+
.def_property_readonly(
4847
"is_final",
4948
[](PyClass &self) -> bool { return self.is_final; })
50-
.def("as_json_string", &PyClass::AsJsonString,
49+
.def("as_json_string", &PyClass::AsJsonString,
5150
py::call_guard<py::gil_scoped_release>());
5251
}
5352

0 commit comments

Comments
 (0)