From cef4537e2b23789155805d86377a38ae8ceb3f8c Mon Sep 17 00:00:00 2001 From: Daniel Walker Date: Sat, 2 Sep 2023 00:18:55 -0400 Subject: [PATCH] added tests for parallel builds --- pytest.ini | 4 +- requirements-dev.txt | 1 + sandworm/_builder.py | 30 ++++-- sandworm/_parallel.py | 212 ++++++++++++++++++++++++++++++++++++++++++ sandworm/core.py | 16 +++- sandworm/parallel.py | 163 -------------------------------- sandworm/target.py | 15 ++- tests/test_build.py | 62 ++++++++---- 8 files changed, 305 insertions(+), 198 deletions(-) create mode 100644 sandworm/_parallel.py delete mode 100644 sandworm/parallel.py diff --git a/pytest.ini b/pytest.ini index 7f369af..4a1a642 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,5 @@ [pytest] testpaths = - tests \ No newline at end of file + tests +timeout = 10 +log_level = 10 diff --git a/requirements-dev.txt b/requirements-dev.txt index 5c4110e..714da35 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,3 +4,4 @@ flake8-annotations==3.0.1 flake8-warnings==0.4.0 mypy==1.4.1 pytest==7.4.0 +pytest-timeout==2.1.0 diff --git a/sandworm/_builder.py b/sandworm/_builder.py index 454d7a4..4667b9f 100755 --- a/sandworm/_builder.py +++ b/sandworm/_builder.py @@ -8,7 +8,6 @@ import sys import textwrap -from . import parallel from . import core from . import target @@ -31,7 +30,7 @@ def get_args() -> tuple[argparse.Namespace, list[str]]: dest="max_workers", type=int, nargs="?", - default=1, + const=-1, help="Build in parallel. Optionally, specify the number of workers to use.", ) @@ -101,10 +100,7 @@ def do_build(env: target.Environment, target_str: str, max_workers: int | None) return True target = env.main_target - if max_workers is None or max_workers > 1: - return parallel.root_parallel_build(target, max_workers) - else: - return core.root_build(target) + return core.root_build(target, max_workers=max_workers) def main() -> int: @@ -115,19 +111,33 @@ def main() -> int: return 0 wormfile = pathlib.Path.cwd() / "Wormfile.py" - if not wormfile.is_file(): - if args.command != "init": - print("No Wormfile.py found.", file=sys.stderr) + + if args.command == "init": + if wormfile.is_file(): + print("Wormfile.py already exists.", file=sys.stderr) return 1 make_template(wormfile) return 0 + if not wormfile.is_file(): + print("No Wormfile.py found.", file=sys.stderr) + return 1 + core.init_logging(fmt=args.format, verbose=args.verbose) if (env := create_environment(args, extra_args)) is None: return 1 + max_workers: int | None + match args.max_workers: + case None: + max_workers = 1 + case n if n < 0: + max_workers = None + case n: + max_workers = n + if args.command == "build": - ret = do_build(env, args.target, args.max_workers) + ret = do_build(env, args.target, max_workers) else: ret = core.make_clean(env) diff --git a/sandworm/_parallel.py b/sandworm/_parallel.py new file mode 100644 index 0000000..5e99430 --- /dev/null +++ b/sandworm/_parallel.py @@ -0,0 +1,212 @@ +from __future__ import annotations +import concurrent.futures +import dataclasses +import logging +import logging.handlers +import multiprocessing +import multiprocessing.connection +import multiprocessing.queues +import threading +import typing + +from . import target + +Connection = multiprocessing.connection.Connection +ChildWaiter = Connection | set[Connection] | None +JobPreContext = tuple[ChildWaiter, Connection | None, Connection | None] + +logger = logging.getLogger("sandworm.parallel") + + +@dataclasses.dataclass(slots=True, repr=False, eq=False) +class Job: + targ: target.Target + waiter: ChildWaiter + read_end: Connection + write_end: Connection + + +@dataclasses.dataclass(slots=True, repr=False, eq=False) +class ReducedJob: + targ: target.Target + read_end: Connection + write_end: Connection + + @staticmethod + def from_job(job: Job) -> ReducedJob: + return ReducedJob(targ=job.targ, read_end=job.read_end, write_end=job.write_end) + + +def init_job_process(log_queue: multiprocessing.queues.Queue) -> None: + logging.getLogger().handlers = [logging.handlers.QueueHandler(log_queue)] + + +def send_job_status( + fileno_connection: Connection, job: Job | ReducedJob, status: bool, *, fileno: int | None = None +) -> None: + if fileno is None: + fileno = job.read_end.fileno() + job.write_end.send(status) + job.write_end.close() + fileno_connection.send(fileno) + + +def run_job(fileno_connection: Connection, job: ReducedJob, fileno: int) -> None: + try: + ret = job.targ.build() + except Exception: + logger.exception(f"Job for {job.targ.fullname()} crashed:") + ret = False + send_job_status(fileno_connection, job, ret, fileno=fileno) + + +class JobPool(concurrent.futures.ProcessPoolExecutor): + def __init__(self, max_workers: int | None, jobs: list[Job]) -> None: + self._log_queue: multiprocessing.queues.Queue = multiprocessing.Queue() + super().__init__(max_workers=max_workers, initializer=init_job_process, initargs=(self._log_queue,)) + self._fileno_conn_read, self._fileno_conn_write = multiprocessing.Pipe() + self._jobs = jobs + self._pending_connections: dict[int, Connection] = {} + self._any_failures = False + + for job in jobs: + self._add_pending_connection(job) + + self._log_thread = threading.Thread(target=self._thread_func) + + def _add_pending_connection(self, job: Job) -> None: + self._pending_connections[job.read_end.fileno()] = job.read_end + + def _thread_func(self) -> None: + while (record := self._log_queue.get()) is not None: + assert isinstance(record, logging.LogRecord) + logger.handle(record) + + def _handle_job(self, job: Job) -> None: + if job.targ.builder is None: + send_job_status(self._fileno_conn_write, job, job.targ.exists) + else: + fileno = job.read_end.fileno() + logger.debug(f"Starting job for target {job.targ.fullname()}") + self.submit(run_job, self._fileno_conn_write, ReducedJob.from_job(job), fileno) + + def _handle_job_status(self, job: Job, dep_success: bool) -> None: + if dep_success: + self._handle_job(job) + else: + send_job_status(self._fileno_conn_write, job, False) + + def _handle_ready_connection(self, conn: Connection) -> None: + success: bool = conn.recv() + if not success: + self._any_failures = True + conn.close() + indices_to_remove: set[int] = set() + for k, job in enumerate(self._jobs): + assert job.waiter is not None + job_finished = False + if isinstance(job.waiter, Connection): + if job.waiter is conn: + job_finished = True + elif conn in job.waiter: + job.waiter.remove(conn) + if not job.waiter: + job_finished = True + + if job_finished: + self._handle_job_status(job, success) + indices_to_remove.add(k) + + if indices_to_remove: + self._jobs = [job for k, job in enumerate(self._jobs) if k not in indices_to_remove] + + def run(self, leaves: list[Job]) -> bool: + logger.debug("Starting job pool") + + for leaf in leaves: + self._add_pending_connection(leaf) + self._handle_job(leaf) + + while self._pending_connections: + fileno: int = self._fileno_conn_read.recv() + if (conn := self._pending_connections.pop(fileno, None)) is not None: + self._handle_ready_connection(conn) + + logger.debug("Job pool finished") + + return not self._any_failures + + def __enter__(self) -> JobPool: + self._log_thread.start() + super().__enter__() + return self + + def __exit__(self, *args: typing.Any) -> typing.Any: + ret = super().__exit__(*args) + + self._log_queue.put(None) + self._log_thread.join() + + return ret + + +def populate_job_pre_map( + job_pre_map: dict[target.Target, JobPreContext], targ: target.Target +) -> JobPreContext: + if (ctx := job_pre_map.get(targ)) is not None: + return ctx + + child_waiter_set: set[Connection] = set() + for dep in targ.dependencies: + dep_ctx = populate_job_pre_map(job_pre_map, dep) + if (second_slot := dep_ctx[1]) is None: + if isinstance(first_slot := dep_ctx[0], Connection): + child_waiter_set.add(first_slot) + elif first_slot is not None: + child_waiter_set |= first_slot + else: + child_waiter_set.add(second_slot) + + child_waiter: ChildWaiter + match len(child_waiter_set): + case 0: + child_waiter = None + case 1: + child_waiter = next(iter(child_waiter_set)) + case _: + child_waiter = child_waiter_set + + read_end: Connection | None + write_end: Connection | None + if targ.builder is None and targ.dependencies: + read_end = write_end = None + else: + read_end, write_end = multiprocessing.Pipe() + + ctx = (child_waiter, read_end, write_end) + job_pre_map[targ] = ctx + return ctx + + +def parallel_root_build(main: target.Target, max_workers: int | None) -> bool: + logger.debug("Determining target dependencies") + + job_pre_map: dict[target.Target, JobPreContext] = {} + populate_job_pre_map(job_pre_map, main) + + jobs: list[Job] = [] + leaves: list[Job] = [] + for targ, (waiter, read_end, write_end) in job_pre_map.items(): + if write_end is None: + continue + assert read_end is not None + job = Job(targ=targ, waiter=waiter, read_end=read_end, write_end=write_end) + if waiter is None: + leaves.append(job) + else: + jobs.append(job) + del job_pre_map + + with JobPool(max_workers, jobs) as pool: + del jobs + return pool.run(leaves) diff --git a/sandworm/core.py b/sandworm/core.py index 12246e3..7c325a3 100644 --- a/sandworm/core.py +++ b/sandworm/core.py @@ -4,6 +4,7 @@ import typing from . import _graph +from . import _parallel from . import target _logger = logging.getLogger("sandworm.core") @@ -21,8 +22,10 @@ def format(self, record: logging.LogRecord) -> str: return msg -def init_logging(*, fmt: str = "%(message)s", verbose: bool = False) -> None: - handler = logging.StreamHandler(stream=sys.stdout) +def init_logging( + *, fmt: str = "%(message)s", verbose: bool = False, stream: typing.TextIO = sys.stdout +) -> None: + handler = logging.StreamHandler(stream=stream) handler.setFormatter(_ColorFormatter(fmt=fmt)) logger = logging.getLogger("sandworm") @@ -38,7 +41,7 @@ def _display_cycle(cycle: list[target.Target]) -> None: _logger.error(f"\t{cycle[0].fullname()} from .") -def root_build(main: target.Target) -> bool: +def root_build(main: target.Target, max_workers: int | None = 1) -> bool: if (cycle := _graph.Graph(main).find_cycle()) is not None: _display_cycle(cycle) return False @@ -46,7 +49,12 @@ def root_build(main: target.Target) -> bool: if not main.out_of_date: return True - if ret := _build_sequence(_linearize(main)): + if max_workers == 1: + ret = _build_sequence(_linearize(main)) + else: + ret = _parallel.parallel_root_build(main, max_workers) + + if ret: _logger.info("Build successful") return ret diff --git a/sandworm/parallel.py b/sandworm/parallel.py deleted file mode 100644 index 313e149..0000000 --- a/sandworm/parallel.py +++ /dev/null @@ -1,163 +0,0 @@ -import concurrent.futures -import dataclasses -import logging -import multiprocessing -import multiprocessing.connection -import multiprocessing.queues -import typing - -from . import target - -_Connection = multiprocessing.connection.Connection -_ChildWaiter = _Connection | set[_Connection] | None -_JobPreContext = tuple[_ChildWaiter, _Connection | None, _Connection | None] - -_logger = logging.getLogger("sandworm.parallel") - - -@dataclasses.dataclass(slots=True, repr=False, eq=False) -class _Job: - targ: target.Target - waiter: _ChildWaiter - read_end: _Connection - write_end: _Connection - - -def _send_job_status(fileno_queue: multiprocessing.queues.Queue, job: _Job, status: bool) -> None: - fileno = job.write_end.fileno() - job.write_end.send(status) - job.write_end.close() - fileno_queue.put(fileno) - - -def _run_job(fileno_queue: multiprocessing.queues.Queue, job: _Job) -> None: - job.read_end.close() - _send_job_status(fileno_queue, job, job.targ.build()) - - -class _JobPool(concurrent.futures.ProcessPoolExecutor): - def __init__(self, max_workers: int | None, jobs: list[_Job]) -> None: - super().__init__(max_workers=max_workers) - self._jobs = jobs - self._pending_connections: dict[int, _Connection] = {} - self._fileno_queue: multiprocessing.queues.Queue = multiprocessing.Queue() - self._any_failures = False - - for job in jobs: - if isinstance(job.waiter, _Connection): - self._pending_connections[job.waiter.fileno()] = job.waiter - else: - assert job.waiter is not None - for conn in job.waiter: - self._pending_connections[conn.fileno()] = conn - - def shutdown(self, *args: typing.Any, **kwargs: typing.Any) -> None: - kwargs["cancel_futures"] = True - super().shutdown(*args, **kwargs) - - def _handle_job(self, job: _Job) -> None: - if job.targ.builder is None: - _send_job_status(self._fileno_queue, job, job.targ.exists) - else: - self.submit(_run_job, self._fileno_queue, job) - - def _handle_job_status(self, job: _Job, dep_success: bool) -> None: - if dep_success: - self._handle_job(job) - else: - _send_job_status(self._fileno_queue, job, False) - self._any_failures = True - - def _handle_ready_connection(self, conn: _Connection) -> None: - success: bool = conn.recv() - conn.close() - indices_to_remove: set[int] = set() - for k, job in enumerate(self._jobs): - assert job.waiter is not None - if isinstance(job.waiter, _Connection): - if job.waiter is conn: - self._handle_job_status(job, success) - indices_to_remove.add(k) - elif conn in job.waiter: - job.waiter.remove(conn) - if not job.waiter: - self._handle_job_status(job, success) - indices_to_remove.add(k) - - if indices_to_remove: - self._jobs = [job for k, job in enumerate(self._jobs) if k not in indices_to_remove] - - def run(self, leaves: list[_Job]) -> bool: - for leaf in leaves: - self._handle_job(leaf) - - while self._jobs: - fileno: int = self._fileno_queue.get() - conn = self._pending_connections.pop(fileno) - self._handle_ready_connection(conn) - - return self._any_failures - - -def _populate_job_pre_map( - job_pre_map: dict[target.Target, _JobPreContext], targ: target.Target -) -> _JobPreContext: - if (ctx := job_pre_map.get(targ)) is not None: - return ctx - - child_waiter_set: set[_Connection] = set() - for dep in targ.dependencies: - dep_ctx = _populate_job_pre_map(job_pre_map, dep) - if (second_slot := dep_ctx[1]) is None: - if isinstance(first_slot := dep_ctx[0], _Connection): - child_waiter_set.add(first_slot) - elif first_slot is not None: - child_waiter_set |= first_slot - else: - child_waiter_set.add(second_slot) - - child_waiter: _ChildWaiter - match len(child_waiter_set): - case 0: - child_waiter = None - case 1: - child_waiter = next(iter(child_waiter_set)) - case _: - child_waiter = child_waiter_set - - read_end: _Connection | None - write_end: _Connection | None - if targ.builder is None and targ.dependencies: - read_end = write_end = None - else: - read_end, write_end = multiprocessing.Pipe() - - ctx = (child_waiter, read_end, write_end) - job_pre_map[targ] = ctx - return ctx - - -def root_parallel_build(main: target.Target, max_workers: int | None) -> bool: - job_pre_map: dict[target.Target, _JobPreContext] = {} - _populate_job_pre_map(job_pre_map, main) - - jobs: list[_Job] = [] - leaves: list[_Job] = [] - for targ, (waiter, read_end, write_end) in job_pre_map.items(): - if write_end is None: - continue - assert read_end is not None - job = _Job(targ=targ, waiter=waiter, read_end=read_end, write_end=write_end) - if waiter is None: - leaves.append(job) - else: - jobs.append(job) - del job_pre_map - - with _JobPool(max_workers, jobs) as pool: - del jobs - ret = pool.run(leaves) - - if ret: - _logger.info("Build successful") - return ret diff --git a/sandworm/target.py b/sandworm/target.py index 80f6a2e..ec4167b 100644 --- a/sandworm/target.py +++ b/sandworm/target.py @@ -5,6 +5,7 @@ import logging import os import pathlib +import pickle import typing _T = typing.TypeVar("_T", bound="Target") @@ -15,6 +16,10 @@ _sentinel = object() +def _dummy_builder(targ: _T) -> bool: + return True + + class Target: def __init__( self: _T, @@ -29,6 +34,12 @@ def __init__( self._env: Environment | None = None self._built = False + if builder is not None: + try: + pickle.dumps(builder) + except Exception as e: + raise TypeError("Builders must be picklable.") from e + @typing.final def __eq__(self, other: typing.Any) -> bool: return type(self) is type(other) and self.fullname() == other.fullname() @@ -86,7 +97,7 @@ def build(self: _T) -> bool: _logger.debug(f"Build for {self.fullname()} succeeded") self._built = True else: - _logger.error(f"Build for {self.fullname()} suceeded") + _logger.error(f"Build for {self.fullname()} failed") return ret @@ -165,7 +176,7 @@ def __init__(self, file: pathlib.Path | str, prev: Environment | None = None) -> self._clean_targets: list[Target] = [] self._main_target: Target - self.add_target(Target("", builder=lambda x: True), main=True) + self.add_target(Target("", builder=_dummy_builder), main=True) def __repr__(self) -> str: return f"Environment(basedir={self.basedir}, {self._map})" diff --git a/tests/test_build.py b/tests/test_build.py index c8e7321..e184bf6 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,39 +1,58 @@ +import enum import os import pathlib import time +import pytest + import sandworm +class MaxWorkers(enum.IntEnum): + SERIAL = 1 + PARALLEL = 4 + + +parametrize_workers = pytest.mark.parametrize("max_workers", MaxWorkers) + + def check_builder(targ: sandworm.FileTarget) -> bool: with open(targ.name, "w") as f: f.write("check\n") return True -def test_no_targets(env: sandworm.Environment) -> None: - assert sandworm.root_build(env.main_target) +@parametrize_workers +def test_no_targets(env: sandworm.Environment, max_workers: int) -> None: + assert sandworm.root_build(env.main_target, max_workers=max_workers) -def test_single_target(env: sandworm.Environment) -> None: +@parametrize_workers +def test_single_target(env: sandworm.Environment, max_workers: int) -> None: env.add_target(sandworm.FileTarget("foo.txt", builder=check_builder), main=True) assert env.main_target is not None - assert sandworm.root_build(env.main_target) + assert sandworm.root_build(env.main_target, max_workers=max_workers) path = pathlib.Path("foo.txt") assert path.is_file() with path.open() as f: assert f.read() == "check\n" -def test_fail_build(env: sandworm.Environment) -> None: - foo_target = sandworm.FileTarget("foo.txt", builder=lambda x: False) +def false_builder(targ: sandworm.Target) -> bool: + return False + + +@parametrize_workers +def test_fail_build(env: sandworm.Environment, max_workers: int) -> None: + foo_target = sandworm.FileTarget("foo.txt", builder=false_builder) env.add_target(foo_target) - assert not sandworm.root_build(foo_target) + assert not sandworm.root_build(foo_target, max_workers=max_workers) -def test_target_out_of_date(env: sandworm.Environment) -> None: +@parametrize_workers +def test_target_out_of_date(env: sandworm.Environment, max_workers: int) -> None: bar_target = sandworm.FileTarget("bar.txt") foo_target = sandworm.FileTarget("foo.txt", dependencies=[bar_target], builder=check_builder) env.add_target(bar_target) @@ -45,14 +64,15 @@ def test_target_out_of_date(env: sandworm.Environment) -> None: later = int(time.time()) + 5 os.utime("bar.txt", (later, later)) - assert sandworm.root_build(foo_target) + assert sandworm.root_build(foo_target, max_workers=max_workers) path = pathlib.Path("foo.txt") assert path.is_file() with path.open() as f: assert f.read() == "check\n" -def test_target_not_out_of_date(env: sandworm.Environment) -> None: +@parametrize_workers +def test_target_not_out_of_date(env: sandworm.Environment, max_workers: int) -> None: bar_target = sandworm.FileTarget("bar.txt") foo_target = sandworm.FileTarget("foo.txt", dependencies=[bar_target], builder=check_builder) env.add_target(bar_target) @@ -61,7 +81,7 @@ def test_target_not_out_of_date(env: sandworm.Environment) -> None: for name in ("bar", "foo"): pathlib.Path(f"{name}.txt").touch() - assert sandworm.root_build(foo_target) + assert sandworm.root_build(foo_target, max_workers=max_workers) path = pathlib.Path("foo.txt") assert path.is_file() with path.open() as f: @@ -95,22 +115,28 @@ def test_clean_targets_in_reverse_order(env: sandworm.Environment) -> None: assert f.read() == "bar\nfoo\n" +def true_builder(targ: sandworm.Target) -> bool: + return True + + def test_fail_cyclic_dependency(env: sandworm.Environment) -> None: - foo_target = sandworm.Target("foo", builder=lambda x: True) - bar_target = sandworm.Target("bar", builder=lambda x: True, dependencies=[foo_target]) - env.add_target(sandworm.Target("foo", builder=lambda x: True, dependencies=[bar_target]), main=True) + foo_target = sandworm.Target("foo", builder=true_builder) + bar_target = sandworm.Target("bar", builder=true_builder, dependencies=[foo_target]) + env.add_target(sandworm.Target("foo", builder=true_builder, dependencies=[bar_target]), main=True) assert not sandworm.root_build(env.main_target) -def test_fail_no_rule_to_build(env: sandworm.Environment) -> None: +@parametrize_workers +def test_fail_no_rule_to_build(env: sandworm.Environment, max_workers: int) -> None: env.add_target(sandworm.Target("foo"), main=True) - assert not sandworm.root_build(env.main_target) + assert not sandworm.root_build(env.main_target, max_workers=max_workers) -def test_no_rule_but_dependencies(env: sandworm.Environment) -> None: +@parametrize_workers +def test_no_rule_but_dependencies(env: sandworm.Environment, max_workers: int) -> None: bar_target = sandworm.FileTarget("bar.txt", builder=check_builder) env.add_target(sandworm.Target("foo", dependencies=[bar_target]), main=True) - assert sandworm.root_build(env.main_target) + assert sandworm.root_build(env.main_target, max_workers=max_workers)