Skip to content

Commit

Permalink
updated MCTSAgentBatch
Browse files Browse the repository at this point in the history
* removed <bits/stdint-uintn.h> include (to allow building on Windows
and Mac)
* small refactoring
  • Loading branch information
QueensGambit committed Aug 22, 2021
1 parent 3c5bdb7 commit 5207d88
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 23 deletions.
30 changes: 10 additions & 20 deletions engine/src/agents/mctsagentbatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
*
*/

#include <bits/stdint-uintn.h>
#include <string>
#include <thread>
#include <fstream>
Expand Down Expand Up @@ -60,13 +59,11 @@ string MCTSAgentBatch::get_name() const
if(splitNodes){
ret = "MCTSBatch-Split-" + std::to_string(numberOfAgents) + "-" + engineVersion + "-" + net->get_model_name();
}

return ret;
}

void MCTSAgentBatch::evaluate_board_state()
{

vector<EvalInfo> evals;
evalInfo->isChess960 = state->is_chess960();

Expand Down Expand Up @@ -154,60 +151,53 @@ void MCTSAgentBatch::evaluate_board_state()
eval.nodes = rootNode->get_node_count();
eval.tbHits = tbHits;



evals.push_back(eval);
tGCThread.join();
}

evalInfo->nodesPreSearch = init_root_node(state);
evalInfo->legalMoves = rootNode->get_legal_actions();


auto combinedPolicy = evals[0].policyProbSmall;
auto combinedChildVisits = evals[0].childNumberVisits;
auto combinedQValues = evals[0].qValues;

for (size_t i = 1; i < numberOfAgents; i++)
{
for(auto j = 0; j< combinedPolicy.size();++j){
for(auto j = 0; j < combinedPolicy.size(); ++j){
combinedPolicy[j] += evals[i].policyProbSmall[j];
}
for(auto j = 0; j< combinedChildVisits.size();++j){
for(auto j = 0; j < combinedChildVisits.size(); ++j){
combinedChildVisits[j] += evals[i].childNumberVisits[j];
}
for(auto j = 0; j< combinedQValues.size();++j){
for(auto j = 0; j < combinedQValues.size(); ++j){
combinedQValues[j] += evals[i].qValues[j];
}
}

for(auto j = 0; j< combinedPolicy.size();++j){
for(auto j = 0; j < combinedPolicy.size(); ++j){
combinedPolicy[j] += combinedPolicy[j]/numberOfAgents;
}
for(auto j = 0; j< combinedChildVisits.size();++j){
for(auto j = 0; j < combinedChildVisits.size(); ++j){
combinedChildVisits[j] += combinedChildVisits[j]/numberOfAgents;
}
for(auto j = 0; j< combinedQValues.size();++j){
for(auto j = 0; j < combinedQValues.size(); ++j){
combinedQValues[j] += combinedQValues[j]/numberOfAgents;
}

vector<float> diffs;
for (size_t i = 0; i < numberOfAgents; i++)
{
diffs.push_back(0.0);
for(auto j = 0; j< combinedPolicy.size();++j){
for(auto j = 0; j< combinedPolicy.size(); ++j){
diffs[i] += std::sqrt(std::pow(evals[i].policyProbSmall[j] - combinedPolicy[j],2));
}

}
std::vector<float>::iterator result = std::min_element(diffs.begin(), diffs.end());
int a = std::distance(diffs.begin(), result);

int stateIdx = std::distance(diffs.begin(), result);

*evalInfo = evals[a];
*evalInfo = evals[stateIdx];
update_nps_measurement(evalInfo->calculate_nps());

info_string("Selected State: " + std::to_string(a));


info_string("Selected State: " + std::to_string(stateIdx));
}
3 changes: 0 additions & 3 deletions engine/src/agents/mctsagentbatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ using namespace crazyara;
class MCTSAgentBatch : public MCTSAgent
{
public:

// how many trees should be generated
int numberOfAgents;
// boolean, deciding if the given nodes are player per tree or are split between the trees
Expand All @@ -66,8 +65,6 @@ class MCTSAgentBatch : public MCTSAgent

string get_name() const override;
void evaluate_board_state() override;


};


Expand Down

0 comments on commit 5207d88

Please # to comment.