Skip to content

Commit

Permalink
Change import to have configurable prefixes via set_options (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal authored Mar 24, 2022
1 parent b0b118c commit a5c6ef0
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 14 deletions.
7 changes: 7 additions & 0 deletions docs/source/reference/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,10 @@ For more details and examples, refer to the relevant chapters in the main part o
:noindex:
:special-members: __init__
```

## Options for dataset attributes

```{eval-rst}
.. autoclass:: intake_esm.utils.set_options
:noindex:
```
2 changes: 1 addition & 1 deletion intake_esm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . import tutorial
from .core import esm_datastore
from .derived import DerivedVariableRegistry, default_registry
from .utils import show_versions
from .utils import set_options, show_versions

try:
__version__ = get_distribution(__name__).version
Expand Down
4 changes: 4 additions & 0 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ def to_dataset_dict(
"""
Load catalog entries into a dictionary of xarray datasets.
Column values, dataset keys and requested variables are added as global
attributes on the returned datasets. The names of these attributes can be
customized with :py:class:`intake_esm.utils.set_options`.
Parameters
----------
xarray_open_kwargs : dict
Expand Down
16 changes: 8 additions & 8 deletions intake_esm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from intake.source.base import DataSource, Schema

from .cat import Aggregation, DataFormat
from .utils import INTAKE_ESM_ATTRS_PREFIX, INTAKE_ESM_DATASET_KEY, INTAKE_ESM_VARS_KEY
from .utils import OPTIONS


class ESMDataSourceError(Exception):
Expand Down Expand Up @@ -74,9 +74,9 @@ def _open_dataset(
variable_intersection = set(requested_variables).intersection(set(varname))
variables = [variable for variable in variable_intersection if variable in ds.data_vars]
ds = ds[variables]
ds.attrs[INTAKE_ESM_VARS_KEY] = variables
ds.attrs[OPTIONS['vars_key']] = variables
else:
ds.attrs[INTAKE_ESM_VARS_KEY] = varname
ds.attrs[OPTIONS['vars_key']] = varname

ds = _expand_dims(expand_dims, ds)
ds = _update_attrs(additional_attrs, ds)
Expand All @@ -87,15 +87,15 @@ def _update_attrs(additional_attrs, ds):
additional_attrs = additional_attrs or {}
if additional_attrs:
additional_attrs = {
f'{INTAKE_ESM_ATTRS_PREFIX}/{key}': value for key, value in additional_attrs.items()
f"{OPTIONS['attrs_prefix']}/{key}": value for key, value in additional_attrs.items()
}
ds.attrs = {**ds.attrs, **additional_attrs}
return ds


def _expand_dims(expand_dims, ds):
if expand_dims:
for variable in ds.attrs[INTAKE_ESM_VARS_KEY]:
for variable in ds.attrs[OPTIONS['vars_key']]:
ds[variable] = ds[variable].expand_dims(**expand_dims)

return ds
Expand Down Expand Up @@ -230,22 +230,22 @@ def _open_dataset(self):
datasets = sorted(
datasets,
key=lambda ds: tuple(
f'{INTAKE_ESM_ATTRS_PREFIX}/{agg.attribute_name}'
f"{OPTIONS['attrs_prefix']}/{agg.attribute_name}"
for agg in self.aggregations
),
)
with dask.config.set(
{'scheduler': 'single-threaded', 'array.slicing.split_large_chunks': True}
): # Use single-threaded scheduler
datasets = [
ds.set_coords(set(ds.variables) - set(ds.attrs[INTAKE_ESM_VARS_KEY]))
ds.set_coords(set(ds.variables) - set(ds.attrs[OPTIONS['vars_key']]))
for ds in datasets
]
self._ds = xr.combine_by_coords(
datasets, **self.xarray_combine_by_coords_kwargs
)

self._ds.attrs[INTAKE_ESM_DATASET_KEY] = self.key
self._ds.attrs[OPTIONS['dataset_key']] = self.key

except Exception as exc:
raise ESMDataSourceError(
Expand Down
72 changes: 68 additions & 4 deletions intake_esm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
import importlib
import sys

INTAKE_ESM_ATTRS_PREFIX = 'intake_esm_attrs'
INTAKE_ESM_DATASET_KEY = 'intake_esm_dataset_key'
INTAKE_ESM_VARS_KEY = 'intake_esm_vars'


def _allnan_or_nonan(df, column: str) -> bool:
"""Check if all values in a column are NaN or not NaN
Expand Down Expand Up @@ -77,3 +73,71 @@ def show_versions(file=sys.stdout): # pragma: no cover
print('', file=file)
for k, stat in sorted(deps_blob):
print(f'{k}: {stat}', file=file)


OPTIONS = {
'attrs_prefix': 'intake_esm_attrs',
'dataset_key': 'intake_esm_dataset_key',
'vars_key': 'intake_esm_vars',
}


class set_options:
"""Set options for intake_esm in a controlled context.
Currently-supported options:
- ``attrs_prefix``:
The prefix to use in the names of attributes constructed from the catalog's columns
when returning xarray Datasets.
Default: ``intake_esm_attrs``.
- ``dataset_key``:
Name of the global attribute where to store the dataset's key.
Default: ``intake_esm_dataset_key``.
- ``vars_key``:
Name of the global attribute where to store the list of requested variables when
opening a dataset. Default: ``intake_esm_vars``.
Examples
--------
You can use ``set_options`` either as a context manager:
>>> import intake
>>> import intake_esm
>>> cat = intake.open_esm_datastore('catalog.json')
>>> with intake_esm.set_options(attrs_prefix='cat'):
... out = cat.to_dataset_dict()
...
Or to set global options:
>>> intake_esm.set_options(attrs_prefix='cat', vars_key='cat_vars')
"""

def __init__(self, **kwargs):
self.old = {}
for k, v in kwargs.items():
if k not in OPTIONS:
raise ValueError(
f'argument name {k} is not in the set of valid options {set(OPTIONS)}'
)

if not isinstance(v, str):
raise ValueError(f'option {k} given an invalid value: {v}')

self.old[k] = OPTIONS[k]

self._update(kwargs)

def __enter__(self):
"""Context management."""
return

def _update(self, kwargs):
"""Update values."""
for k, v in kwargs.items():
OPTIONS[k] = v

def __exit__(self, type, value, traceback):
"""Context management."""
self._update(self.old)
12 changes: 11 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,5 +461,15 @@ class ChildCatalog(intake_esm.esm_datastore):
pass

cat = ChildCatalog(catalog_dict_records)
scat = cat.search(variable=['FOO', 'BAR'])
scat = cat.search(variable=['FLNS'])
assert type(scat) is ChildCatalog


def test_options():
cat = intake.open_esm_datastore(catalog_dict_records)
scat = cat.search(variable=['FLNS'])
with intake_esm.set_options(attrs_prefix='myprefix'):
_, ds = scat.to_dataset_dict(
xarray_open_kwargs={'backend_kwargs': {'storage_options': {'anon': True}}},
).popitem()
assert ds.attrs['myprefix/component'] == 'atm'

0 comments on commit a5c6ef0

Please # to comment.