Skip to content

Commit

Permalink
added tests for parallel builds
Browse files Browse the repository at this point in the history
  • Loading branch information
nickeldan committed Sep 2, 2023
1 parent 56afe4d commit cef4537
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 198 deletions.
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[pytest]
testpaths =
tests
tests
timeout = 10
log_level = 10
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ flake8-annotations==3.0.1
flake8-warnings==0.4.0
mypy==1.4.1
pytest==7.4.0
pytest-timeout==2.1.0
30 changes: 20 additions & 10 deletions sandworm/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sys
import textwrap

from . import parallel
from . import core
from . import target

Expand All @@ -31,7 +30,7 @@ def get_args() -> tuple[argparse.Namespace, list[str]]:
dest="max_workers",
type=int,
nargs="?",
default=1,
const=-1,
help="Build in parallel. Optionally, specify the number of workers to use.",
)

Expand Down Expand Up @@ -101,10 +100,7 @@ def do_build(env: target.Environment, target_str: str, max_workers: int | None)
return True
target = env.main_target

if max_workers is None or max_workers > 1:
return parallel.root_parallel_build(target, max_workers)
else:
return core.root_build(target)
return core.root_build(target, max_workers=max_workers)


def main() -> int:
Expand All @@ -115,19 +111,33 @@ def main() -> int:
return 0

wormfile = pathlib.Path.cwd() / "Wormfile.py"
if not wormfile.is_file():
if args.command != "init":
print("No Wormfile.py found.", file=sys.stderr)

if args.command == "init":
if wormfile.is_file():
print("Wormfile.py already exists.", file=sys.stderr)
return 1
make_template(wormfile)
return 0

if not wormfile.is_file():
print("No Wormfile.py found.", file=sys.stderr)
return 1

core.init_logging(fmt=args.format, verbose=args.verbose)
if (env := create_environment(args, extra_args)) is None:
return 1

max_workers: int | None
match args.max_workers:
case None:
max_workers = 1
case n if n < 0:
max_workers = None
case n:
max_workers = n

if args.command == "build":
ret = do_build(env, args.target, args.max_workers)
ret = do_build(env, args.target, max_workers)
else:
ret = core.make_clean(env)

Expand Down
212 changes: 212 additions & 0 deletions sandworm/_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from __future__ import annotations
import concurrent.futures
import dataclasses
import logging
import logging.handlers
import multiprocessing
import multiprocessing.connection
import multiprocessing.queues
import threading
import typing

from . import target

Connection = multiprocessing.connection.Connection
ChildWaiter = Connection | set[Connection] | None
JobPreContext = tuple[ChildWaiter, Connection | None, Connection | None]

logger = logging.getLogger("sandworm.parallel")


@dataclasses.dataclass(slots=True, repr=False, eq=False)
class Job:
targ: target.Target
waiter: ChildWaiter
read_end: Connection
write_end: Connection


@dataclasses.dataclass(slots=True, repr=False, eq=False)
class ReducedJob:
targ: target.Target
read_end: Connection
write_end: Connection

@staticmethod
def from_job(job: Job) -> ReducedJob:
return ReducedJob(targ=job.targ, read_end=job.read_end, write_end=job.write_end)


def init_job_process(log_queue: multiprocessing.queues.Queue) -> None:
logging.getLogger().handlers = [logging.handlers.QueueHandler(log_queue)]


def send_job_status(
fileno_connection: Connection, job: Job | ReducedJob, status: bool, *, fileno: int | None = None
) -> None:
if fileno is None:
fileno = job.read_end.fileno()
job.write_end.send(status)
job.write_end.close()
fileno_connection.send(fileno)


def run_job(fileno_connection: Connection, job: ReducedJob, fileno: int) -> None:
try:
ret = job.targ.build()
except Exception:
logger.exception(f"Job for {job.targ.fullname()} crashed:")
ret = False
send_job_status(fileno_connection, job, ret, fileno=fileno)


class JobPool(concurrent.futures.ProcessPoolExecutor):
def __init__(self, max_workers: int | None, jobs: list[Job]) -> None:
self._log_queue: multiprocessing.queues.Queue = multiprocessing.Queue()
super().__init__(max_workers=max_workers, initializer=init_job_process, initargs=(self._log_queue,))
self._fileno_conn_read, self._fileno_conn_write = multiprocessing.Pipe()
self._jobs = jobs
self._pending_connections: dict[int, Connection] = {}
self._any_failures = False

for job in jobs:
self._add_pending_connection(job)

self._log_thread = threading.Thread(target=self._thread_func)

def _add_pending_connection(self, job: Job) -> None:
self._pending_connections[job.read_end.fileno()] = job.read_end

def _thread_func(self) -> None:
while (record := self._log_queue.get()) is not None:
assert isinstance(record, logging.LogRecord)
logger.handle(record)

def _handle_job(self, job: Job) -> None:
if job.targ.builder is None:
send_job_status(self._fileno_conn_write, job, job.targ.exists)
else:
fileno = job.read_end.fileno()
logger.debug(f"Starting job for target {job.targ.fullname()}")
self.submit(run_job, self._fileno_conn_write, ReducedJob.from_job(job), fileno)

def _handle_job_status(self, job: Job, dep_success: bool) -> None:
if dep_success:
self._handle_job(job)
else:
send_job_status(self._fileno_conn_write, job, False)

def _handle_ready_connection(self, conn: Connection) -> None:
success: bool = conn.recv()
if not success:
self._any_failures = True
conn.close()
indices_to_remove: set[int] = set()
for k, job in enumerate(self._jobs):
assert job.waiter is not None
job_finished = False
if isinstance(job.waiter, Connection):
if job.waiter is conn:
job_finished = True
elif conn in job.waiter:
job.waiter.remove(conn)
if not job.waiter:
job_finished = True

if job_finished:
self._handle_job_status(job, success)
indices_to_remove.add(k)

if indices_to_remove:
self._jobs = [job for k, job in enumerate(self._jobs) if k not in indices_to_remove]

def run(self, leaves: list[Job]) -> bool:
logger.debug("Starting job pool")

for leaf in leaves:
self._add_pending_connection(leaf)
self._handle_job(leaf)

while self._pending_connections:
fileno: int = self._fileno_conn_read.recv()
if (conn := self._pending_connections.pop(fileno, None)) is not None:
self._handle_ready_connection(conn)

logger.debug("Job pool finished")

return not self._any_failures

def __enter__(self) -> JobPool:
self._log_thread.start()
super().__enter__()
return self

def __exit__(self, *args: typing.Any) -> typing.Any:
ret = super().__exit__(*args)

self._log_queue.put(None)
self._log_thread.join()

return ret


def populate_job_pre_map(
job_pre_map: dict[target.Target, JobPreContext], targ: target.Target
) -> JobPreContext:
if (ctx := job_pre_map.get(targ)) is not None:
return ctx

child_waiter_set: set[Connection] = set()
for dep in targ.dependencies:
dep_ctx = populate_job_pre_map(job_pre_map, dep)
if (second_slot := dep_ctx[1]) is None:
if isinstance(first_slot := dep_ctx[0], Connection):
child_waiter_set.add(first_slot)
elif first_slot is not None:
child_waiter_set |= first_slot
else:
child_waiter_set.add(second_slot)

child_waiter: ChildWaiter
match len(child_waiter_set):
case 0:
child_waiter = None
case 1:
child_waiter = next(iter(child_waiter_set))
case _:
child_waiter = child_waiter_set

read_end: Connection | None
write_end: Connection | None
if targ.builder is None and targ.dependencies:
read_end = write_end = None
else:
read_end, write_end = multiprocessing.Pipe()

ctx = (child_waiter, read_end, write_end)
job_pre_map[targ] = ctx
return ctx


def parallel_root_build(main: target.Target, max_workers: int | None) -> bool:
logger.debug("Determining target dependencies")

job_pre_map: dict[target.Target, JobPreContext] = {}
populate_job_pre_map(job_pre_map, main)

jobs: list[Job] = []
leaves: list[Job] = []
for targ, (waiter, read_end, write_end) in job_pre_map.items():
if write_end is None:
continue
assert read_end is not None
job = Job(targ=targ, waiter=waiter, read_end=read_end, write_end=write_end)
if waiter is None:
leaves.append(job)
else:
jobs.append(job)
del job_pre_map

with JobPool(max_workers, jobs) as pool:
del jobs
return pool.run(leaves)
16 changes: 12 additions & 4 deletions sandworm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

from . import _graph
from . import _parallel
from . import target

_logger = logging.getLogger("sandworm.core")
Expand All @@ -21,8 +22,10 @@ def format(self, record: logging.LogRecord) -> str:
return msg


def init_logging(*, fmt: str = "%(message)s", verbose: bool = False) -> None:
handler = logging.StreamHandler(stream=sys.stdout)
def init_logging(
*, fmt: str = "%(message)s", verbose: bool = False, stream: typing.TextIO = sys.stdout
) -> None:
handler = logging.StreamHandler(stream=stream)
handler.setFormatter(_ColorFormatter(fmt=fmt))

logger = logging.getLogger("sandworm")
Expand All @@ -38,15 +41,20 @@ def _display_cycle(cycle: list[target.Target]) -> None:
_logger.error(f"\t{cycle[0].fullname()} from .")


def root_build(main: target.Target) -> bool:
def root_build(main: target.Target, max_workers: int | None = 1) -> bool:
if (cycle := _graph.Graph(main).find_cycle()) is not None:
_display_cycle(cycle)
return False

if not main.out_of_date:
return True

if ret := _build_sequence(_linearize(main)):
if max_workers == 1:
ret = _build_sequence(_linearize(main))
else:
ret = _parallel.parallel_root_build(main, max_workers)

if ret:
_logger.info("Build successful")
return ret

Expand Down
Loading

0 comments on commit cef4537

Please # to comment.