From 1e40981b27f6c8a3384d19cb0a8a338d5063a70c Mon Sep 17 00:00:00 2001
From: Rogier van der Geer <rogiervandergeer@godatadriven.com>
Date: Fri, 15 May 2020 17:08:19 +0200
Subject: [PATCH] Change the signature of pick_random

Parent pickers are no longer passed any kwargs. The pick_random must
now be initialised before use, the number of parents passed to it
upon initialization. In addition, pickers must always return a sequence
of picked parents - even if it is only one.

These changes make it much easier to implement more complex picking
algorithms, and in addition they remove the requirement for the
select_arguments() decorator, which hurts my eyes.
---
 evol/helpers/pickers.py         | 19 ++++++++++++-------
 evol/population.py              |  6 +++---
 evol/utils.py                   | 11 ++++-------
 examples/number_of_parents.py   |  4 ++--
 examples/population_demo.py     |  2 +-
 examples/rock_paper_scissors.py |  2 +-
 examples/travelling_salesman.py |  2 +-
 tests/conftest.py               |  2 +-
 tests/test_evolution.py         |  2 +-
 tests/test_logging.py           | 20 ++++++++------------
 tests/test_population.py        | 13 ++++++-------
 tests/test_utils.py             | 12 ++++++------
 12 files changed, 46 insertions(+), 49 deletions(-)

diff --git a/evol/helpers/pickers.py b/evol/helpers/pickers.py
index 14bce74..53ef20a 100644
--- a/evol/helpers/pickers.py
+++ b/evol/helpers/pickers.py
@@ -1,14 +1,19 @@
-from typing import Sequence, Tuple
-
 from random import choice
+from typing import Callable, Sequence, Tuple
 
 from evol import Individual
 
 
-def pick_random(parents: Sequence[Individual], n_parents: int = 2) -> Tuple:
-    """Randomly selects parents with replacement
+def pick_random(n_parents: int = 2) -> Callable[[Sequence[Individual]], Tuple[Individual, ...]]:
+    """Returns a parent-picker that randomly samples parents with replacement.
+
+    Typical usage:
+        Evolution().breed(parent_picker=pick_random(n_parents=2), combiner=some_combiner)
 
-    Accepted arguments:
-      n_parents: Number of parents to select. Defaults to 2.
+    :param n_parents: The number of parents the picker should return.
+    :return: Callable
     """
-    return tuple(choice(parents) for _ in range(n_parents))
+    def picker(parents: Sequence[Individual]) -> Tuple[Individual, ...]:
+        return tuple(choice(parents) for _ in range(n_parents))
+
+    return picker
diff --git a/evol/population.py b/evol/population.py
index cf81419..b7f2784 100644
--- a/evol/population.py
+++ b/evol/population.py
@@ -18,8 +18,8 @@
 from evol.conditions import Condition
 from evol.exceptions import StopEvolution
 from evol.helpers.groups import group_random
-from evol.utils import offspring_generator, select_arguments
 from evol.serialization import SimpleSerializer
+from evol.utils import offspring_generator, select_arguments
 
 if TYPE_CHECKING:
     from .evolution import Evolution
@@ -160,7 +160,7 @@ def evaluate(self, lazy: bool = False) -> 'BasePopulation':
         pass
 
     def breed(self,
-              parent_picker: Callable[..., Sequence[Individual]],
+              parent_picker: Callable[[Sequence[Individual]], Sequence[Individual]],
               combiner: Callable,
               population_size: Optional[int] = None,
               **kwargs) -> 'BasePopulation':
@@ -182,7 +182,7 @@ def breed(self,
         if population_size:
             self.intended_size = population_size
         offspring = offspring_generator(parents=self.individuals,
-                                        parent_picker=select_arguments(parent_picker),
+                                        parent_picker=parent_picker,
                                         combiner=select_arguments(combiner),
                                         **kwargs)
         self.individuals += list(islice(offspring, self.intended_size - len(self.individuals)))
diff --git a/evol/utils.py b/evol/utils.py
index 08d8759..1d4fd8a 100644
--- a/evol/utils.py
+++ b/evol/utils.py
@@ -1,11 +1,11 @@
 from inspect import signature
-from typing import List, Callable, Union, Sequence, Any, Generator
+from typing import List, Callable, Sequence, Any, Generator
 
 from evol import Individual
 
 
 def offspring_generator(parents: List[Individual],
-                        parent_picker: Callable[..., Union[Individual, Sequence]],
+                        parent_picker: Callable[[Sequence[Individual]], Sequence[Individual]],
                         combiner: Callable[..., Any],
                         **kwargs) -> Generator[Individual, None, None]:
     """Generator for offspring.
@@ -25,11 +25,8 @@ def offspring_generator(parents: List[Individual],
     """
     while True:
         # Obtain parent chromosomes
-        selected_parents = parent_picker(parents, **kwargs)
-        if isinstance(selected_parents, Individual):
-            chromosomes = (selected_parents.chromosome,)
-        else:
-            chromosomes = tuple(individual.chromosome for individual in selected_parents)
+        selected_parents = parent_picker(parents)
+        chromosomes = tuple(individual.chromosome for individual in selected_parents)
         # Create children
         combined = combiner(*chromosomes, **kwargs)
         if isinstance(combined, Generator):
diff --git a/examples/number_of_parents.py b/examples/number_of_parents.py
index 1ad6ab4..18f930e 100644
--- a/examples/number_of_parents.py
+++ b/examples/number_of_parents.py
@@ -22,7 +22,7 @@ def init_func():
     def eval_func(x, opt_value=opt_value):
         return -((x - opt_value) ** 2) + math.cos(x - opt_value)
 
-    def random_parent_picker(pop, n_parents):
+    def random_parent_picker(pop):
         return [random.choice(pop) for i in range(n_parents)]
 
     def mean_parents(*parents):
@@ -36,7 +36,7 @@ def add_noise(chromosome, sigma):
 
     evo = (Evolution()
            .survive(fraction=survival)
-           .breed(parent_picker=random_parent_picker, combiner=mean_parents, n_parents=n_parents)
+           .breed(parent_picker=random_parent_picker, combiner=mean_parents)
            .mutate(mutate_function=add_noise, sigma=noise)
            .evaluate())
 
diff --git a/examples/population_demo.py b/examples/population_demo.py
index 8cb2466..20839be 100644
--- a/examples/population_demo.py
+++ b/examples/population_demo.py
@@ -11,7 +11,7 @@ def func_to_optimise(x):
 
 
 def pick_random_parents(pop):
-    return random.choice(pop)
+    return random.choice(pop),
 
 
 random.seed(42)
diff --git a/examples/rock_paper_scissors.py b/examples/rock_paper_scissors.py
index 2a98dda..0bf6dfb 100755
--- a/examples/rock_paper_scissors.py
+++ b/examples/rock_paper_scissors.py
@@ -120,7 +120,7 @@ def run_rock_paper_scissors(population_size: int = 100,
     evo = Evolution().repeat(
         evolution=(Evolution()
                    .survive(fraction=survive_fraction)
-                   .breed(parent_picker=pick_random, combiner=lambda x, y: x.combine(y), n_parents=2)
+                   .breed(parent_picker=pick_random(n_parents=2), combiner=lambda x, y: x.combine(y))
                    .mutate(lambda x: x.mutate())
                    .evaluate()
                    .callback(history.log)),
diff --git a/examples/travelling_salesman.py b/examples/travelling_salesman.py
index 213b305..bb00254 100755
--- a/examples/travelling_salesman.py
+++ b/examples/travelling_salesman.py
@@ -47,7 +47,7 @@ def print_function(population: Population):
 
     island_evo = (Evolution()
                   .survive(fraction=0.5)
-                  .breed(parent_picker=pick_random, combiner=cycle_crossover)
+                  .breed(parent_picker=pick_random(n_parents=2), combiner=cycle_crossover)
                   .mutate(swap_elements))
 
     evo = (Evolution()
diff --git a/tests/conftest.py b/tests/conftest.py
index e2e70f3..665b1f8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -54,7 +54,7 @@ def simple_evolution():
     return (
         Evolution()
         .survive(fraction=0.5)
-        .breed(parent_picker=pick_random, n_parents=2, combiner=lambda x, y: x + y)
+        .breed(parent_picker=pick_random(n_parents=2), combiner=lambda x, y: x + y)
         .mutate(lambda x: x + 1, probability=0.1)
     )
 
diff --git a/tests/test_evolution.py b/tests/test_evolution.py
index 9c53633..c030522 100644
--- a/tests/test_evolution.py
+++ b/tests/test_evolution.py
@@ -41,7 +41,7 @@ def callback(pop):
         sub_evo = (
             Evolution()
             .survive(fraction=0.5)
-            .breed(parent_picker=pick_random,
+            .breed(parent_picker=pick_random(n_parents=2),
                    combiner=lambda x, y: x + y)
             .callback(callback_function=callback)
         )
diff --git a/tests/test_logging.py b/tests/test_logging.py
index 4749e85..a71f5ae 100644
--- a/tests/test_logging.py
+++ b/tests/test_logging.py
@@ -56,9 +56,8 @@ def test_baselogger_works_via_evolution_callback(self, tmpdir, capsys):
         pop = Population(chromosomes=range(10), eval_function=lambda x: x)
         evo = (Evolution()
                .survive(fraction=0.5)
-               .breed(parent_picker=pick_random,
-                      combiner=lambda mom, dad: (mom + dad) / 2 + (random.random() - 0.5),
-                      n_parents=2)
+               .breed(parent_picker=pick_random(n_parents=2),
+                      combiner=lambda mom, dad: (mom + dad) / 2 + (random.random() - 0.5))
                .callback(logger.log, foo='bar'))
         pop.evolve(evolution=evo, n=2)
         # check characteristics of the file
@@ -132,9 +131,8 @@ def test_summarylogger_works_via_evolution(self, tmpdir, capsys):
         pop = Population(chromosomes=list(range(10)), eval_function=lambda x: x)
         evo = (Evolution()
                .survive(fraction=0.5)
-               .breed(parent_picker=pick_random,
-                      combiner=lambda mom, dad: (mom + dad) / 2 + (random.random() - 0.5),
-                      n_parents=2)
+               .breed(parent_picker=pick_random(n_parents=2),
+                      combiner=lambda mom, dad: (mom + dad) / 2 + (random.random() - 0.5))
                .evaluate()
                .callback(logger.log, foo='bar'))
         pop.evolve(evolution=evo, n=5)
@@ -157,9 +155,8 @@ def test_two_populations_can_use_same_logger(self, tmpdir, capsys):
         pop2 = Population(chromosomes=list(range(10)), eval_function=lambda x: x)
         evo = (Evolution()
                .survive(fraction=0.5)
-               .breed(parent_picker=pick_random,
-                      combiner=lambda mom, dad: (mom + dad) + 1,
-                      n_parents=2)
+               .breed(parent_picker=pick_random(n_parents=2),
+                      combiner=lambda mom, dad: (mom + dad) + 1)
                .evaluate()
                .callback(logger.log, foo="dino"))
         pop1.evolve(evolution=evo, n=5)
@@ -183,9 +180,8 @@ def test_every_mechanic_in_evolution_log(self, tmpdir, capsys):
         pop = Population(chromosomes=list(range(10)), eval_function=lambda x: x)
         evo = (Evolution()
                .survive(fraction=0.5)
-               .breed(parent_picker=pick_random,
-                      combiner=lambda mom, dad: (mom + dad) + 1,
-                      n_parents=2)
+               .breed(parent_picker=pick_random(n_parents=2),
+                      combiner=lambda mom, dad: (mom + dad) + 1)
                .evaluate()
                .callback(logger.log, every=2))
         pop.evolve(evolution=evo, n=100)
diff --git a/tests/test_population.py b/tests/test_population.py
index 636ecdc..fe457d2 100644
--- a/tests/test_population.py
+++ b/tests/test_population.py
@@ -172,14 +172,13 @@ def test_breed_amount_works(self, simple_chromosomes, simple_evaluation_function
 
     def test_breed_works_with_kwargs(self, simple_chromosomes, simple_evaluation_function):
         pop1 = Population(chromosomes=simple_chromosomes, eval_function=simple_evaluation_function)
-        pop1.survive(n=50).breed(parent_picker=pick_random,
-                                 combiner=lambda mom, dad: (mom + dad) / 2,
-                                 n_parents=2)
+        pop1.survive(n=50).breed(parent_picker=pick_random(),
+                                 combiner=lambda mom, dad: (mom + dad) / 2)
         assert len(pop1) == len(simple_chromosomes)
         pop2 = Population(chromosomes=simple_chromosomes, eval_function=simple_evaluation_function)
-        pop2.survive(n=50).breed(parent_picker=pick_random,
+        pop2.survive(n=50).breed(parent_picker=pick_random(n_parents=3),
                                  combiner=lambda *parents: sum(parents)/len(parents),
-                                 population_size=400, n_parents=3)
+                                 population_size=400)
         assert len(pop2) == 400
         assert pop2.intended_size == 400
 
@@ -187,13 +186,13 @@ def test_breed_raises_with_multiple_values_for_kwarg(self, simple_population):
 
         (simple_population
             .survive(fraction=0.5)
-            .breed(parent_picker=pick_random,
+            .breed(parent_picker=pick_random(n_parents=2),
                    combiner=lambda x, y: x + y))
 
         with raises(TypeError):
             (simple_population
                 .survive(fraction=0.5)
-                .breed(parent_picker=pick_random,
+                .breed(parent_picker=pick_random(n_parents=2),
                        combiner=lambda x, y: x + y, y=2))
 
 
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d98f659..0d2ccce 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -12,18 +12,18 @@ def combiner(x, y):
             return 1
 
         result = offspring_generator(parents=simple_population.individuals,
-                                     parent_picker=pick_random, combiner=combiner)
+                                     parent_picker=pick_random(), combiner=combiner)
         assert isinstance(next(result), Individual)
         assert next(result).chromosome == 1
 
     @mark.parametrize('n_parents', [1, 2, 3, 4])
     def test_args(self, n_parents: int, simple_population: Population):
-        def combiner(*parents, n_parents):
+        def combiner(*parents):
             assert len(parents) == n_parents
             return 1
 
-        result = offspring_generator(parents=simple_population.individuals, n_parents=n_parents,
-                                     parent_picker=pick_random, combiner=combiner)
+        result = offspring_generator(parents=simple_population.individuals,
+                                     parent_picker=pick_random(n_parents=n_parents), combiner=combiner)
         assert isinstance(next(result), Individual)
         assert next(result).chromosome == 1
 
@@ -32,7 +32,7 @@ def combiner(x):
             return 1
 
         def picker(parents):
-            return parents[0]
+            return parents[0],
 
         result = offspring_generator(parents=simple_population.individuals, parent_picker=picker, combiner=combiner)
         assert isinstance(next(result), Individual)
@@ -44,7 +44,7 @@ def combiner(x, y):
             yield 2
 
         result = offspring_generator(parents=simple_population.individuals,
-                                     parent_picker=pick_random, combiner=combiner)
+                                     parent_picker=pick_random(), combiner=combiner)
         for _ in range(10):
             assert next(result).chromosome == 1
             assert next(result).chromosome == 2