Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Derived Catalog: test for all needed variables and skip if existing #441

Merged
merged 2 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions intake_esm/derived.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class DerivedVariable(pydantic.BaseModel):
func: typing.Callable
variable: pydantic.StrictStr
query: typing.Dict[pydantic.StrictStr, typing.Union[typing.Any, typing.List[typing.Any]]]
prefer_derived: bool

@pydantic.validator('query')
def validate_query(cls, values):
Expand Down Expand Up @@ -92,6 +93,7 @@ def register(
*,
variable: str,
query: typing.Dict[pydantic.StrictStr, typing.Union[typing.Any, typing.List[typing.Any]]],
prefer_derived: bool = False,
) -> typing.Callable:
"""Register a derived variable

Expand All @@ -103,13 +105,18 @@ def register(
The name of the variable to derive.
query : typing.Dict[str, typing.Union[typing.Any, typing.List[typing.Any]]]
The query to use to retrieve dependent variables required to derive `variable`.
prefer_derived: bool, optional (default=False)
Specify whether to compute this variable on datasets that already contain a variable
of the same name. Default (False) is to leave the existing variable.

Returns
-------
typing.Callable
The function that was registered.
"""
self._registry[variable] = DerivedVariable(func=func, variable=variable, query=query)
self._registry[variable] = DerivedVariable(
func=func, variable=variable, query=query, prefer_derived=prefer_derived
)
return func

def __contains__(self, item: str) -> bool:
Expand Down Expand Up @@ -182,8 +189,11 @@ def update_datasets(

for dset_key, dataset in datasets.items():
for _, derived_variable in self.items():
if set(dataset.variables).intersection(
if set(dataset.variables).issuperset(
derived_variable.dependent_variables(variable_key_name)
) and (
(derived_variable.variable not in dataset.variables)
or derived_variable.prefer_derived
):
try:
# Assumes all dependent variables are in the same dataset
Expand Down
39 changes: 38 additions & 1 deletion tests/test_derived.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def func(ds):
ds['FOO'] = ds.air // 100
return ds

@dvr.register(variable='lon', query={'variable': 'air'})
def func2(ds):
ds['lon'] = ds.air.lon // 100
return ds

dsets = dvr.update_datasets(datasets={'test': ds.copy()}, variable_key_name='variable')
assert 'test' in dsets
assert 'FOO' in dsets['test']
Expand Down Expand Up @@ -98,10 +103,42 @@ def funcb(ds):
dsets = dvr.update_datasets(
datasets={'test': ds.copy()}, variable_key_name='variable', skip_on_error=True
)

assert 'FOO' not in dsets['test']

with pytest.raises(DerivedVariableError):
dsets = dvr.update_datasets(
datasets={'test': ds.copy()}, variable_key_name='variable', skip_on_error=False
)

@dvr.register(variable='BAR', query={'variable': ['air', 'water']})
def funcc(ds):
ds['BAR'] = ds.air / ds.water
return ds

@dvr.register(variable='FOO', query={'variable': ['air']})
def funcd(ds):
ds['FOO'] = ds.air * 2
return ds

# No error, nothing is done.
dsets = dvr.update_datasets(
datasets={'test': ds.assign(FOO=ds.air).copy()},
variable_key_name='variable',
skip_on_error=False,
)
assert {'air', 'FOO'} == dsets['test'].data_vars.keys()
assert ds.air.equals(dsets['test'].FOO)

@dvr.register(variable='FOO', query={'variable': ['air']}, prefer_derived=True)
def funce(ds):
ds['FOO'] = ds.air * 2
return ds

# No error, FOO is recomputed
dsets = dvr.update_datasets(
datasets={'test': ds.assign(FOO=ds.air).copy()},
variable_key_name='variable',
skip_on_error=False,
)
assert {'air', 'FOO'} == dsets['test'].data_vars.keys()
assert ds.air.equals(dsets['test'].FOO / 2)