Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Input Representation Version Parsing #148

Merged
merged 3 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ option(BACKEND_MXNET "Build with MXNet backend (Blas/IntelMKL/CUDA/T
option(BACKEND_TORCH "Build with Torch backend (CPU/GPU) support" OFF)
option(USE_960 "Build with 960 variant support" OFF)
option(BUILD_TESTS "Build and run tests" OFF)
option(USE_DYNAMIC_NN_ARCH "Build with dynamic neural network architektur support" OFF)
option(USE_DYNAMIC_NN_ARCH "Build with dynamic neural network architektur support" ON)
# enable a single mode for different model input / outputs
option(MODE_CRAZYHOUSE "Build with crazyhouse only support" ON)
option(MODE_CHESS "Build with chess + chess960 only support" OFF)
Expand Down
38 changes: 38 additions & 0 deletions engine/src/nn/neuralnetapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include "neuralnetapi.h"
#include <string>
#include <regex>
#include "../stateobj.h"


Expand Down Expand Up @@ -69,6 +70,8 @@ void NeuralNetAPI::initialize_nn_design()
nbNNInputValues = nnDesign.inputShape.flatten() / batchSize;
nbNNAuxiliaryOutputs = nnDesign.auxiliaryOutputShape.flatten() / batchSize;
policyOutputLength = nnDesign.policyOutputShape.v[1] * batchSize;
version = read_version_from_string(modelName);
info_string("Input representation: ", version_to_string(version));
}

void NeuralNetAPI::initialize()
Expand Down Expand Up @@ -164,3 +167,38 @@ ostream& nn_api::operator<<(ostream &os, const nn_api::Shape &shape)
os << ")";
return os;
}

Version read_version_from_string(const string &modelFileName)
{
// pattern to detect "-v<major>.<minor>"
const string prefix = "-v";
const string pattern = "(" + prefix + ")[0-9]+.[0-9]+";

// regex expression for pattern to be searched
regex regexp(pattern);

// flag type for determining the matching behavior (in this case on string objects)
smatch matches;

// regex_search that searches pattern regexp in the string
regex_search(modelFileName, matches, regexp);

if (matches.size() > 0) {
for (auto match : matches) {
if (match.length() > 3) {
const string content = match;
const size_t pointPos = content.find(".");
try {
const string versionMajor = content.substr(prefix.size(), pointPos-prefix.size()); // skip "-v"
const string versionMinor = content.substr(pointPos+1); // skip "."
return make_version(std::stoi(versionMajor), std::stoi(versionMinor), 0);
} catch (exception e) {
info_string(e.what());
break;
}
}
}
}
// unsuccessfull
return make_version<0,0,0>();
}
11 changes: 11 additions & 0 deletions engine/src/nn/neuralnetapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ vector<string> get_items_by_elment(const vector<string>& stringVector, const str
*/
string get_file_ending_with(const string& dir, const string& suffix);

/**
* @brief read_version_from_string Returns the corresponding version for a given model file name.
* The version identifier is expected to come after the substring "-v" in the format "-v<Major>.<Minor>", e.g. "-v1.2.onnx".
* If the information is missing or parsing failed, make_version<0,0,0>() will be returned.
* Versioning patch information is always set to 0.
* The version information is used to decide between different input representations for the neural network.
* @param modelFileName
* @return Version information
*/
Version read_version_from_string(const string& modelFileName);


template <typename T>
/**
Expand Down