diff --git a/mlrose/fitness.py b/mlrose/fitness.py index 94a8ffc2..250d43c1 100644 --- a/mlrose/fitness.py +++ b/mlrose/fitness.py @@ -904,9 +904,14 @@ class MaxKColor: def __init__(self, edges): # Remove any duplicates from list - edges = list({tuple(sorted(edge)) for edge in edges}) - - self.edges = edges + self.neighbors = {} + for v1, v2 in edges: + if v1 not in self.neighbors: + self.neighbors[v1] = set() + if v2 not in self.neighbors: + self.neighbors[v2] = set() + self.neighbors[v1].add(v2) + self.neighbors[v2].add(v1) self.prob_type = 'discrete' def evaluate(self, state): @@ -925,9 +930,9 @@ def evaluate(self, state): fitness = 0 - for i in range(len(self.edges)): - # Check for adjacent nodes of the same color - if state[self.edges[i][0]] == state[self.edges[i][1]]: + for v, color in enumerate(state): + # Check that all adjacent nodes are different color + if all(state[neighbor] != color for neighbor in self.neighbors.get(v, set())): fitness += 1 return fitness