From 04c7defd241b586e4f2f3badb74e09d7e23d2f41 Mon Sep 17 00:00:00 2001 From: RoberLopez Date: Mon, 23 Dec 2024 18:53:17 +0100 Subject: [PATCH] clean --- examples/translation/main.cpp | 7 +- opennn/language_data_set.cpp | 656 +++------------------------------- opennn/language_data_set.h | 13 +- opennn/strings_utilities.cpp | 20 -- opennn/strings_utilities.h | 8 +- opennn/tensors.cpp | 7 + opennn/tensors.h | 3 + opennn/transformer.cpp | 14 +- 8 files changed, 62 insertions(+), 666 deletions(-) diff --git a/examples/translation/main.cpp b/examples/translation/main.cpp index 9234a017e..8146d6e5a 100644 --- a/examples/translation/main.cpp +++ b/examples/translation/main.cpp @@ -37,11 +37,8 @@ int main() // Data set - LanguageDataSet language_data_set("C:/Users/Roberto Lopez/Documents/opennn/examples/amazon_reviews/data/amazon_cells_labelled.txt"); - // LanguageDataSet language_data_set("/home/artelnics/Escritorio/andres_alonso/ViT/dataset/amazon_reviews/amazon_cells_reduced.txt"); - // LanguageDataSet language_data_set("/home/artelnics/Escritorio/andres_alonso/ViT/dataset/ENtoES_dataset50000.txt"); - // LanguageDataSet language_data_set("/home/artelnics/Escritorio/andres_alonso/ViT/dataset/dataset_ingles_espanol.txt"); - + LanguageDataSet language_data_set("C:/translation.csv"); +/* // cout<> filter_inputs(const vector>& tri vector generate_override_vocabulary(const vector& reserved_tokens, - const set& character_tokens, - const map& current_tokens) + const set& character_tokens, + const map& current_tokens) { vector vocabulary; vocabulary.insert(vocabulary.end(), reserved_tokens.begin(), reserved_tokens.end()); @@ -1111,22 +1109,22 @@ vector calculate_vocabulary_binary_search(const vector LanguageDataSet::create_vocabulary(const vector>& tokens, - const Index& vocabulary_size, - const vector& reserved_tokens, - const Index& upper_threshold, - const Index& lower_threshold, - const Index& iterations_number, - const Index& max_input_tokens, - const Index& max_token_length, - const Index& max_unique_characters, - const float& slack_ratio, - const bool& include_joiner_token, - const string& joiner) + const Index& vocabulary_size, + const Index& upper_threshold, + const Index& lower_threshold, + const Index& iterations_number, + const Index& max_input_tokens, + const Index& max_token_length, + const Index& max_unique_characters, + const float& slack_ratio, + const bool& include_joiner_token, + const string& joiner) { + const vector total_tokens = tokens_list(tokens); - + const vector> word_counts = count_words(total_tokens); - + const auto [upper_search, lower_search] = calculate_thresholds(word_counts, upper_threshold, lower_threshold); @@ -1141,6 +1139,7 @@ vector LanguageDataSet::create_vocabulary(const vector>& const vector> filtered_counts = filter_inputs(trimmed_counts, allowed_characters, max_input_tokens); + /* const vector vocabulary = calculate_vocabulary_binary_search(filtered_counts, lower_search, @@ -1149,576 +1148,10 @@ vector LanguageDataSet::create_vocabulary(const vector>& return vocabulary; */ - return vector(); } -void LanguageDataSet::load_documents() -{ -} - - -void LanguageDataSet::read_csv_3_language_model() -{ - ifstream file(data_path); - - const bool is_float = is_same::value; - - const string separator_string = get_separator_string(); - - string line; - - //skip_header(file); - - // Read data - - const Index raw_variables_number = has_sample_ids ? get_raw_variables_number() + 1 : get_raw_variables_number(); - - vector tokens(raw_variables_number); - - const Index samples_number = data.dimension(0); - - if(has_sample_ids) sample_ids.resize(samples_number); - - if(display) cout << "Reading data..." << endl; - - Index sample_index = 0; - Index raw_variable_index = 0; - - while(getline(file, line)) - { - prepare_line(line); - - if(line.empty()) continue; - - fill_tokens(line, separator_string, tokens); - - for(Index j = 0; j < raw_variables_number; j++) - { - trim(tokens[j]); - - if(has_sample_ids && j == 0) - { - sample_ids[sample_index] = tokens[j]; - - continue; - } - - if(tokens[j] == missing_values_label || tokens[j].empty()) - data(sample_index, raw_variable_index) = type(NAN); - else if(is_float) - data(sample_index, raw_variable_index) = type(strtof(tokens[j].data(), nullptr)); - else - data(sample_index, raw_variable_index) = type(stof(tokens[j])); - - raw_variable_index++; - } - - raw_variable_index = 0; - sample_index++; - } - - const Index data_file_preview_index = has_header ? 3 : 2; - - data_file_preview[data_file_preview_index] = tokens; - - file.close(); - - if(display) cout << "Data read successfully..." << endl; -} - - -void LanguageDataSet::read_csv_1() -{ - if(display) cout << "Path: " << data_path << endl; - - if(data_path.empty()) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv() method.\n" - << "Data file name is empty.\n"; - - throw runtime_error(buffer.str()); - } - - regex accent_regex("[\\xC0-\\xFF]"); - ifstream file(data_path); - - if(!file.is_open()) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv() method.\n" - << "Cannot open data file: " << data_path << "\n"; - - throw runtime_error(buffer.str()); - } - - const string separator_char = get_separator_string(); - - if(display) cout << "Setting data file preview..." << endl; - - Index lines_number = has_binary_raw_variables()? 4 : 3; - - data_file_preview.resize(lines_number); - - string line; - - Index lines_count = 0; - - while(file.good()) - { - getline(file, line); - - decode(line); - - trim(line); - - erase(line, '"'); - - if(line.empty()) continue; - - check_separators(line); - - data_file_preview[lines_count] = get_tokens(line, separator_char); - - lines_count++; - - if(lines_count == lines_number) break; - } - - file.close(); - - // Check empty file - - if(data_file_preview[0].size() == 0) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv_1() method.\n" - << "File " << data_path << " is empty.\n"; - - throw runtime_error(buffer.str()); - } - - // Set rows labels and raw_variables names - - if(display) cout << "Setting rows labels..." << endl; - - string first_name = data_file_preview[0][0]; - transform(first_name.begin(), first_name.end(), first_name.begin(), ::tolower); - - const Index raw_variables_number = get_has_rows_labels() ? data_file_preview[0].size()-1 : data_file_preview[0].size(); - - raw_variables.resize(raw_variables_number); - - // Check if header has numeric value - - if(has_binary_raw_variables() && has_numbers(data_file_preview[0])) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv_1() method.\n" - << "Some raw_variables names are numeric.\n"; - - throw runtime_error(buffer.str()); - } - - // raw_variables names - - if(display) cout << "Setting raw_variables names..." << endl; - - if(has_binary_raw_variables()) - { -/* - get_has_rows_labels() ? set_raw_variable_names(data_file_preview[0].slice(Eigen::array({1}), - Eigen::array({data_file_preview[0].size()-1}))) - : set_raw_variable_names(data_file_preview[0]); -*/ - } - else - { - set_raw_variable_names(get_default_raw_variables_names(raw_variables_number)); - } - - // Check raw_variables with all missing values - - bool has_nans_raw_variables = false; - - do - { - has_nans_raw_variables = false; - - if(lines_number > 10) - break; - - for(size_t i = 0; i < data_file_preview[0].size(); i++) - { - if(get_has_rows_labels() && i == 0) continue; - - // Check if all are missing values - - if(data_file_preview[1][i] == missing_values_label - && data_file_preview[2][i] == missing_values_label - && data_file_preview[lines_number-2][i] == missing_values_label - && data_file_preview[lines_number-1][i] == missing_values_label) - has_nans_raw_variables = true; - else - has_nans_raw_variables = false; - - if(has_nans_raw_variables) - { - lines_number++; - data_file_preview.resize(lines_number); - - lines_count = 0; - - file.open(data_path.c_str()); - - if(!file.is_open()) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv() method.\n" - << "Cannot open data file: " << data_path << "\n"; - - throw runtime_error(buffer.str()); - } - - while(file.good()) - { - getline(file, line); - decode(line); - trim(line); - erase(line, '"'); - if(line.empty()) continue; - check_separators(line); - data_file_preview[lines_count] = get_tokens(line, separator_char); - lines_count++; - if(lines_count == lines_number) break; - } - - file.close(); - } - } - }while(has_nans_raw_variables); - - // raw_variables types - - if(display) cout << "Setting raw_variables types..." << endl; - - Index raw_variable_index = 0; - - for(size_t i = 0; i < data_file_preview[0].size(); i++) - { - if(get_has_rows_labels() && i == 0) continue; - - string data_file_preview_1 = data_file_preview[1][i]; - string data_file_preview_2 = data_file_preview[2][i]; - string data_file_preview_3 = data_file_preview[lines_number-2][i]; - string data_file_preview_4 = data_file_preview[lines_number-1][i]; - - /* if(nans_columns(column_index)) - { - columns(column_index).type = ColumnType::Constant; - column_index++; - } - else*/ if((is_date_time_string(data_file_preview_1) && data_file_preview_1 != missing_values_label) - || (is_date_time_string(data_file_preview_2) && data_file_preview_2 != missing_values_label) - || (is_date_time_string(data_file_preview_3) && data_file_preview_3 != missing_values_label) - || (is_date_time_string(data_file_preview_4) && data_file_preview_4 != missing_values_label)) - { - raw_variables[raw_variable_index].type = RawVariableType::DateTime; - // time_column = raw_variables[raw_variable_index].name; - raw_variable_index++; - } - else if(((is_numeric_string(data_file_preview_1) && data_file_preview_1 != missing_values_label) || data_file_preview_1.empty()) - || ((is_numeric_string(data_file_preview_2) && data_file_preview_2 != missing_values_label) || data_file_preview_2.empty()) - || ((is_numeric_string(data_file_preview_3) && data_file_preview_3 != missing_values_label) || data_file_preview_3.empty()) - || ((is_numeric_string(data_file_preview_4) && data_file_preview_4 != missing_values_label) || data_file_preview_4.empty())) - { - raw_variables[raw_variable_index].type = RawVariableType::Numeric; - raw_variable_index++; - } - else - { - raw_variables[raw_variable_index].type = RawVariableType::Categorical; - raw_variable_index++; - } - } - - // Resize data file preview to original - - if(data_file_preview.size() > 4) - { - lines_number = has_binary_raw_variables() ? 4 : 3; - - vector> data_file_preview_copy(data_file_preview); - - data_file_preview.resize(lines_number); - - data_file_preview[0] = data_file_preview_copy[1]; - data_file_preview[1] = data_file_preview_copy[1]; - data_file_preview[2] = data_file_preview_copy[2]; - data_file_preview[lines_number - 2] = data_file_preview_copy[data_file_preview_copy.size()-2]; - data_file_preview[lines_number - 1] = data_file_preview_copy[data_file_preview_copy.size()-1]; - } -} - - -void LanguageDataSet::read_csv_2_simple() -{ - regex accent_regex("[\\xC0-\\xFF]"); - ifstream file(data_path); - - if(!file.is_open()) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv_2_simple() method.\n" - << "Cannot open data file: " << data_path << "\n"; - - throw runtime_error(buffer.str()); - } - - string line; - Index line_number = 0; - - if(has_binary_raw_variables()) - { - while(file.good()) - { - line_number++; - - getline(file, line); - - trim(line); - - erase(line, '"'); - - if(line.empty()) continue; - - break; - } - } - - Index samples_count = 0; - - Index tokens_count; - - if(display) cout << "Setting data dimensions..." << endl; - - const string separator_string = get_separator_string(); - - const Index raw_variables_number = get_raw_variables_number(); - const Index raw_raw_variables_number = get_has_rows_labels() ? raw_variables_number + 1 : raw_variables_number; - - while(file.good()) - { - line_number++; - - getline(file, line); - - trim(line); - - erase(line, '"'); - - if(line.empty()) continue; - - tokens_count = count_tokens(line, separator_string); - - if(tokens_count != raw_raw_variables_number) - { - ostringstream buffer; - - buffer << "OpenNN Exception: DataSet class.\n" - << "void read_csv_2_simple() method.\n" - << "Line " << line_number << ": Size of tokens(" - << tokens_count << ") is not equal to number of raw_variables(" - << raw_raw_variables_number << ").\n"; - - throw runtime_error(buffer.str()); - } - - samples_count++; - } - - file.close(); - - data.resize(samples_count, raw_variables_number); - - set_default_raw_variables_uses(); - - sample_uses.resize(samples_count); - // sample_uses.set(SampleUse::Training); - - split_samples_random(); -} - - -void LanguageDataSet::read_csv() -{ - read_csv_1(); - - read_csv_2_simple(); - - read_csv_3_language_model(); -} - -// void LanguageDataSet::read_txt() -// { -// cout << "Reading .txt file..." << endl; - -// load_documents(data_path); - -// Index entry_number = documents(0).size(); - - -// for(Index i = 1; i < documents.size(); i++) -// entry_number += documents[i].size(); - -// Index completion_entry_number = targets(0).size(); - -// for(Index i = 1; i < targets.size(); i++) -// completion_entry_number += targets(i).size(); - -// if(entry_number != completion_entry_number) -// throw runtime_error("Context number of entries (" + to_string(entry_number) + ") not equal to completion number of entries (" + to_string(completion_entry_number) + ").\n"); - -// vector context(entry_number); - -// Index entry_index = 0; - -// for(Index i = 0; i < documents.size(); i++) -// for(Index j = 0; j < documents[i].size(); j++) -// context(entry_index++) = documents[i][j]; - - -// vector completion(entry_number); - -// entry_index = 0; - -// for(Index i = 0; i < targets.size(); i++) -// for(Index j = 0; j < targets(i).size(); j++) -// completion(entry_index++) = targets(i)(j); - -// cout << "Processing documents..." << endl; - -// const vector> context_tokens = preprocess_language_documents(context); -// const vector> completion_tokens = preprocess_language_documents(completion); - -// bool imported_vocabulary = false; - -// if(context_vocabulary_path.empty() || completion_vocabulary_path.empty()) -// { -// cout << "Calculating vocabularies..." << endl; - -// const Index target_vocabulary_size = 8000; - -// vector reserved_tokens = { "[PAD]", "[UNK]", "[START]", "[END]" }; - -// context_vocabulary= create_vocabulary(context_tokens, target_vocabulary_size, reserved_tokens); -// completion_vocabulary= create_vocabulary(completion_tokens, target_vocabulary_size, reserved_tokens); -// } -// else -// { -// cout << "Importing vocabularies..." << endl; - -// //imported_vocabulary = true; -// import_vocabulary(context_vocabulary_path, context_vocabulary); -// import_vocabulary(completion_vocabulary_path, completion_vocabulary); -// } - -// const Index LIMIT = 126; - -// Index maximum_context_tokens = context_tokens[0].size(); - -// for(Index i = 0; i < entry_number; i++) -// if(context_tokens[i].size() > maximum_context_tokens) -// maximum_context_tokens = context_tokens[i].size(); - -// maximum_context_length = maximum_context_tokens > LIMIT ? LIMIT : maximum_context_tokens; - -// Index maximum_completion_tokens = completion_tokens[0].size(); - -// for(Index i = 0; i < entry_number; i++) -// if(completion_tokens[i].size() > maximum_completion_tokens) -// maximum_completion_tokens = completion_tokens[i].size(); - -// maximum_completion_length = maximum_completion_tokens > LIMIT + 1 ? LIMIT + 1 : maximum_completion_tokens; - -// // Output - -// cout << "Writting data file..." << endl; - -// string transformed_data_path = data_path; -// replace(transformed_data_path,".txt","_data.txt"); -// replace(transformed_data_path,".csv","_data.csv"); - -// ofstream file(transformed_data_path); - - // @todo maybe context does NOT need start and end tokens - -// for(Index i = type(0); i < maximum_context_length + 2; i++) // there is start and end indicators -// file << "context_token_position_" << i << ";"; - -// for(Index i = type(0); i < maximum_completion_length + 1; i++) -// file << "input_token_position_" << i << ";"; - -// for(Index i = type(0); i < maximum_completion_length; i++) -// file << "target_token_position_" << i << ";"; - -// file << "target_token_position_" << maximum_completion_length << "\n"; - -// // Data file preview - -// Index preview_size = 4; - -// data_file_preview.resize(preview_size, 2); - -// for(Index i = 0; i < preview_size - 1; i++) -// { -// data_file_preview[i][0] = context[i]; -// data_file_preview[i][1] = completion[i]; -// } - -// data_file_preview(preview_size - 1, 0) = context(context.size()-1); -// data_file_preview(preview_size - 1, 1) = completion(completion.size()-1); - -// //if(!imported_vocabulary) write_data_file_whitespace(file, context_tokens, completion_tokens); -// //else -// write_data_file_wordpiece(file, context_tokens, completion_tokens); - -// file.close(); - -// data_path = transformed_data_path; -// separator = Separator::Semicolon; -// has_header = true; - -// read_csv_language_model(); - -// set_raw_variable_types(RawVariableType::Numeric); -// cout<<"Works properly"< tokens = get_tokens(line, separator_string); + vector tokens = get_tokens(line, separator_string); const Index tokens_number = tokens.size(); if (tokens_number != 2) throw runtime_error("Tokens number must be two."); - input_tokens[sample_index] = get_tokens(tokens[0], " "); - target_tokens[sample_index] = get_tokens(tokens[0], " "); + to_lower(tokens); + split_punctuation(tokens); + delete_extra_spaces(tokens); + delete_non_printable_chars(tokens); + delete_non_alphanumeric(tokens); -// const vector> input_tokens = preprocess_language_documents(input); -// const vector> target_tokens = preprocess_language_documents(target); + input_tokens[sample_index] = get_tokens(tokens[0], " "); + target_tokens[sample_index] = get_tokens(tokens[1], " "); sample_index++; } - + cout << "Calculating vocabularies..." << endl; -// input_vocabulary = create_vocabulary(input_tokens); -// target_vocabulary = create_vocabulary(target_tokens); + input_vocabulary = create_vocabulary(input_tokens); + target_vocabulary = create_vocabulary(target_tokens); // completion_vocabulary = {"[PAD]", "[UNK]", "[START]", "[END]", "Good", "Bad"}; // completion_vocabulary = {"[START]", "[END]", "Good", "Bad"}; - + constexpr size_t LIMIT = 126; maximum_input_length = min(get_maximum_size(input_tokens), LIMIT); - maximum_target_length = min(get_maximum_size(target_tokens), LIMIT + 1); -/* - // Output - - cout << "Writting data file..." << endl; - - string transformed_data_path = data_path.string(); - - replace(transformed_data_path,".txt","_data.txt"); - replace(transformed_data_path,".csv","_data.csv"); - - //if (!imported_vocabulary) write_data_file_whitespace(file, context_tokens, completion_tokens); - //else + // Output + /* write_data_file_wordpiece(file, input_tokens, target_tokens); file.close(); data_path = transformed_data_path; - separator = Separator::Semicolon; set_raw_variable_types(RawVariableType::Numeric); @@ -1828,7 +1252,7 @@ void LanguageDataSet::read_txt() for (Index i = 0; i < maximum_target_length + 1; i++) set_raw_variable_use(i + maximum_input_length + maximum_target_length + 3, VariableUse::Target); - +*/ } diff --git a/opennn/language_data_set.h b/opennn/language_data_set.h index 984f1b48c..608882395 100644 --- a/opennn/language_data_set.h +++ b/opennn/language_data_set.h @@ -21,7 +21,7 @@ class LanguageDataSet : public DataSet explicit LanguageDataSet(const dimensions& = {0}, const dimensions& = {0}); - explicit LanguageDataSet(const filesystem::path& = filesystem::path()); + explicit LanguageDataSet(const filesystem::path&); const vector& get_input_vocabulary() const; const vector& get_target_vocabulary() const; @@ -53,8 +53,7 @@ class LanguageDataSet : public DataSet void import_lengths(const filesystem::path&, Index&, Index&); vector create_vocabulary(const vector>& tokens, - const Index& vocabulary_size, - const vector& reserved_tokens, + const Index& vocabulary_size = 123, const Index& upper_threshold = 10000000, const Index& lower_threshold = 10, const Index& iterations_number = 4, @@ -65,16 +64,8 @@ class LanguageDataSet : public DataSet const bool& include_joiner_token = true, const string& joiner = "##"); - void load_documents(); - - void read_csv_1(); - - void read_csv_2_simple(); - void read_csv_3_language_model(); - void read_csv() override; - // Empieza por aquí. void read_txt(); diff --git a/opennn/strings_utilities.cpp b/opennn/strings_utilities.cpp index f0a63abf5..e805b6144 100644 --- a/opennn/strings_utilities.cpp +++ b/opennn/strings_utilities.cpp @@ -835,24 +835,6 @@ vector> get_tokens(const vector& documents, const string& } -vector> preprocess_language_documents(const vector& documents) -{ - vector documents_copy(documents); - - to_lower(documents_copy); - - split_punctuation(documents_copy); - - delete_non_printable_chars(documents_copy); - - delete_extra_spaces(documents_copy); - - delete_non_alphanumeric(documents_copy); - - return get_tokens(documents_copy, " "); -} - - vector> count_words(const vector& words) { unordered_map count; @@ -913,8 +895,6 @@ void split_punctuation(vector& documents) for (const auto& [symbol, replacement] : punctuations) replace_substring(documents, symbol, replacement); - - delete_extra_spaces(documents); } diff --git a/opennn/strings_utilities.h b/opennn/strings_utilities.h index b54606eb4..d4d718ad0 100644 --- a/opennn/strings_utilities.h +++ b/opennn/strings_utilities.h @@ -79,8 +79,6 @@ namespace opennn void delete_non_alphanumeric(vector&); vector> get_tokens(const vector&, const string&); - vector> preprocess_language_documents(const vector&); - vector> count_words(const vector&); enum Language {ENG, SPA}; @@ -89,20 +87,20 @@ namespace opennn void set_language(const string&); - void append_documents(const vector&); - // Preprocess - void delete_extra_spaces(vector&); void delete_non_printable_chars(vector&); + void split_punctuation(string&); void split_punctuation(vector&); void delete_emails(vector>&); void delete_non_alphanumeric(vector&); + + void print_tokens(const vector>&); } #endif // OPENNNSTRINGS_H diff --git a/opennn/tensors.cpp b/opennn/tensors.cpp index 1a5e7cc51..3045957ba 100644 --- a/opennn/tensors.cpp +++ b/opennn/tensors.cpp @@ -1428,6 +1428,13 @@ TensorMap> tensor_map_4(const pair& x_pair) x_pair.second[3]); } + +void print_pairs(const vector>& pairs) +{ + for (size_t i = 0; i < pairs.size(); i++) + cout << pairs[i].first << ": " << pairs[i].second << endl; +} + } // OpenNN: Open Neural Networks Library. diff --git a/opennn/tensors.h b/opennn/tensors.h index 31edb7335..d9958fc39 100644 --- a/opennn/tensors.h +++ b/opennn/tensors.h @@ -200,6 +200,9 @@ void print_vector(const vector& vec) cout << "]\n"; } +void print_pairs(const vector>&); + + template Tensor get_dimensions(const Tensor& tensor) { diff --git a/opennn/transformer.cpp b/opennn/transformer.cpp index 9f29ef7ed..1d0ebea48 100644 --- a/opennn/transformer.cpp +++ b/opennn/transformer.cpp @@ -271,9 +271,6 @@ void Transformer::set_context_vocabulary(const vector& new_context_vocab // type end_indicator = 3; // //} -// // @todo -// const vector> context_tokens = preprocess_language_documents(tensor_wrapper(context_string)); - // const Index batch_samples_number = 1; // Tensor context(batch_samples_number, context_length); @@ -333,6 +330,7 @@ void Transformer::set_context_vocabulary(const vector& new_context_vocab string Transformer::calculate_outputs(const vector& context_documents) { +/* //type start_indicator = 1; //type end_indicator = 2; @@ -342,10 +340,6 @@ string Transformer::calculate_outputs(const vector& context_documents) type end_indicator = 3; //} - // @todo - - const vector> context_tokens = preprocess_language_documents(context_documents); - const Index samples_number = 1; Tensor context(samples_number, context_length); @@ -384,9 +378,9 @@ string Transformer::calculate_outputs(const vector& context_documents) { forward_propagate(input_pairs, forward_propagation, false); - current_outputs/*.device(*thread_pool_device)*/ = outputs.chip(i - 1, 0); + current_outputs.device(*thread_pool_device) = outputs.chip(i - 1, 0); - prediction/*.device(*thread_pool_device)*/ = current_outputs.argmax(); + prediction.device(*thread_pool_device) = current_outputs.argmax(); input(i) = type(prediction(0)); @@ -403,6 +397,8 @@ string Transformer::calculate_outputs(const vector& context_documents) detokenize_wordpiece(input, output_buffer); return output_buffer.str(); +*/ + return string(); }