From 01affefaafecfc61a3e275a72f7e05260c7f81ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B1=85=E6=88=8E=E6=B0=8F?= Date: Sun, 3 Mar 2024 17:14:54 +0800 Subject: [PATCH] feat(user_dictionary): predict word --- src/rime/dict/user_dictionary.cc | 68 ++++++++++++++++++++++++++---- src/rime/dict/user_dictionary.h | 4 +- src/rime/gear/script_translator.cc | 9 +++- 3 files changed, 70 insertions(+), 11 deletions(-) diff --git a/src/rime/dict/user_dictionary.cc b/src/rime/dict/user_dictionary.cc index 73630610fc..9b57a69b0b 100644 --- a/src/rime/dict/user_dictionary.cc +++ b/src/rime/dict/user_dictionary.cc @@ -16,14 +16,17 @@ #include #include #include +#include #include #include #include +#include namespace rime { struct DfsState { size_t depth_limit; + size_t predict_word_from_depth; TickCount present_tick; Code code; vector credibility; @@ -32,13 +35,15 @@ struct DfsState { string key; string value; + size_t depth() const { return code.size(); } + bool IsExactMatch(const string& prefix) { return boost::starts_with(key, prefix + '\t'); } bool IsPrefixMatch(const string& prefix) { return boost::starts_with(key, prefix); } - void RecruitEntry(size_t pos); + void RecruitEntry(size_t pos, map* syllabary = nullptr); bool NextEntry() { if (!accessor->GetNextRecord(&key, &value)) { key.clear(); @@ -63,11 +68,30 @@ struct DfsState { } }; -void DfsState::RecruitEntry(size_t pos) { +void DfsState::RecruitEntry(size_t pos, map* syllabary) { + string full_code; auto e = UserDictionary::CreateDictEntry(key, value, present_tick, - credibility.back()); + credibility.back(), + syllabary ? &full_code : nullptr); if (e) { - e->code = code; + if (syllabary) { + vector syllables = + strings::split(full_code, " ", strings::SplitBehavior::SkipToken); + Code numeric_code; + for (auto s = syllables.begin(); s != syllables.end(); ++s) { + auto found = syllabary->find(*s); + if (found == syllabary->end()) { + LOG(ERROR) << "failed to recruit dict entry '" << e->text + << "', unrecognized syllable: " << *s; + return; + } + numeric_code.push_back(found->second); + } + e->code = numeric_code; + e->matching_code_size = code.size(); + } else { + e->code = code; + } DLOG(INFO) << "add entry at pos " << pos; query_result[pos].push_back(e); } @@ -230,10 +254,36 @@ void UserDictionary::DfsLookup(const SyllableGraph& syll_graph, if (!state->NextEntry()) // reached the end of db break; } - // the caller can limit the number of syllables to look up - if ((!state->depth_limit || state->code.size() < state->depth_limit) && - state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore' - DfsLookup(syll_graph, end_pos, prefix, state); + auto next_index = syll_graph.indices.find(end_pos); + if (next_index == syll_graph.indices.end()) { + // reached the end of input, predict word if requested + if (state->predict_word_from_depth != 0 && + state->depth() >= state->predict_word_from_depth) { + while (state->IsPrefixMatch(prefix)) { + DLOG(INFO) << "prefix match found for '" << prefix << "'."; + if (syllabary_.empty()) { + Syllabary syllabary; + if (!table_->GetSyllabary(&syllabary)) { + LOG(ERROR) << "failed to get syllabary for user dict: " + << name(); + break; + } + SyllableId syllable_id = 0; + for (auto s = syllabary.begin(); s != syllabary.end(); ++s) { + syllabary_[*s] = syllable_id++; + } + } + state->RecruitEntry(end_pos, &syllabary_); + if (!state->NextEntry()) // reached the end of db + break; + } + } + } else { + // the caller can limit the number of syllables to look up + if ((!state->depth_limit || state->depth() < state->depth_limit) && + state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore' + DfsLookup(syll_graph, end_pos, prefix, state); + } } } if (!state->IsPrefixMatch(current_prefix)) // 'b |' vs. 'g o \tGo' @@ -254,12 +304,14 @@ an UserDictionary::Lookup( const SyllableGraph& syll_graph, size_t start_pos, size_t depth_limit, + size_t predict_word_from_depth, double initial_credibility) { if (!table_ || !prism_ || !loaded() || start_pos >= syll_graph.interpreted_length) return nullptr; DfsState state; state.depth_limit = depth_limit; + state.predict_word_from_depth = predict_word_from_depth; FetchTickCount(); state.present_tick = tick_ + 1; state.credibility.push_back(initial_credibility); diff --git a/src/rime/dict/user_dictionary.h b/src/rime/dict/user_dictionary.h index 653164eb26..3e289a842d 100644 --- a/src/rime/dict/user_dictionary.h +++ b/src/rime/dict/user_dictionary.h @@ -59,6 +59,7 @@ class UserDictionary : public Class { an Lookup(const SyllableGraph& syllable_graph, size_t start_pos, size_t depth_limit = 0, + size_t predict_word_from_depth = 0, double initial_credibility = 0.0); size_t LookupWords(UserDictEntryIterator* result, const string& input, @@ -82,7 +83,7 @@ class UserDictionary : public Class { const string& value, TickCount present_tick, double credibility = 0.0, - string* full_code = NULL); + string* full_code = nullptr); protected: bool Initialize(); @@ -98,6 +99,7 @@ class UserDictionary : public Class { an db_; an table_; an prism_; + map syllabary_; TickCount tick_ = 0; time_t transaction_time_ = 0; }; diff --git a/src/rime/gear/script_translator.cc b/src/rime/gear/script_translator.cc index b71392c7bc..f8a94f1fee 100644 --- a/src/rime/gear/script_translator.cc +++ b/src/rime/gear/script_translator.cc @@ -356,7 +356,11 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) { phrase_ = dict->Lookup(syllable_graph, 0, predict_word); if (user_dict) { - user_phrase_ = user_dict->Lookup(syllable_graph, 0); + const size_t kUnlimitedDepth = 0; + const size_t kNumSyllablesToPredictWord = 4; + user_phrase_ = + user_dict->Lookup(syllable_graph, 0, kUnlimitedDepth, + predict_word ? kNumSyllablesToPredictWord : 0); } if (!phrase_ && !user_phrase_) return false; @@ -371,7 +375,8 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) { phrase_ && phrase_iter_->first == consumed && is_exact_match_phrase(phrase_iter_->second.Peek()); bool has_exact_match_user_phrase = - user_phrase_ && user_phrase_iter_->first == consumed; + user_phrase_ && user_phrase_iter_->first == consumed && + is_exact_match_phrase(user_phrase_iter_->second.Peek()); bool has_at_least_two_syllables = syllable_graph.edges.size() >= 2; if (!has_exact_match_phrase && !has_exact_match_user_phrase && has_at_least_two_syllables) {