diff --git a/dask_ctl/__init__.py b/dask_ctl/__init__.py index 035e332..c2a5089 100644 --- a/dask_ctl/__init__.py +++ b/dask_ctl/__init__.py @@ -1,5 +1,5 @@ from ._version import get_versions - +from .exceptions import DaskClusterConfigNotFound # noqa import os.path from dask.widgets import TEMPLATE_PATHS diff --git a/dask_ctl/ctl.yaml b/dask_ctl/ctl.yaml index 2e9e411..1dac4dc 100644 --- a/dask_ctl/ctl.yaml +++ b/dask_ctl/ctl.yaml @@ -1,2 +1,3 @@ ctl: disable_discovery: [] + cluster-spec: null diff --git a/dask_ctl/exceptions.py b/dask_ctl/exceptions.py new file mode 100644 index 0000000..14bfcd7 --- /dev/null +++ b/dask_ctl/exceptions.py @@ -0,0 +1,2 @@ +class DaskClusterConfigNotFound(FileNotFoundError): + """Unable to find the Dask cluster config.""" diff --git a/dask_ctl/lifecycle.py b/dask_ctl/lifecycle.py index 71c6c82..bca6616 100644 --- a/dask_ctl/lifecycle.py +++ b/dask_ctl/lifecycle.py @@ -1,21 +1,32 @@ import importlib from typing import List +import dask.config from dask.widgets import get_template from dask.utils import typename +from distributed.deploy import LocalCluster from distributed.deploy.cluster import Cluster from .discovery import discover_cluster_names, discover_clusters from .spec import load_spec from .utils import loop +from .exceptions import DaskClusterConfigNotFound -def create_cluster(spec_path: str) -> Cluster: +def create_cluster( + spec_path: str = None, + local_fallback: bool = False, + asynchronous: bool = False, +) -> Cluster: """Create a cluster from a spec file. Parameters ---------- spec_path - Path to a cluster spec file. + Path to a cluster spec file. Defaults to ``dask-cluster.yaml``. + local_fallback + Create a LocalCluster if spec file not found. + asynchronous + Start the cluster in asynchronous mode Returns ------- @@ -37,15 +48,25 @@ def create_cluster(spec_path: str) -> Cluster: LocalCluster(b3973c71, 'tcp://127.0.0.1:8786', workers=4, threads=12, memory=17.18 GB) """ + spec_path = ( + dask.config.get("ctl.cluster-spec", None, override_with=spec_path) + or "dask-cluster.yaml" + ) async def _create_cluster(): - cm_module, cm_class, args, kwargs = load_spec(spec_path) + try: + cm_module, cm_class, args, kwargs = load_spec(spec_path) + except FileNotFoundError as e: + if local_fallback: + return LocalCluster(asynchronous=asynchronous) + else: + raise DaskClusterConfigNotFound(f"Unable to find {spec_path}") from e module = importlib.import_module(cm_module) cluster_manager = getattr(module, cm_class) kwargs = {key.replace("-", "_"): entry for key, entry in kwargs.items()} - cluster = await cluster_manager(*args, **kwargs, asynchronous=True) + cluster = cluster_manager(*args, **kwargs, asynchronous=asynchronous) cluster.shutdown_on_close = False return cluster diff --git a/dask_ctl/tests/test_lifecycle.py b/dask_ctl/tests/test_lifecycle.py index 7eea396..21c2740 100644 --- a/dask_ctl/tests/test_lifecycle.py +++ b/dask_ctl/tests/test_lifecycle.py @@ -1,9 +1,11 @@ import pytest import ast +import dask.config from dask.distributed import LocalCluster from dask_ctl.lifecycle import create_cluster, get_snippet +from dask_ctl.exceptions import DaskClusterConfigNotFound def test_create_cluster(simple_spec_path): @@ -12,6 +14,18 @@ def test_create_cluster(simple_spec_path): assert isinstance(cluster, LocalCluster) +def test_create_cluster_fallback(): + with pytest.raises(DaskClusterConfigNotFound, match="dask-cluster.yaml"): + cluster = create_cluster() + + with dask.config.set({"ctl.cluster-spec": "foo.yaml"}): + with pytest.raises(DaskClusterConfigNotFound, match="foo.yaml"): + cluster = create_cluster() + + cluster = create_cluster(local_fallback=True) + assert isinstance(cluster, LocalCluster) + + @pytest.mark.xfail(reason="Proxy cluster discovery not working") def test_snippet(): with LocalCluster(scheduler_port=8786) as _: