Skip to content

Commit

Permalink
Make texel tuner multithreaded + re-tune PSTs and piece values
Browse files Browse the repository at this point in the history
  • Loading branch information
kelseyde committed May 11, 2024
1 parent b45fcf3 commit da2ec36
Show file tree
Hide file tree
Showing 3 changed files with 2,128 additions and 353 deletions.
55 changes: 10 additions & 45 deletions src/main/java/com/kelseyde/calvin/tuning/texel/TexelTuner.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,23 @@ public TexelTuner(String fileName) {
private final String positionsFileName;

private Map<Board, Double> positions;
private List<Map<Board, Double>> partitions;

private Delta[] deltas;

private double k = 1.26;

private int threadCount = 5;
private int threadCount = 10;

public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfigFunction)
throws IOException, ExecutionException, InterruptedException {

positions = loadPositions();
partitions = partitionPositions(positions, threadCount);
initDeltas(initialParams.length);
System.out.println("number of positions: " + positions.size());
int[] bestParams = initialParams;
double bestError = meanSquareErrorMultithreaded(bestParams, createConfigFunction);
double bestError = meanSquareError(bestParams, createConfigFunction);
int iterations = 0;

boolean improved = true;
Expand All @@ -56,7 +58,7 @@ public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfi
int[] newParams = Arrays.copyOf(bestParams, bestParams.length);
int delta = deltas[i].delta;
newParams[i] += delta;
double newError = meanSquareErrorMultithreaded(newParams, createConfigFunction);
double newError = meanSquareError(newParams, createConfigFunction);
System.out.printf("tuning param %s of %s, error %s%n", i, bestParams.length, newError);

if (newError < bestError) {
Expand All @@ -71,7 +73,7 @@ public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfi

} else {
newParams[i] -= delta * 2;
newError = meanSquareErrorMultithreaded(newParams, createConfigFunction);
newError = meanSquareError(newParams, createConfigFunction);
if (newError < bestError) {
improved = true;
modifiedParams++;
Expand All @@ -97,22 +99,15 @@ public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfi

}

public double meanSquareErrorMultithreaded(int[] params, Function<int[], EngineConfig> createConfigFunction)
public double meanSquareError(int[] params, Function<int[], EngineConfig> createConfigFunction)
throws ExecutionException, InterruptedException {

int totalPositions = positions.size();
List<Map<Board, Double>> partitions = partitionPositions(positions, threadCount);

List<CompletableFuture<Double>> threads = partitions.stream()
.map(partition -> CompletableFuture.supplyAsync(() -> totalError(partition, params, createConfigFunction)))
.toList();

CompletableFuture<List<Double>> combined = CompletableFuture.allOf(threads.toArray(CompletableFuture[]::new))
.thenApply(future -> threads.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList()));

return combined.get().stream().reduce(Double::sum).orElse(0.0) / totalPositions;
CompletableFuture.allOf(threads.toArray(CompletableFuture[]::new)).get();
Double combinedResult = threads.stream().map(CompletableFuture::join).reduce(0.0, Double::sum);
return combinedResult / totalPositions;
}

public double totalError(Map<Board, Double> partitionedPositions, int[] params, Function<int[], EngineConfig> createConfigFunction) {
Expand Down Expand Up @@ -174,36 +169,6 @@ public List<String> loadFens() throws IOException {
return Files.readAllLines(path);
}

// public double tuneScalingConstant(Evaluator evaluator) throws IOException {
// List<String> positions = loadFens();
// System.out.println("number of positions: " + positions.size());
// double bestError = meanSquareError(evaluator, positions);
//
// boolean improved = true;
// while (improved) {
// improved = false;
// k += 0.01;
// System.out.println("1 new k = " + k);
// double newError = meanSquareError(evaluator, positions);
// if (newError < bestError) {
// improved = true;
// bestError = newError;
// System.out.println("improved k " + k + " " + bestError);
// } else {
// k -= 0.02;
// System.out.println("2 new k = " + k);
// newError = meanSquareError(evaluator, positions);
// if (newError < bestError) {
// improved = true;
// bestError = newError;
// System.out.println("improved k " + k + " " + bestError);
// }
// }
// }
// System.out.println("final k " + k);
// return k;
// }
//
private Map<Board, Double> loadPositions() throws IOException {
List<String> fens = loadFens();
Map<Board, Double> positions = new HashMap<>();
Expand Down
Loading

0 comments on commit da2ec36

Please # to comment.