Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
kelseyde committed May 11, 2024
1 parent 77ae540 commit b45fcf3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 23 deletions.
63 changes: 47 additions & 16 deletions src/main/java/com/kelseyde/calvin/tuning/texel/TexelTuner.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -35,14 +34,16 @@ public TexelTuner(String fileName) {

private double k = 1.26;

public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfigFunction) throws IOException {
private int threadCount = 5;

public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfigFunction)
throws IOException, ExecutionException, InterruptedException {

positions = loadPositions();
initDeltas(initialParams.length);
System.out.println("number of positions: " + positions.size());
Evaluator evaluator = new Evaluator(createConfigFunction.apply(initialParams));
int[] bestParams = initialParams;
double bestError = meanSquareError(evaluator);
double bestError = meanSquareErrorMultithreaded(bestParams, createConfigFunction);
int iterations = 0;

boolean improved = true;
Expand All @@ -55,8 +56,7 @@ public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfi
int[] newParams = Arrays.copyOf(bestParams, bestParams.length);
int delta = deltas[i].delta;
newParams[i] += delta;
evaluator = new Evaluator(createConfigFunction.apply(newParams));
double newError = meanSquareError(evaluator);
double newError = meanSquareErrorMultithreaded(newParams, createConfigFunction);
System.out.printf("tuning param %s of %s, error %s%n", i, bestParams.length, newError);

if (newError < bestError) {
Expand All @@ -71,8 +71,7 @@ public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfi

} else {
newParams[i] -= delta * 2;
evaluator = new Evaluator(createConfigFunction.apply(newParams));
newError = meanSquareError(evaluator);
newError = meanSquareErrorMultithreaded(newParams, createConfigFunction);
if (newError < bestError) {
improved = true;
modifiedParams++;
Expand All @@ -98,11 +97,28 @@ public int[] tune(int[] initialParams, Function<int[], EngineConfig> createConfi

}

public double meanSquareError(Evaluator evaluator) throws IOException {
public double meanSquareErrorMultithreaded(int[] params, Function<int[], EngineConfig> createConfigFunction)
throws ExecutionException, InterruptedException {

int totalPositions = positions.size();
List<Map<Board, Double>> partitions = partitionPositions(positions, threadCount);

List<CompletableFuture<Double>> threads = partitions.stream()
.map(partition -> CompletableFuture.supplyAsync(() -> totalError(partition, params, createConfigFunction)))
.toList();

CompletableFuture<List<Double>> combined = CompletableFuture.allOf(threads.toArray(CompletableFuture[]::new))
.thenApply(future -> threads.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList()));

return combined.get().stream().reduce(Double::sum).orElse(0.0) / totalPositions;
}

int numberOfPositions = positions.size();
public double totalError(Map<Board, Double> partitionedPositions, int[] params, Function<int[], EngineConfig> createConfigFunction) {
Evaluator evaluator = new Evaluator(createConfigFunction.apply(params));
double totalError = 0.0;
for (Map.Entry<Board, Double> entry : positions.entrySet()) {
for (Map.Entry<Board, Double> entry : partitionedPositions.entrySet()) {
Board board = entry.getKey();
int eval = evaluator.evaluate(board);
if (!board.isWhiteToMove()) eval = -eval;
Expand All @@ -111,8 +127,7 @@ public double meanSquareError(Evaluator evaluator) throws IOException {
double error = error(prediction, actual);
totalError += error;
}
return totalError / numberOfPositions;

return totalError;
}

/**
Expand All @@ -137,6 +152,22 @@ private double result(String position) {
};
}

private List<Map<Board, Double>> partitionPositions(Map<Board, Double> positions, int partitions) {
List<Map<Board, Double>> partitionedPositions = new ArrayList<>();
List<Map.Entry<Board, Double>> positionEntries = positions.entrySet().stream().toList();
int positionsPerPartition = positions.size() / partitions;
int currentIndex = 0;
for (int i = 1; i < partitions + 1; i++) {
int startIndex = currentIndex;
int endIndex = Math.min(currentIndex + positionsPerPartition, positions.size());
partitionedPositions.add(
positionEntries.subList(startIndex, endIndex).stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
currentIndex += positionsPerPartition;
}
return partitionedPositions;
}

public List<String> loadFens() throws IOException {
String fileName = String.format("src/test/resources/texel/" + positionsFileName);
Path path = Paths.get(fileName);
Expand Down
34 changes: 27 additions & 7 deletions src/test/java/com/kelseyde/calvin/tuning/texel/TexelTunerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.List;


@Disabled
public class TexelTunerTest {
Expand All @@ -18,9 +22,25 @@ public class TexelTunerTest {
private final ObjectMapper objectMapper = new ObjectMapper();

@Test
public void tunePieceValuesAndPSTs() throws IOException {
public void tunePieceValuesAndPSTs() throws IOException, ExecutionException, InterruptedException {
List<Integer> weights = new ArrayList<>();
EngineConfig initialConfig = EngineInitializer.loadDefaultConfig();
weights.addAll(Arrays.stream(initialConfig.getMiddlegameTables()[0]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getMiddlegameTables()[1]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getMiddlegameTables()[2]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getMiddlegameTables()[3]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getMiddlegameTables()[4]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getMiddlegameTables()[5]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getEndgameTables()[0]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getEndgameTables()[1]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getEndgameTables()[2]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getEndgameTables()[3]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getEndgameTables()[4]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getEndgameTables()[5]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getPieceValues()[0]).boxed().toList());
weights.addAll(Arrays.stream(initialConfig.getPieceValues()[1]).boxed().toList());
tune(
new int[] { 0, 0, 0, 0, 0, 0, 0, 0, 91, 127, 54, 88, 61, 119, 27, -18, -13, 0, 19, 24, 58, 63, 18, -27, -13, 10, 5, 14, 24, 19, 24, -20, -34, -9, -5, 9, 10, 13, 3, -32, -29, -11, -3, -9, 10, -4, 33, -12, -32, -1, -21, -16, -8, 31, 38, -15, 0, 0, 0, 0, 0, 0, 0, 0, -174, -89, -41, -42, 61, -96, -22, -100, -70, -34, 79, 29, 30, 69, 14, -10, -40, 53, 37, 72, 91, 136, 80, 44, -2, 24, 19, 60, 40, 69, 15, 29, -6, 11, 23, 20, 35, 26, 28, -9, -21, -6, 15, 17, 26, 17, 32, -15, -22, -46, -5, 4, 6, 20, -7, -12, -98, -19, -51, -26, -24, -21, -20, -30, -22, -3, -75, -44, -32, -35, 6, -1, -26, 9, -21, -13, 23, 66, 11, -40, -23, 30, 50, 33, 42, 50, 44, 5, -11, -2, 12, 43, 32, 30, 0, -9, -9, 6, 6, 23, 37, 5, 7, -3, 7, 22, 8, 8, 11, 20, 15, 10, -3, 22, 9, -7, 0, 21, 40, 8, -26, -10, -7, -21, -20, -15, -36, -18, 39, 49, 35, 51, 61, 16, 38, 50, 25, 25, 51, 62, 73, 74, 33, 51, 2, 19, 26, 35, 24, 52, 68, 23, -24, -11, 4, 26, 17, 28, -1, -13, -31, -25, -19, -8, 2, -5, 13, -24, -45, -24, -23, -17, -4, 0, -5, -26, -37, -17, -27, -16, -6, 4, 1, -70, -12, -12, -6, 10, 9, 10, -44, -19, -21, -1, 36, 19, 61, 51, 38, 52, -17, -32, -12, -6, -23, 64, 35, 61, -20, -19, 14, 15, 36, 63, 54, 64, -34, -20, -13, -9, 6, 20, -7, 8, -16, -26, -8, -10, 5, 3, 10, -4, -21, 1, -11, -2, -2, 4, 21, -2, -28, -1, 8, 5, 11, 16, 4, 2, -8, -19, -9, 3, -18, -32, -38, -43, -63, 24, 17, -13, -49, -34, 5, 8, 29, -1, -19, -7, -7, 3, -37, -36, -16, 24, 3, -16, -20, 6, 22, -27, -24, -27, -19, -28, -37, -25, -14, -43, -52, -2, -28, -46, -45, -49, -40, -58, -19, -13, -29, -53, -49, -31, -15, -34, 8, 0, -15, -57, -36, -23, 10, 1, -22, 29, 13, -47, 9, -31, 31, 7, 0, 0, 0, 0, 0, 0, 0, 0, 171, 166, 151, 127, 140, 125, 158, 180, 87, 93, 78, 60, 49, 46, 75, 77, 32, 17, 6, -2, -9, -3, 14, 10, 16, 6, -4, -14, -6, -1, 0, 0, 4, 0, 1, 1, 3, 2, -8, -7, 16, 1, 15, 9, 20, 7, -5, -4, 0, 0, 0, 0, 0, 0, 0, 0, -56, -45, -20, -21, -38, -27, -63, -97, -24, -15, -22, -9, -6, -24, -31, -49, -24, -27, 3, 2, -8, -16, -26, -40, -10, 10, 15, 29, 17, 11, 8, -18, -11, -6, 15, 25, 23, 10, -3, -15, -20, -4, -6, 14, 9, -5, -17, -29, -35, -13, -17, -2, -2, -17, -16, -37, -29, -51, -20, -8, -25, -11, -50, -63, -7, -28, -4, -7, -14, -9, -17, -17, -15, -11, 0, -15, -10, -20, -11, -14, 1, -15, 0, -8, -3, 7, -3, 7, -3, 2, 5, 2, 7, 3, -4, -5, -13, -4, 6, 12, 0, 3, -10, -16, -12, 0, 5, 9, 13, -4, -14, -20, -14, -19, -14, 0, 5, -10, -16, -27, -16, -10, -23, -5, -6, -23, -12, -14, 20, 13, 13, 8, 13, 19, 15, 12, 14, 20, 12, 8, -10, 10, 15, 10, 14, 14, 7, 3, 11, 4, 2, 4, 11, 10, 12, 1, 2, 8, 6, 9, 10, 7, 9, 4, 2, 1, -1, -4, 3, 2, -3, 2, -7, -10, -8, -9, 1, 1, 7, 1, -6, -6, -4, 4, -2, 9, 3, -8, -8, -6, 11, -13, -2, 29, 29, 34, 30, 26, 10, 27, -10, 20, 25, 43, 65, 32, 37, -2, -27, 3, 16, 42, 54, 42, 26, 9, -4, 25, 25, 52, 64, 47, 57, 43, -15, 35, 19, 47, 38, 34, 39, 20, -23, -34, 22, 6, 16, 14, 17, 5, -15, -16, -33, -9, -9, -30, -29, -30, -32, -21, -22, -44, -5, -37, -27, -34, -73, -42, -25, -25, -18, 8, -3, -24, -19, 10, 7, 10, 10, 38, 23, 4, 3, 10, 16, 12, 20, 38, 37, 13, -15, 15, 17, 20, 26, 31, 23, 0, -25, -11, 14, 25, 29, 30, 9, -10, -26, -6, 14, 28, 30, 23, 10, -8, -29, -4, 7, 20, 21, 11, 2, -20, -50, -34, -14, -4, -28, -12, -24, -50, 82, 344, 365, 484, 1032, 0, 87, 280, 290, 519, 936, 0 },
weights.stream().mapToInt(i -> i).toArray(),
(params) -> {
if (params.length != 780) return null;
EngineConfig config = EngineInitializer.loadDefaultConfig();
Expand All @@ -44,7 +64,7 @@ public void tunePieceValuesAndPSTs() throws IOException {
}

@Test
public void tuneMobilityWeights() throws IOException {
public void tuneMobilityWeights() throws IOException, ExecutionException, InterruptedException {
tune(
new int[] { -18, -14, -8, -4, 0, 4, 8, 12, 16, -26, -21, -16, -12, -8, -4, 0, 4, 8, 12, 16, 16, 16, 16, -14, -12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, 12, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, -18, -14, -8, -4, 0, 4, 8, 12, 16, -26, -21, -16, -12, -8, -4, 0, 4, 8, 12, 16, 16, 16, 16, -14, -12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, 12, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12 },
(params) -> {
Expand All @@ -63,7 +83,7 @@ public void tuneMobilityWeights() throws IOException {
}

@Test
public void tuneBishopAndRookWeights() throws IOException {
public void tuneBishopAndRookWeights() throws IOException, ExecutionException, InterruptedException {
tune(
new int[] { 50, 42, 22, 18, 16 },
(params) -> {
Expand All @@ -79,7 +99,7 @@ public void tuneBishopAndRookWeights() throws IOException {
}

@Test
public void tuneKingSafetyWeights() throws IOException {
public void tuneKingSafetyWeights() throws IOException, ExecutionException, InterruptedException {
tune(
new int[] { 0, 0, 10, 25, 50, 50, 50, 15, 10, 25, 15, 120 },
(params) -> {
Expand All @@ -102,7 +122,7 @@ public void tuneKingSafetyWeights() throws IOException {
}

@Test
public void tunePieceValues() throws IOException {
public void tunePieceValues() throws IOException, ExecutionException, InterruptedException {
tune(
new int[] { 92, 393, 400, 544, 1119, 0, 78, 254, 280, 535, 1072, 0 },
(params) -> {
Expand All @@ -115,7 +135,7 @@ public void tunePieceValues() throws IOException {

}

private void tune(int[] initialParams, Function<int[], EngineConfig> createConfigFunction) throws IOException {
private void tune(int[] initialParams, Function<int[], EngineConfig> createConfigFunction) throws IOException, ExecutionException, InterruptedException {
EngineConfig initialConfig = createConfigFunction.apply(initialParams);
System.out.println("Initial config: " + objectMapper.writeValueAsString(initialConfig));
int[] bestParams = tuner.tune(initialParams, createConfigFunction);
Expand Down

0 comments on commit b45fcf3

Please # to comment.