From a5c6ef02e217bf9fecca7efbf0139fe9f726d14d Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 24 Mar 2022 16:11:01 -0400 Subject: [PATCH] Change import to have configurable prefixes via `set_options` (#460) --- docs/source/reference/api.md | 7 ++++ intake_esm/__init__.py | 2 +- intake_esm/core.py | 4 ++ intake_esm/source.py | 16 ++++---- intake_esm/utils.py | 72 ++++++++++++++++++++++++++++++++++-- tests/test_core.py | 12 +++++- 6 files changed, 99 insertions(+), 14 deletions(-) diff --git a/docs/source/reference/api.md b/docs/source/reference/api.md index 15d82cb3..958da5fa 100644 --- a/docs/source/reference/api.md +++ b/docs/source/reference/api.md @@ -50,3 +50,10 @@ For more details and examples, refer to the relevant chapters in the main part o :noindex: :special-members: __init__ ``` + +## Options for dataset attributes + +```{eval-rst} +.. autoclass:: intake_esm.utils.set_options + :noindex: +``` diff --git a/intake_esm/__init__.py b/intake_esm/__init__.py index 621a9cd3..38c68f58 100644 --- a/intake_esm/__init__.py +++ b/intake_esm/__init__.py @@ -8,7 +8,7 @@ from . import tutorial from .core import esm_datastore from .derived import DerivedVariableRegistry, default_registry -from .utils import show_versions +from .utils import set_options, show_versions try: __version__ = get_distribution(__name__).version diff --git a/intake_esm/core.py b/intake_esm/core.py index 9efe89a6..6cb55f89 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -475,6 +475,10 @@ def to_dataset_dict( """ Load catalog entries into a dictionary of xarray datasets. + Column values, dataset keys and requested variables are added as global + attributes on the returned datasets. The names of these attributes can be + customized with :py:class:`intake_esm.utils.set_options`. + Parameters ---------- xarray_open_kwargs : dict diff --git a/intake_esm/source.py b/intake_esm/source.py index 27b739cd..f84b9557 100644 --- a/intake_esm/source.py +++ b/intake_esm/source.py @@ -8,7 +8,7 @@ from intake.source.base import DataSource, Schema from .cat import Aggregation, DataFormat -from .utils import INTAKE_ESM_ATTRS_PREFIX, INTAKE_ESM_DATASET_KEY, INTAKE_ESM_VARS_KEY +from .utils import OPTIONS class ESMDataSourceError(Exception): @@ -74,9 +74,9 @@ def _open_dataset( variable_intersection = set(requested_variables).intersection(set(varname)) variables = [variable for variable in variable_intersection if variable in ds.data_vars] ds = ds[variables] - ds.attrs[INTAKE_ESM_VARS_KEY] = variables + ds.attrs[OPTIONS['vars_key']] = variables else: - ds.attrs[INTAKE_ESM_VARS_KEY] = varname + ds.attrs[OPTIONS['vars_key']] = varname ds = _expand_dims(expand_dims, ds) ds = _update_attrs(additional_attrs, ds) @@ -87,7 +87,7 @@ def _update_attrs(additional_attrs, ds): additional_attrs = additional_attrs or {} if additional_attrs: additional_attrs = { - f'{INTAKE_ESM_ATTRS_PREFIX}/{key}': value for key, value in additional_attrs.items() + f"{OPTIONS['attrs_prefix']}/{key}": value for key, value in additional_attrs.items() } ds.attrs = {**ds.attrs, **additional_attrs} return ds @@ -95,7 +95,7 @@ def _update_attrs(additional_attrs, ds): def _expand_dims(expand_dims, ds): if expand_dims: - for variable in ds.attrs[INTAKE_ESM_VARS_KEY]: + for variable in ds.attrs[OPTIONS['vars_key']]: ds[variable] = ds[variable].expand_dims(**expand_dims) return ds @@ -230,7 +230,7 @@ def _open_dataset(self): datasets = sorted( datasets, key=lambda ds: tuple( - f'{INTAKE_ESM_ATTRS_PREFIX}/{agg.attribute_name}' + f"{OPTIONS['attrs_prefix']}/{agg.attribute_name}" for agg in self.aggregations ), ) @@ -238,14 +238,14 @@ def _open_dataset(self): {'scheduler': 'single-threaded', 'array.slicing.split_large_chunks': True} ): # Use single-threaded scheduler datasets = [ - ds.set_coords(set(ds.variables) - set(ds.attrs[INTAKE_ESM_VARS_KEY])) + ds.set_coords(set(ds.variables) - set(ds.attrs[OPTIONS['vars_key']])) for ds in datasets ] self._ds = xr.combine_by_coords( datasets, **self.xarray_combine_by_coords_kwargs ) - self._ds.attrs[INTAKE_ESM_DATASET_KEY] = self.key + self._ds.attrs[OPTIONS['dataset_key']] = self.key except Exception as exc: raise ESMDataSourceError( diff --git a/intake_esm/utils.py b/intake_esm/utils.py index 092bd81e..a269ef54 100644 --- a/intake_esm/utils.py +++ b/intake_esm/utils.py @@ -2,10 +2,6 @@ import importlib import sys -INTAKE_ESM_ATTRS_PREFIX = 'intake_esm_attrs' -INTAKE_ESM_DATASET_KEY = 'intake_esm_dataset_key' -INTAKE_ESM_VARS_KEY = 'intake_esm_vars' - def _allnan_or_nonan(df, column: str) -> bool: """Check if all values in a column are NaN or not NaN @@ -77,3 +73,71 @@ def show_versions(file=sys.stdout): # pragma: no cover print('', file=file) for k, stat in sorted(deps_blob): print(f'{k}: {stat}', file=file) + + +OPTIONS = { + 'attrs_prefix': 'intake_esm_attrs', + 'dataset_key': 'intake_esm_dataset_key', + 'vars_key': 'intake_esm_vars', +} + + +class set_options: + """Set options for intake_esm in a controlled context. + + Currently-supported options: + + - ``attrs_prefix``: + The prefix to use in the names of attributes constructed from the catalog's columns + when returning xarray Datasets. + Default: ``intake_esm_attrs``. + - ``dataset_key``: + Name of the global attribute where to store the dataset's key. + Default: ``intake_esm_dataset_key``. + - ``vars_key``: + Name of the global attribute where to store the list of requested variables when + opening a dataset. Default: ``intake_esm_vars``. + + Examples + -------- + You can use ``set_options`` either as a context manager: + + >>> import intake + >>> import intake_esm + >>> cat = intake.open_esm_datastore('catalog.json') + >>> with intake_esm.set_options(attrs_prefix='cat'): + ... out = cat.to_dataset_dict() + ... + + Or to set global options: + + >>> intake_esm.set_options(attrs_prefix='cat', vars_key='cat_vars') + """ + + def __init__(self, **kwargs): + self.old = {} + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + f'argument name {k} is not in the set of valid options {set(OPTIONS)}' + ) + + if not isinstance(v, str): + raise ValueError(f'option {k} given an invalid value: {v}') + + self.old[k] = OPTIONS[k] + + self._update(kwargs) + + def __enter__(self): + """Context management.""" + return + + def _update(self, kwargs): + """Update values.""" + for k, v in kwargs.items(): + OPTIONS[k] = v + + def __exit__(self, type, value, traceback): + """Context management.""" + self._update(self.old) diff --git a/tests/test_core.py b/tests/test_core.py index d9b837e8..71d59682 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -461,5 +461,15 @@ class ChildCatalog(intake_esm.esm_datastore): pass cat = ChildCatalog(catalog_dict_records) - scat = cat.search(variable=['FOO', 'BAR']) + scat = cat.search(variable=['FLNS']) assert type(scat) is ChildCatalog + + +def test_options(): + cat = intake.open_esm_datastore(catalog_dict_records) + scat = cat.search(variable=['FLNS']) + with intake_esm.set_options(attrs_prefix='myprefix'): + _, ds = scat.to_dataset_dict( + xarray_open_kwargs={'backend_kwargs': {'storage_options': {'anon': True}}}, + ).popitem() + assert ds.attrs['myprefix/component'] == 'atm'