diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index 2883f063..ced8b171 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -247,7 +247,7 @@ def get_project_as_df(self, name, df_type="pandas", group_by=None): return project.to_df(df_type=df_type, group_by=None) @failsafe - def get_or_create_project(self, name: str, **kwargs): + def get_or_create_project(self, name: str, **kwargs) -> Project: """Get or create a project. Parameters diff --git a/rubicon_ml/domain/__init__.py b/rubicon_ml/domain/__init__.py index 204b48ee..ab8cab65 100644 --- a/rubicon_ml/domain/__init__.py +++ b/rubicon_ml/domain/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Union +from typing import TypeVar, Union from rubicon_ml.domain import utils from rubicon_ml.domain.artifact import Artifact @@ -13,6 +13,11 @@ DOMAIN_TYPES = Union[Artifact, Dataframe, Experiment, Feature, Metric, Parameter, Project] +DomainsVar = TypeVar( + "DomainsVar", Artifact, Dataframe, Experiment, Feature, Metric, Parameter, Project +) + + __all__ = [ "Artifact", "Dataframe", diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 160231a0..c9058482 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -3,7 +3,8 @@ import tempfile import warnings from datetime import datetime -from typing import List, Optional +from json import JSONDecodeError +from typing import Any, Dict, List, Optional, Type from zipfile import ZipFile import fsspec @@ -42,34 +43,40 @@ def __init__(self, root_dir: str, **storage_options): # --- Filesystem Helpers --- - def _cat(self, path): + def _cat(self, path: str): """Returns the contents of the file at `path`.""" return self.filesystem.cat(path) - def _cat_paths(self, metadata_paths): + def _cat_paths(self, metadata_paths: List[str]) -> Dict[str, Any]: """Cat `metadata_paths` to get the list of files to include. Ignore FileNotFoundErrors to avoid misc file errors, like hidden dotfiles. """ - files = [] + if not isinstance(metadata_paths, list): + metadata_paths = [metadata_paths] + if not metadata_paths: + return {} + + files = {} + for path, metadata in self.filesystem.cat(metadata_paths, on_error="return").items(): if isinstance(metadata, FileNotFoundError): warning = f"{path} not found. Was this file unintentionally created?" warnings.warn(warning) else: - files.append(metadata) + files[path] = metadata return files - def _exists(self, path): + def _exists(self, path: str) -> bool: """Returns True if a file exists at `path`, False otherwise.""" return self.filesystem.exists(path) - def _glob(self, globstring): + def _glob(self, globstring: str): """Returns the names of the files matching `globstring`.""" return self.filesystem.glob(globstring, detail=True) - def _ls_directories_only(self, path): + def _ls_directories_only(self, path: str) -> List[str]: """Returns the names of all the directories at path `path`.""" directories = [ os.path.join(p.get("name"), "metadata.json") @@ -79,14 +86,14 @@ def _ls_directories_only(self, path): return directories - def _ls(self, path): + def _ls(self, path: str): return self.filesystem.ls(path) - def _mkdir(self, dirpath): + def _mkdir(self, dirpath: str): """Creates a directory `dirpath` with parents.""" return self.filesystem.mkdirs(dirpath, exist_ok=True) - def _modified(self, path): + def _modified(self, path: str): return self.filesystem.modified(path) def _persist_bytes(self, bytes_data, path): @@ -125,6 +132,39 @@ def _rm(self, path): """Recursively remove all files at `path`.""" return self.filesystem.rm(path, recursive=True) + def _load_metadata_files( + self, metadata_root: str, domain_type: Type[domain.DomainsVar] + ) -> List[domain.DomainsVar]: + """Load metadata files from the given root directory and return a list of domain objects.""" + # find all directories, prepare a list of those plus `metadata.yaml` + try: + metadata_paths = self._ls_directories_only(metadata_root) + except FileNotFoundError: + return [] + + loaded_domains = [] + # cat_paths will check for FileNotFoundErrors and skip any missing files + # it loads the contents of the found files + for path, metadata in self._cat_paths(metadata_paths).items(): + try: + metadata_contents = json.loads(metadata) + except JSONDecodeError: + warnings.warn(f"Failed to load metadata for {domain_type.__name__} at {path}") + continue + + try: + loaded_domain = domain_type(**metadata_contents) + except TypeError: + warnings.warn(f"Failed to load {domain_type.__name__} from metadata at {path}") + continue + + loaded_domains.append(loaded_domain) + + if loaded_domains: + loaded_domains.sort(key=lambda d: d.created_at) + + return loaded_domains + # -------- Projects -------- def _get_project_metadata_path(self, project_name): @@ -148,7 +188,7 @@ def create_project(self, project): self._persist_domain(project, project_metadata_path) - def get_project(self, project_name): + def get_project(self, project_name: str) -> domain.Project: """Retrieve a project from the configured filesystem. Parameters @@ -170,7 +210,7 @@ def get_project(self, project_name): return domain.Project(**project) - def get_projects(self): + def get_projects(self) -> List[domain.Project]: """Get the list of projects from the filesystem. Returns @@ -179,17 +219,10 @@ def get_projects(self): The list of projects from the filesystem. """ try: - project_metadata_paths = self._ls_directories_only(self.root_dir) - projects = [ - domain.Project(**json.loads(metadata)) - for metadata in self._cat_paths(project_metadata_paths) - ] - projects.sort(key=lambda p: p.created_at) + return self._load_metadata_files(self.root_dir, domain.Project) except FileNotFoundError: return [] - return projects - # ------ Experiments ------- def _get_experiment_metadata_root(self, project_name): @@ -220,7 +253,7 @@ def create_experiment(self, experiment): self._persist_domain(experiment, experiment_metadata_path) - def get_experiment(self, project_name, experiment_id): + def get_experiment(self, project_name: str, experiment_id: str) -> domain.Experiment: """Retrieve an experiment from the configured filesystem. Parameters @@ -244,7 +277,7 @@ def get_experiment(self, project_name, experiment_id): return domain.Experiment(**experiment) - def get_experiments(self, project_name): + def get_experiments(self, project_name: str) -> List[domain.Experiment]: """Retrieve all experiments from the configured filesystem that belong to the project with name `project_name`. @@ -262,17 +295,7 @@ def get_experiments(self, project_name): """ experiment_metadata_root = self._get_experiment_metadata_root(project_name) - try: - experiment_metadata_paths = self._ls_directories_only(experiment_metadata_root) - experiments = [ - domain.Experiment(**json.loads(metadata)) - for metadata in self._cat_paths(experiment_metadata_paths) - ] - experiments.sort(key=lambda e: e.created_at) - except FileNotFoundError: - return [] - - return experiments + return self._load_metadata_files(experiment_metadata_root, domain.Experiment) # ------- Archiving -------- @@ -330,7 +353,10 @@ def _archive( return zip_archive_filename def _experiments_from_archive( - self, project_name, remote_rubicon_root: str, latest_only: Optional[bool] = False + self, + project_name, + remote_rubicon_root: str, + latest_only: Optional[bool] = False, ): """Retrieve archived experiments into this project's experiments folder. @@ -484,19 +510,12 @@ def get_artifacts_metadata(self, project_name, experiment_id=None): list of rubicon.domain.Artifact The artifacts logged to the specified object. """ - artifact_metadata_root = self._get_artifact_metadata_root(project_name, experiment_id) - try: - artifact_metadata_paths = self._ls_directories_only(artifact_metadata_root) - artifacts = [ - domain.Artifact(**json.loads(metadata)) - for metadata in self._cat_paths(artifact_metadata_paths) - ] - artifacts.sort(key=lambda a: a.created_at) + artifact_metadata_root = self._get_artifact_metadata_root(project_name, experiment_id) except FileNotFoundError: return [] - return artifacts + return self._load_metadata_files(artifact_metadata_root, domain.Artifact) def get_artifact_data(self, project_name, artifact_id, experiment_id=None): """Retrieve an artifact's raw data. @@ -672,7 +691,9 @@ def get_dataframe_metadata(self, project_name, dataframe_id, experiment_id=None) return domain.Dataframe(**dataframe) - def get_dataframes_metadata(self, project_name, experiment_id=None): + def get_dataframes_metadata( + self, project_name: str, experiment_id: Optional[str] = None + ) -> List[domain.Dataframe]: """Retrieve all dataframes' metadata from the configured filesystem that belong to the specified object. @@ -691,19 +712,12 @@ def get_dataframes_metadata(self, project_name, experiment_id=None): list of rubicon.domain.Dataframe The dataframes logged to the specified object. """ - dataframe_metadata_root = self._get_dataframe_metadata_root(project_name, experiment_id) - try: - dataframe_metadata_paths = self._ls_directories_only(dataframe_metadata_root) - dataframes = [ - domain.Dataframe(**json.loads(metadata)) - for metadata in self._cat_paths(dataframe_metadata_paths) - ] - dataframes.sort(key=lambda d: d.created_at) + dataframe_metadata_root = self._get_dataframe_metadata_root(project_name, experiment_id) except FileNotFoundError: return [] - return dataframes + return self._load_metadata_files(dataframe_metadata_root, domain.Dataframe) def get_dataframe_data(self, project_name, dataframe_id, experiment_id=None, df_type="pandas"): """Retrieve a dataframe's raw data. @@ -803,7 +817,9 @@ def create_feature(self, feature, project_name, experiment_id): self._persist_domain(feature, feature_metadata_path) - def get_feature(self, project_name, experiment_id, feature_name): + def get_feature( + self, project_name: str, experiment_id: str, feature_name: str + ) -> domain.Feature: """Retrieve a feature from the configured filesystem. Parameters @@ -832,7 +848,7 @@ def get_feature(self, project_name, experiment_id, feature_name): return domain.Feature(**feature) - def get_features(self, project_name, experiment_id): + def get_features(self, project_name: str, experiment_id: str) -> List[domain.Feature]: """Retrieve all features from the configured filesystem that belong to the experiment with ID `experiment_id`. @@ -851,19 +867,11 @@ def get_features(self, project_name, experiment_id): The features logged to the experiment with ID `experiment_id`. """ - feature_metadata_root = self._get_feature_metadata_root(project_name, experiment_id) - try: - feature_metadata_paths = self._ls_directories_only(feature_metadata_root) - features = [ - domain.Feature(**json.loads(metadata)) - for metadata in self._cat_paths(feature_metadata_paths) - ] - features.sort(key=lambda f: f.created_at) + feature_metadata_root = self._get_feature_metadata_root(project_name, experiment_id) except FileNotFoundError: return [] - - return features + return self._load_metadata_files(feature_metadata_root, domain.Feature) # -------- Metrics --------- @@ -905,7 +913,7 @@ def create_metric(self, metric, project_name, experiment_id): self._persist_domain(metric, metric_metadata_path) - def get_metric(self, project_name, experiment_id, metric_name): + def get_metric(self, project_name: str, experiment_id: str, metric_name: str) -> domain.Metric: """Retrieve a metric from the configured filesystem. Parameters @@ -934,7 +942,7 @@ def get_metric(self, project_name, experiment_id, metric_name): return domain.Metric(**metric) - def get_metrics(self, project_name, experiment_id): + def get_metrics(self, project_name: str, experiment_id: str) -> List[domain.Metric]: """Retrieve all metrics from the configured filesystem that belong to the experiment with ID `experiment_id`. @@ -953,19 +961,11 @@ def get_metrics(self, project_name, experiment_id): The metrics logged to the experiment with ID `experiment_id`. """ - metric_metadata_root = self._get_metric_metadata_root(project_name, experiment_id) - try: - metric_metadata_paths = self._ls_directories_only(metric_metadata_root) - metrics = [ - domain.Metric(**json.loads(metadata)) - for metadata in self._cat_paths(metric_metadata_paths) - ] - metrics.sort(key=lambda m: m.created_at) + metric_metadata_root = self._get_metric_metadata_root(project_name, experiment_id) except FileNotFoundError: return [] - - return metrics + return self._load_metadata_files(metric_metadata_root, domain.Metric) # ------- Parameters ------- @@ -1035,7 +1035,7 @@ def get_parameter(self, project_name, experiment_id, parameter_name): return domain.Parameter(**parameter) - def get_parameters(self, project_name, experiment_id): + def get_parameters(self, project_name: str, experiment_id: str) -> List[domain.Parameter]: """Retrieve all parameters from the configured filesystem that belong to the experiment with ID `experiment_id`. @@ -1054,19 +1054,12 @@ def get_parameters(self, project_name, experiment_id): The parameters logged to the experiment with ID `experiment_id`. """ - parameter_metadata_root = self._get_parameter_metadata_root(project_name, experiment_id) - try: - parameter_metadata_paths = self._ls_directories_only(parameter_metadata_root) - parameters = [ - domain.Parameter(**json.loads(metadata)) - for metadata in self._cat_paths(parameter_metadata_paths) - ] - parameters.sort(key=lambda p: p.created_at) + parameter_metadata_root = self._get_parameter_metadata_root(project_name, experiment_id) except FileNotFoundError: return [] - return parameters + return self._load_metadata_files(parameter_metadata_root, domain.Parameter) # ---------- Tags ---------- @@ -1101,7 +1094,12 @@ def _get_tag_metadata_root( return f"{entity_metadata_root}/{entity_identifier}" def add_tags( - self, project_name, tags, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + tags, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Persist tags to the configured filesystem. @@ -1130,7 +1128,12 @@ def add_tags( self._persist_domain({"added_tags": tags}, tag_metadata_path) def remove_tags( - self, project_name, tags, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + tags, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Delete tags from the configured filesystem. @@ -1225,7 +1228,12 @@ def _get_comment_metadata_root( ) def add_comments( - self, project_name, comments, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + comments, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Persist comments to the configured filesystem. @@ -1254,7 +1262,12 @@ def add_comments( self._persist_domain({"added_comments": comments}, comment_metadata_path) def remove_comments( - self, project_name, comments, experiment_id=None, entity_identifier=None, entity_type=None + self, + project_name, + comments, + experiment_id=None, + entity_identifier=None, + entity_type=None, ): """Delete comments from the configured filesystem. diff --git a/tests/unit/client/test_project_client.py b/tests/unit/client/test_project_client.py index 2b33cdc8..f7ad81ab 100644 --- a/tests/unit/client/test_project_client.py +++ b/tests/unit/client/test_project_client.py @@ -237,7 +237,15 @@ def test_to_dask_df(rubicon_and_project_client_with_experiments): assert len(df) == 10 # check the cols within the df - exp_details = ["id", "name", "description", "model_name", "commit_hash", "tags", "created_at"] + exp_details = [ + "id", + "name", + "description", + "model_name", + "commit_hash", + "tags", + "created_at", + ] for detail in exp_details: assert detail in df.columns @@ -250,7 +258,15 @@ def test_to_pandas_df(rubicon_and_project_client_with_experiments): assert len(df) == 10 # check the cols within the df - exp_details = ["id", "name", "description", "model_name", "commit_hash", "tags", "created_at"] + exp_details = [ + "id", + "name", + "description", + "model_name", + "commit_hash", + "tags", + "created_at", + ] for detail in exp_details: assert detail in df.columns @@ -280,7 +296,9 @@ def test_to_dask_df_grouped_by_commit_hash(rubicon_and_project_client_with_exper assert detail in df.columns -def test_to_pandas_df_grouped_by_commit_hash(rubicon_and_project_client_with_experiments): +def test_to_pandas_df_grouped_by_commit_hash( + rubicon_and_project_client_with_experiments, +): project = rubicon_and_project_client_with_experiments[1] dfs = project.to_df(df_type="pandas", group_by="commit_hash") @@ -509,3 +527,23 @@ def test_archive_remote_rubicon_s3(mock_open): mock_open.assert_called_once_with(zip_archive_filename, "wb") rubicon_a.repository.filesystem.rm(rubicon_a.config.root_dir, recursive=True) + + +def test_wrong_json_schema_experiment(rubicon_local_filesystem_client_with_project): + """Test that our new error catchers work for bad json schemas.""" + rubicon, project = rubicon_local_filesystem_client_with_project + experiment_location = rubicon.repository._get_experiment_metadata_root(project.name) + os.mkdir(experiment_location) + with open(os.path.join(experiment_location, "bad_experiment.json"), "w") as f: + f.write("bad json") + + assert project.experiments() == [] + + +def test_no_json_schema_experiment(rubicon_local_filesystem_client_with_project): + """Test that our new error catchers work when we are missing json schemas.""" + rubicon, project = rubicon_local_filesystem_client_with_project + experiment_location = rubicon.repository._get_experiment_metadata_root(project.name) + os.mkdir(experiment_location) + + assert project.experiments() == []