Skip to content

Commit

Permalink
Add method for loading registry from a Python module (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Oct 19, 2021
1 parent dce1b90 commit aaa71db
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
39 changes: 39 additions & 0 deletions intake_esm/derived.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib
import inspect
import typing

import pydantic
Expand Down Expand Up @@ -34,6 +36,43 @@ class DerivedVariableRegistry:
def __post_init_post_parse__(self):
self._registry = {}

@classmethod
def load(cls, name: str, package: str = None) -> 'DerivedVariableRegistry':
"""Load a DerivedVariableRegistry from a Python module/file
Parameters
----------
name : str
The name of the module to load the DerivedVariableRegistry from.
package : str, optional
The package to load the module from. This argument is
required when performing a relative import. It specifies the package
to use as the anchor point from which to resolve the relative import
to an absolute import.
Returns
-------
DerivedVariableRegistry
A DerivedVariableRegistry loaded from the Python module.
Notes
-----
If you have a folder: /home/foo/pythonfiles, and you want to load a registry
defined in registry.py, located in that directory, ensure to add your folder to the
$PYTHONPATH before calling this function.
>>> import sys
>>> sys.path.insert(0, "/home/foo/pythonfiles")
>>> from intake_esm.derived import DerivedVariableRegistry
>>> registsry = DerivedVariableRegistry.load("registry")
"""
modname = importlib.import_module(name, package=package)
candidates = inspect.getmembers(modname, lambda x: isinstance(x, DerivedVariableRegistry))
if candidates:
return candidates[0][1]
else:
raise ValueError(f'No DerivedVariableRegistry found in {name} module')

@tlz.curry
def register(
self, func: typing.Callable, *, variable: str, dependent_variables: typing.List[str]
Expand Down
8 changes: 8 additions & 0 deletions tests/my_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import intake_esm

registry = intake_esm.DerivedVariableRegistry()


@registry.register(variable='FOO', dependent_variables=['FLUT'])
def func(ds):
return ds
16 changes: 16 additions & 0 deletions tests/test_derived.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import sys

import pytest
import xarray as xr

from intake_esm.derived import DerivedVariable, DerivedVariableError, DerivedVariableRegistry

from .utils import here


def test_registry_init():
"""
Expand All @@ -14,6 +18,18 @@ def test_registry_init():
assert len(dvr.keys()) == 0


def test_registry_load():

sys.path.insert(0, f'{here}/')
dvr = DerivedVariableRegistry.load('my_registry')
assert len(dvr) > 0
assert 'FOO' in dvr

# Test for errors/ invalid inputs, wrong return type
with pytest.raises(ValueError):
DerivedVariableRegistry.load('utils')


def test_registry_register():
dvr = DerivedVariableRegistry()

Expand Down

0 comments on commit aaa71db

Please # to comment.