diff --git a/docs/source/concepts/model_specification/index.rst b/docs/source/concepts/model_specification/index.rst index 83bf439a2..c93540e91 100644 --- a/docs/source/concepts/model_specification/index.rst +++ b/docs/source/concepts/model_specification/index.rst @@ -76,9 +76,11 @@ specification item and the fully instantiated object is the domain of the The :class:`ComponentConfigurationParser ` -is responsible for taking a list or hierarchical :class:`LayeredConfigTree -` of components derived from a model -specification file and turning it into a list of instantiated component objects. +is responsible for taking a list or hierarchical +:class:`LayeredConfigTree ` of +components derived from a model specification file and turning it into a list of +instantiated component objects. + The :meth:`get_components ` method of the parser is used anytime a simulation is initialized from a diff --git a/docs/source/conf.py b/docs/source/conf.py index 03b43b79f..8f317ef88 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -50,6 +50,7 @@ # ones. extensions = [ "sphinx.ext.autodoc", + "sphinx_autodoc_typehints", "sphinx.ext.intersphinx", "sphinx.ext.doctest", "sphinx.ext.todo", diff --git a/setup.py b/setup.py index 896eeadd4..115d6257e 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ "IPython", "matplotlib", "sphinxcontrib-video", + "sphinx-autodoc-typehints", ] setup( diff --git a/src/vivarium/component.py b/src/vivarium/component.py index df700af85..b8eab99f6 100644 --- a/src/vivarium/component.py +++ b/src/vivarium/component.py @@ -5,6 +5,7 @@ A base Component class to be used to create components for use in ``vivarium`` simulations. + """ import re @@ -31,8 +32,7 @@ class Component(ABC): - """ - The base class for all components used in a Vivarium simulation. + """The base class for all components used in a Vivarium simulation. A `Component` in a Vivarium simulation represents a distinct feature or aspect of the model. It encapsulates the logic and data needed for that @@ -76,17 +76,16 @@ class Component(ABC): - `on_time_step_cleanup` - `on_collect_metrics` - `on_simulation_end` + """ CONFIGURATION_DEFAULTS: Dict[str, Any] = {} - """ - A dictionary containing the defaults for any configurations managed by this + """A dictionary containing the defaults for any configurations managed by this component. An empty dictionary indicates no managed configurations. """ def __repr__(self): - """ - Returns a string representation of the __init__ call made to create this + """Returns a string representation of the __init__ call made to create this object. The representation is built by retrieving the initialization parameters @@ -100,7 +99,6 @@ def __repr__(self): Returns ------- - str A string representation of the __init__ call made to create this object. """ @@ -123,9 +121,10 @@ def __str__(self): @property def name(self) -> str: - """ - Returns the name of the component. By convention, these are in snake - case with arguments of the `__init__` appended and separated by `.`. + """Returns the name of the component. + + By convention, these are in snake case with arguments of the `__init__` + appended and separated by `.`. Names must be unique within a simulation. @@ -143,7 +142,7 @@ def name(self) -> str: Returns ------- str - The unique name of the component. + The name of the component. """ if not self._name: base_name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", type(self).__name__) @@ -159,20 +158,18 @@ def name(self) -> str: @property def sub_components(self) -> List["Component"]: - """ - Provide components managed by this component. + """Provide components managed by this component. Returns ------- List[Component] - A list of components that are managed by this component. + The sub-components that are managed by this component. """ return self._sub_components @property def configuration_defaults(self) -> Dict[str, Any]: - """ - Provides a dictionary containing the defaults for any configurations + """Provides a dictionary containing the defaults for any configurations managed by this component. These default values will be stored at the `component_configs` layer of the @@ -180,7 +177,6 @@ def configuration_defaults(self) -> Dict[str, Any]: Returns ------- - Dict[str, Any] A dictionary containing the defaults for any configurations managed by this component. """ @@ -188,12 +184,10 @@ def configuration_defaults(self) -> Dict[str, Any]: @property def columns_created(self) -> List[str]: - """ - Provides names of columns created by the component. + """Provides names of columns created by the component. Returns ------- - List[str] Names of the columns created by this component, or an empty list if none. """ @@ -201,12 +195,10 @@ def columns_created(self) -> List[str]: @property def columns_required(self) -> Optional[List[str]]: - """ - Provides names of columns required by the component. + """Provides names of columns required by the component. Returns ------- - Optional[List[str]] Names of required columns not created by this component. An empty list means all available columns are needed. `None` means no additional columns are necessary. @@ -215,13 +207,11 @@ def columns_required(self) -> Optional[List[str]]: @property def initialization_requirements(self) -> Dict[str, List[str]]: - """ - Provides the names of all values required by this component during + """Provides the names of all values required by this component during simulant initialization. Returns ------- - Dict[str, List[str]] A dictionary containing the additional requirements of this component during simulant initialization. An omitted key or an empty list for a key implies no requirements for that key during @@ -235,12 +225,10 @@ def initialization_requirements(self) -> Dict[str, List[str]]: @property def population_view_query(self) -> Optional[str]: - """ - Provides a query to use when filtering the component's `PopulationView`. + """Provides a query to use when filtering the component's `PopulationView`. Returns ------- - Optional[str] A pandas query string for filtering the component's `PopulationView`. Returns `None` if no filtering is required. """ @@ -248,12 +236,10 @@ def population_view_query(self) -> Optional[str]: @property def post_setup_priority(self) -> int: - """ - Provides the priority of this component's post_setup listener. + """Provides the priority of this component's post_setup listener. Returns ------- - int The priority of this component's post_setup listener. This value can range from 0 to 9, inclusive. """ @@ -261,12 +247,10 @@ def post_setup_priority(self) -> int: @property def time_step_prepare_priority(self) -> int: - """ - Provides the priority of this component's time_step__prepare listener. + """Provides the priority of this component's time_step__prepare listener. Returns ------- - int The priority of this component's time_step__prepare listener. This value can range from 0 to 9, inclusive. """ @@ -274,12 +258,10 @@ def time_step_prepare_priority(self) -> int: @property def time_step_priority(self) -> int: - """ - Provides the priority of this component's time_step listener. + """Provides the priority of this component's time_step listener. Returns ------- - int The priority of this component's time_step listener. This value can range from 0 to 9, inclusive. """ @@ -287,12 +269,10 @@ def time_step_priority(self) -> int: @property def time_step_cleanup_priority(self) -> int: - """ - Provides the priority of this component's time_step__cleanup listener. + """Provides the priority of this component's time_step__cleanup listener. Returns ------- - int The priority of this component's time_step__cleanup listener. This value can range from 0 to 9, inclusive. """ @@ -300,12 +280,10 @@ def time_step_cleanup_priority(self) -> int: @property def collect_metrics_priority(self) -> int: - """ - Provides the priority of this component's collect_metrics listener. + """Provides the priority of this component's collect_metrics listener. Returns ------- - int The priority of this component's collect_metrics listener. This value can range from 0 to 9, inclusive. """ @@ -313,12 +291,10 @@ def collect_metrics_priority(self) -> int: @property def simulation_end_priority(self) -> int: - """ - Provides the priority of this component's simulation_end listener. + """Provides the priority of this component's simulation_end listener. Returns ------- - int The priority of this component's simulation_end listener. This value can range from 0 to 9, inclusive. """ @@ -329,8 +305,7 @@ def simulation_end_priority(self) -> int: ##################### def __init__(self) -> None: - """ - Initializes a new instance of the Component class. + """Initializes a new instance of the Component class. This method is the initializer for the Component class. It initializes logger of type Logger and population_view of type PopulationView to None. @@ -349,8 +324,7 @@ def __init__(self) -> None: self.lookup_tables: Dict[str, LookupTable] = {} def setup_component(self, builder: "Builder") -> None: - """ - Sets up the component for a Vivarium simulation. + """Sets up the component for a Vivarium simulation. This method is run by Vivarium during the setup phase. It performs a series of operations to prepare the component for the simulation. @@ -362,12 +336,8 @@ def setup_component(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder object used to set up the component. - - Returns - ------- - None """ self.logger = builder.logging.get_logger(self.name) self.get_value_columns = builder.data.value_columns() @@ -388,8 +358,7 @@ def setup_component(self, builder: "Builder") -> None: ####################### def setup(self, builder: "Builder") -> None: - """ - Defines custom actions this component needs to run during the setup + """Defines custom actions this component needs to run during the setup lifecycle phase. This method is intended to be overridden by subclasses to perform any @@ -398,32 +367,25 @@ def setup(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder object used to set up the component. - - Returns - ------- - None """ pass def on_post_setup(self, event: Event) -> None: - """ - Method that vivarium will run during the post_setup event. + """Method that vivarium will run during the post_setup event. This method is intended to be overridden by subclasses if there are operations they need to perform specifically during the post_setup event. - NOTE: This method is not commonly used functionality. + Notes + ----- + This method is not commonly used functionality. Parameters ---------- - event : Event + event The event object associated with the post_setup event. - - Returns - ------- - None """ pass @@ -447,8 +409,7 @@ def on_initialize_simulants(self, pop_data: "SimulantData") -> None: pass def on_time_step_prepare(self, event: Event) -> None: - """ - Method that vivarium will run during the time_step__prepare event. + """Method that vivarium will run during the time_step__prepare event. This method is intended to be overridden by subclasses if there are operations they need to perform specifically during the @@ -456,36 +417,26 @@ def on_time_step_prepare(self, event: Event) -> None: Parameters ---------- - event : Event + event The event object associated with the time_step__prepare event. - - Returns - ------- - None """ pass def on_time_step(self, event: Event) -> None: - """ - Method that vivarium will run during the time_step event. + """Method that vivarium will run during the time_step event. This method is intended to be overridden by subclasses if there are operations they need to perform specifically during the time_step event. Parameters ---------- - event : Event + event The event object associated with the time_step event. - - Returns - ------- - None """ pass def on_time_step_cleanup(self, event: Event) -> None: - """ - Method that vivarium will run during the time_step__cleanup event. + """Method that vivarium will run during the time_step__cleanup event. This method is intended to be overridden by subclasses if there are operations they need to perform specifically during the @@ -493,18 +444,13 @@ def on_time_step_cleanup(self, event: Event) -> None: Parameters ---------- - event : Event + event The event object associated with the time_step__cleanup event. - - Returns - ------- - None """ pass def on_collect_metrics(self, event: Event) -> None: - """ - Method that vivarium will run during the collect_metrics event. + """Method that vivarium will run during the collect_metrics event. This method is intended to be overridden by subclasses if there are operations they need to perform specifically during the collect_metrics @@ -512,18 +458,13 @@ def on_collect_metrics(self, event: Event) -> None: Parameters ---------- - event : Event + event The event object associated with the collect_metrics event. - - Returns - ------- - None """ pass def on_simulation_end(self, event: Event) -> None: - """ - Method that vivarium will run during the simulation_end event. + """Method that vivarium will run during the simulation_end event. This method is intended to be overridden by subclasses if there are operations they need to perform specifically during the simulation_end @@ -531,12 +472,8 @@ def on_simulation_end(self, event: Event) -> None: Parameters ---------- - event : Event + event The event object associated with the simulation_end event. - - Returns - ------- - None """ pass @@ -545,19 +482,18 @@ def on_simulation_end(self, event: Event) -> None: ################## def get_initialization_parameters(self) -> Dict[str, Any]: - """ - Retrieves the values of all parameters specified in the `__init__` that + """Retrieves the values of all parameters specified in the `__init__` that have an attribute with the same name. - Note: this retrieves the value of the attribute at the time of calling, + Notes + ----- + This retrieves the value of the attribute at the time of calling, which is not guaranteed to be the same as the original value. Returns ------- - dict A dictionary where the keys are the names of the parameters used in the `__init__` method and the values are their current values. - """ return { parameter_name: getattr(self, parameter_name) @@ -566,8 +502,7 @@ def get_initialization_parameters(self) -> Dict[str, Any]: } def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]: - """ - Retrieves the configuration for this component from the builder. + """Retrieves the configuration for this component from the builder. This method retrieves the configuration for this component from the simulation's overall configuration. The configuration is retrieved using @@ -575,12 +510,11 @@ def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]: Parameters ---------- - builder : Builder + builder The simulation's builder object. Returns ------- - Optional[layered_config_tree.main.LayeredConfigTree] The configuration for this component, or `None` if the component has no configuration. """ @@ -590,8 +524,7 @@ def get_configuration(self, builder: "Builder") -> Optional[LayeredConfigTree]: return None def build_all_lookup_tables(self, builder: "Builder") -> None: - """ - Builds all lookup tables for this component. + """Builds all lookup tables for this component. This method builds lookup tables for this component based on the data sources specified in the configuration. If no data sources are specified, @@ -602,12 +535,8 @@ def build_all_lookup_tables(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder object used to set up the component. - - Returns - ------- - None """ if self.configuration and "data_sources" in self.configuration: for table_name in self.configuration.data_sources.keys(): @@ -627,8 +556,7 @@ def build_lookup_table( data_source: Union[str, float, int, list, pd.DataFrame], value_columns: Optional[Iterable[str]] = None, ) -> LookupTable: - """ - Builds a LookupTable from a data source. + """Builds a LookupTable from a data source. Uses `get_data` to parse the data source and retrieve the lookup table data. The LookupTable is built from the data source, with the value @@ -638,16 +566,15 @@ def build_lookup_table( Parameters ---------- - builder : Builder + builder The builder object used to set up the component. - data_source : Union[str, float, pandas.core.generic.PandasObject] + data_source The data source to build the LookupTable from. - value_columns : Optional[Iterable[str]] + value_columns The columns to include in the LookupTable. Returns ------- - LookupTable The LookupTable built from the data source. Raises @@ -695,8 +622,7 @@ def get_data( builder: "Builder", data_source: Union[str, float, pd.DataFrame], ) -> Union[float, pd.DataFrame]: - """ - Retrieves data from a data source. + """Retrieves data from a data source. If the data source is a float or a DataFrame, it is treated as the data itself. If the data source is a string, containing the substring '::', @@ -708,14 +634,13 @@ def get_data( Parameters ---------- - builder : Builder + builder The builder object used to set up the component. - data_source : Union[str, float, pandas.core.generic.PandasObject] + data_source The data source to retrieve data from. Returns ------- - Union[float, pandas.core.generic.PandasObject] The data retrieved from the data source. Raises @@ -751,8 +676,7 @@ def get_data( raise ConfigurationError(f"Failed to find key '{data_source}' in artifact.") def _set_population_view(self, builder: "Builder") -> None: - """ - Creates the PopulationView for this component if it needs access to + """Creates the PopulationView for this component if it needs access to the state table. The method determines the necessary columns for the PopulationView @@ -761,12 +685,8 @@ def _set_population_view(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder object used to set up the component. - - Returns - ------- - None """ if self.columns_required: # Get all columns created and required @@ -787,8 +707,7 @@ def _set_population_view(self, builder: "Builder") -> None: ) def _register_post_setup_listener(self, builder: "Builder") -> None: - """ - Registers a post_setup listener if this component has defined one. + """Registers a post_setup listener if this component has defined one. This method allows the component to respond to "post_setup" events if it has its own `on_post_setup` method. The listener will be registered with @@ -797,12 +716,8 @@ def _register_post_setup_listener(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the listener. - - Returns - ------- - None """ if type(self).on_post_setup != Component.on_post_setup: builder.event.register_listener( @@ -812,8 +727,7 @@ def _register_post_setup_listener(self, builder: "Builder") -> None: ) def _register_simulant_initializer(self, builder: "Builder") -> None: - """ - Registers a simulant initializer if this component has defined one. + """Registers a simulant initializer if this component has defined one. This method allows the component to initialize simulants if it has its own `on_initialize_simulants` method. It registers this method with the @@ -822,12 +736,8 @@ def _register_simulant_initializer(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the initializer. - - Returns - ------- - None """ if type(self).on_initialize_simulants != Component.on_initialize_simulants: builder.population.initializes_simulants( @@ -837,8 +747,7 @@ def _register_simulant_initializer(self, builder: "Builder") -> None: ) def _register_time_step_prepare_listener(self, builder: "Builder") -> None: - """ - Registers a time_step_prepare listener if this component has defined one. + """Registers a time_step_prepare listener if this component has defined one. This method allows the component to respond to "time_step_prepare" events if it has its own `on_time_step_prepare` method. The listener will be @@ -846,12 +755,8 @@ def _register_time_step_prepare_listener(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the listener. - - Returns - ------- - None """ if type(self).on_time_step_prepare != Component.on_time_step_prepare: builder.event.register_listener( @@ -861,8 +766,7 @@ def _register_time_step_prepare_listener(self, builder: "Builder") -> None: ) def _register_time_step_listener(self, builder: "Builder") -> None: - """ - Registers a time_step listener if this component has defined one. + """Registers a time_step listener if this component has defined one. This method allows the component to respond to "time_step" events if it has its own `on_time_step` method. The listener will be @@ -870,12 +774,8 @@ def _register_time_step_listener(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the listener. - - Returns - ------- - None """ if type(self).on_time_step != Component.on_time_step: builder.event.register_listener( @@ -885,8 +785,7 @@ def _register_time_step_listener(self, builder: "Builder") -> None: ) def _register_time_step_cleanup_listener(self, builder: "Builder") -> None: - """ - Registers a time_step_cleanup listener if this component has defined one. + """Registers a time_step_cleanup listener if this component has defined one. This method allows the component to respond to "time_step_cleanup" events if it has its own `on_time_step_cleanup` method. The listener will be @@ -894,12 +793,8 @@ def _register_time_step_cleanup_listener(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the listener. - - Returns - ------- - None """ if type(self).on_time_step_cleanup != Component.on_time_step_cleanup: builder.event.register_listener( @@ -909,8 +804,7 @@ def _register_time_step_cleanup_listener(self, builder: "Builder") -> None: ) def _register_collect_metrics_listener(self, builder: "Builder") -> None: - """ - Registers a collect_metrics listener if this component has defined one. + """Registers a collect_metrics listener if this component has defined one. This method allows the component to respond to "collect_metrics" events if it has its own `on_collect_metrics` method. The listener will be @@ -918,12 +812,8 @@ def _register_collect_metrics_listener(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the listener. - - Returns - ------- - None """ if type(self).on_collect_metrics != Component.on_collect_metrics: builder.event.register_listener( @@ -933,8 +823,7 @@ def _register_collect_metrics_listener(self, builder: "Builder") -> None: ) def _register_simulation_end_listener(self, builder: "Builder") -> None: - """ - Registers a simulation_end listener if this component has defined one. + """Registers a simulation_end listener if this component has defined one. This method allows the component to respond to "simulation_end" events if it has its own `on_simulation_end` method. The listener will be @@ -942,12 +831,8 @@ def _register_simulation_end_listener(self, builder: "Builder") -> None: Parameters ---------- - builder : Builder + builder The builder with which to register the listener. - - Returns - ------- - None """ if type(self).on_simulation_end != Component.on_simulation_end: builder.event.register_listener( diff --git a/src/vivarium/framework/artifact/artifact.py b/src/vivarium/framework/artifact/artifact.py index e98f96981..4cbbb8b4f 100644 --- a/src/vivarium/framework/artifact/artifact.py +++ b/src/vivarium/framework/artifact/artifact.py @@ -15,7 +15,7 @@ import warnings from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from vivarium.framework.artifact import hdf @@ -29,7 +29,7 @@ class ArtifactException(Exception): class Artifact: """An interface for interacting with :mod:`vivarium` artifacts.""" - def __init__(self, path: Union[str, Path], filter_terms: List[str] = None): + def __init__(self, path: Union[str, Path], filter_terms: Optional[List[str]] = None): """ Parameters ---------- @@ -38,7 +38,6 @@ def __init__(self, path: Union[str, Path], filter_terms: List[str] = None): filter_terms A set of terms suitable for usage with the ``where`` kwarg for :func:`pandas.read_hdf`. - """ self._path = Path(path) self._filter_terms = filter_terms @@ -92,7 +91,6 @@ def load(self, entity_key: str) -> Any: Returns ------- - Any The expected data. Will either be a standard Python object or a :class:`pandas.DataFrame` or :class:`pandas.Series`. @@ -100,7 +98,6 @@ def load(self, entity_key: str) -> Any: ------ ArtifactException If the provided key is not in the artifact. - """ if entity_key not in self: raise ArtifactException(f"{entity_key} should be in {self.path}.") @@ -117,7 +114,7 @@ def load(self, entity_key: str) -> Any: return self._cache[entity_key] - def write(self, entity_key: str, data: Any): + def write(self, entity_key: str, data: Any) -> None: """Writes data into the artifact and binds it to the provided key. Parameters @@ -132,7 +129,6 @@ def write(self, entity_key: str, data: Any): ------ ArtifactException If the provided key already exists in the artifact. - """ if entity_key in self: raise ArtifactException(f"{entity_key} already in artifact.") @@ -142,7 +138,7 @@ def write(self, entity_key: str, data: Any): hdf.write(self._path, entity_key, data) self._keys.append(entity_key) - def remove(self, entity_key: str): + def remove(self, entity_key: str) -> None: """Removes data associated with the provided key from the artifact. Parameters @@ -154,7 +150,6 @@ def remove(self, entity_key: str): ------ ArtifactException If the key is not present in the artifact. - """ if entity_key not in self: raise ArtifactException( @@ -166,7 +161,7 @@ def remove(self, entity_key: str): self._cache.pop(entity_key) hdf.remove(self._path, entity_key) - def replace(self, entity_key: str, data: Any): + def replace(self, entity_key: str, data: Any) -> None: """Replaces the artifact data at the provided key with the new data. Parameters @@ -181,7 +176,6 @@ def replace(self, entity_key: str, data: Any): ------ ArtifactException If the provided key does not already exist in the artifact. - """ if entity_key not in self: raise ArtifactException( @@ -190,12 +184,11 @@ def replace(self, entity_key: str, data: Any): self.remove(entity_key) self.write(entity_key, data) - def clear_cache(self): + def clear_cache(self) -> None: """Clears the artifact's cache. The artifact will cache data in memory to improve performance for repeat access. - """ self._cache = {} @@ -234,8 +227,11 @@ def _to_tree(keys: List[str]) -> Dict[str, Dict[str, List[str]]]: class Keys: """A convenient wrapper around the keyspace which makes it easier for Artifact to maintain its keyspace when an entity key is added or removed. + With the artifact_path, Keys object is initialized when the Artifact is - initialized""" + initialized + + """ keyspace_node = "metadata.keyspace" @@ -243,7 +239,7 @@ def __init__(self, artifact_path: Path): self._path = artifact_path self._keys = [str(k) for k in hdf.load(self._path, "metadata.keyspace", None, None)] - def append(self, new_key: str): + def append(self, new_key: str) -> None: """Whenever the artifact gets a new key and new data, append is called to remove the old keyspace and to write the updated keyspace""" @@ -251,7 +247,7 @@ def append(self, new_key: str): hdf.remove(self._path, self.keyspace_node) hdf.write(self._path, self.keyspace_node, self._keys) - def remove(self, removing_key: str): + def remove(self, removing_key: str) -> None: """Whenever the artifact removes a key and data, remove is called to remove the key from keyspace and write the updated keyspace.""" @@ -268,10 +264,12 @@ def __contains__(self, item): return item in self._keys -def _parse_draw_filters(filter_terms): +def _parse_draw_filters(filter_terms) -> Optional[list[str]]: """Given a list of filter terms, parse out any related to draws and convert - to the list of column names. Also include 'value' column for compatibility - with data that is long on draws.""" + to the list of column names. + + Also include 'value' column for compatibility with data that is long on draws. + """ columns = None if filter_terms: diff --git a/src/vivarium/framework/artifact/hdf.py b/src/vivarium/framework/artifact/hdf.py index c57578c40..80474db29 100644 --- a/src/vivarium/framework/artifact/hdf.py +++ b/src/vivarium/framework/artifact/hdf.py @@ -55,7 +55,7 @@ #################### -def touch(path: Union[str, Path]): +def touch(path: Union[str, Path]) -> None: """Creates an HDF file, wiping an existing file if necessary. If the given path is proper to create a HDF file, it creates a new @@ -78,7 +78,7 @@ def touch(path: Union[str, Path]): pass -def write(path: Union[str, Path], entity_key: str, data: Any): +def write(path: Union[str, Path], entity_key: str, data: Any) -> None: """Writes data to the HDF file at the given path to the given key. Parameters @@ -134,16 +134,14 @@ def load( column_filters An optional list of columns to load from the data. - Raises - ------ - ValueError - If the path or entity_key are improperly formatted. - Returns ------- - Any The data stored at the the given key in the HDF file. + Raises + ------ + ValueError + If the path or entity_key are improperly formatted. """ path = _get_valid_hdf_path(path) entity_key = EntityKey(entity_key) @@ -174,7 +172,7 @@ def load( return data -def remove(path: Union[str, Path], entity_key: str): +def remove(path: Union[str, Path], entity_key: str) -> None: """Removes a piece of data from an HDF file. Parameters @@ -188,7 +186,6 @@ def remove(path: Union[str, Path], entity_key: str): ------ ValueError If the path or entity_key are improperly formatted. - """ path = _get_valid_hdf_path(path) entity_key = EntityKey(entity_key) @@ -202,12 +199,11 @@ def get_keys(path: Union[str, Path]) -> List[str]: Parameters ---------- - path : + path The path to the HDF file. Returns ------- - List[str] A list of key representations of the internal paths in the HDF. """ path = _get_valid_hdf_path(path) @@ -233,6 +229,10 @@ def __init__(self, key): The string representation of the entity key. Must be formatted as ``"type.name.measure"`` or ``"type.measure"``. + Raises + ------ + ValueError + If the key is improperly formatted. """ elements = [e for e in key.split(".") if e] if len(elements) not in [2, 3] or len(key.split(".")) != len(elements): @@ -291,9 +291,7 @@ def with_measure(self, measure: str) -> "EntityKey": Returns ------- - EntityKey A new EntityKey with the updated measure. - """ if self.name: return EntityKey(f"{self.type}.{self.name}.{measure}") @@ -330,12 +328,11 @@ def _get_valid_hdf_path(path: Union[str, Path]) -> Path: return path -def _write_pandas_data(path: Path, entity_key: EntityKey, data: Union[_PandasObj]): +def _write_pandas_data(path: Path, entity_key: EntityKey, data: _PandasObj): """Write data in a pandas format to an HDF file. This method currently supports :class:`pandas DataFrame` objects, with or with or without columns, and :class:`pandas.Series` objects. - """ if data.empty: # Our data is indexed, sometimes with no other columns. This leaves an @@ -356,7 +353,7 @@ def _write_pandas_data(path: Path, entity_key: EntityKey, data: Union[_PandasObj store.get_storer(entity_key.path).attrs.metadata = metadata -def _write_json_blob(path: Path, entity_key: EntityKey, data: Any): +def _write_json_blob(path: Path, entity_key: EntityKey, data: Any) -> None: """Writes a Python object as json to the HDF file at the given path.""" with tables.open_file(str(path), "a") as store: if entity_key.group_prefix not in store: @@ -397,7 +394,7 @@ def _get_node_name(node: tables.node.Node) -> str: return node_name -def _get_valid_filter_terms(filter_terms, colnames): +def _get_valid_filter_terms(filter_terms, colnames) -> Optional[list[str]]: """Removes any filter terms referencing non-existent columns Parameters @@ -413,7 +410,6 @@ def _get_valid_filter_terms(filter_terms, colnames): The list of valid filter terms (terms that do not reference any column not existing in the data). Returns none if the list is empty because the `where` argument doesn't like empty lists. - """ if not filter_terms: return None diff --git a/src/vivarium/framework/artifact/manager.py b/src/vivarium/framework/artifact/manager.py index 270339a25..59dd363ea 100644 --- a/src/vivarium/framework/artifact/manager.py +++ b/src/vivarium/framework/artifact/manager.py @@ -10,10 +10,10 @@ import re from pathlib import Path -from typing import Any, Callable, List, Sequence, Union +from typing import Any, Callable, List, Optional, Sequence, Union import pandas as pd -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from vivarium.framework.artifact.artifact import Artifact from vivarium.manager import Manager @@ -49,15 +49,17 @@ def setup(self, builder): self.artifact = self._load_artifact(builder.configuration) builder.lifecycle.add_constraint(self.load, allow_during=["setup"]) - def _load_artifact(self, configuration: LayeredConfigTree) -> Union[Artifact, None]: - """Looks up the path to the artifact hdf file, builds a default filter, + def _load_artifact(self, configuration: LayeredConfigTree) -> Optional[Artifact]: + """Loads artifact data. + + Looks up the path to the artifact hdf file, builds a default filter, and generates the data artifact. Stores any configuration specified filter terms separately to be applied on loading, because not all columns are available via artifact filter terms. Parameters ---------- - configuration : + configuration Configuration block of the model specification containing the input data parameters. Returns @@ -87,7 +89,6 @@ def load(self, entity_key: str, **column_filters: _Filter) -> Any: Returns ------- - Any The data associated with the given key, filtered down to the requested subset if the data is a dataframe. """ @@ -104,8 +105,7 @@ def load(self, entity_key: str, **column_filters: _Filter) -> Any: ) def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]: - """ - Returns a function that returns the value columns for the given input. + """Returns a function that returns the value columns for the given input. The function can be called with either a string or a pandas DataFrame. If a string is provided, it is interpreted as an artifact key, and the @@ -115,7 +115,6 @@ def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]: Returns ------- - Callable[[Union[str, pandas.core.generic.PandasObject]], List[str]] A function that returns the value columns for the given input. """ return lambda _: [self._default_value_column] @@ -130,7 +129,7 @@ class ArtifactInterface: def __init__(self, manager: ArtifactManager): self._manager = manager - def load(self, entity_key: str, **column_filters: Union[_Filter]) -> pd.DataFrame: + def load(self, entity_key: str, **column_filters: _Filter) -> pd.DataFrame: """Loads data associated with a formatted entity key. The provided entity key must be of the form @@ -158,14 +157,12 @@ def load(self, entity_key: str, **column_filters: Union[_Filter]) -> pd.DataFram Returns ------- - pandas.DataFrame The data associated with the given key filtered down to the requested subset. """ return self._manager.load(entity_key, **column_filters) def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]: - """ - Returns a function that returns the value columns for the given input. + """Returns a function that returns the value columns for the given input. The function can be called with either a string or a pandas DataFrame. If a string is provided, it is interpreted as an artifact key, and the @@ -173,7 +170,6 @@ def value_columns(self) -> Callable[[Union[str, pd.DataFrame]], List[str]]: Returns ------- - Callable[[Union[str, pandas.core.generic.PandasObject]], List[str]] A function that returns the value columns for the given input. """ return self._manager.value_columns() @@ -183,7 +179,7 @@ def __repr__(self): def filter_data( - data: pd.DataFrame, config_filter_term: str = None, **column_filters: _Filter + data: pd.DataFrame, config_filter_term: Optional[str] = None, **column_filters: _Filter ) -> pd.DataFrame: """Uses the provided column filters and age_group conditions to subset the raw data.""" data = _config_filter(data, config_filter_term) @@ -264,7 +260,6 @@ def parse_artifact_path_config(config: LayeredConfigTree) -> str: Returns ------- - str The path to the data artifact. """ path = Path(config.input_data.artifact_path) diff --git a/src/vivarium/framework/components/manager.py b/src/vivarium/framework/components/manager.py index 3af646e82..ebb8bb9da 100644 --- a/src/vivarium/framework/components/manager.py +++ b/src/vivarium/framework/components/manager.py @@ -106,7 +106,6 @@ def __repr__(self): class ComponentManager(Manager): """Manages the initialization and setup of :mod:`vivarium` components. - Maintains references to all components and managers in a :mod:`vivarium` simulation, applies their default configuration and initiates their ``setup`` life-cycle stage. @@ -158,7 +157,6 @@ def add_managers(self, managers: Union[List[Manager], Tuple[Manager]]) -> None: ---------- managers Instantiated managers to register. - """ for m in self._flatten(list(managers)): self.apply_configuration_defaults(m) @@ -173,7 +171,6 @@ def add_components(self, components: Union[List[Component], Tuple[Component]]) - ---------- components Instantiated components to register. - """ for c in self._flatten(list(components)): self.apply_configuration_defaults(c) @@ -191,9 +188,7 @@ def get_components_by_type( Returns ------- - List[Any] A list of components of type ``component_type``. - """ # Convert component_type to a tuple for isinstance return [c for c in self._components if isinstance(c, tuple(component_type))] @@ -210,14 +205,12 @@ def get_component(self, name: str) -> Component: Returns ------- - Component A component that has name ``name``. Raises ------ ValueError No component exists in the component manager with ``name``. - """ for c in self._components: if c.name == name: @@ -229,7 +222,6 @@ def list_components(self) -> Dict[str, Component]: Returns ------- - Dict[str, Any] A mapping of component names to components. """ @@ -248,7 +240,6 @@ def setup_components(self, builder: "Builder") -> None: ---------- builder Interface to several simulation tools. - """ self._setup_components(builder, self._managers + self._components) @@ -322,11 +313,11 @@ def __repr__(self): class ComponentInterface: - """The builder interface for the component manager system. This class - defines component manager methods a ``vivarium`` component can access from - the builder. It provides methods for querying and adding components to the - :class:`ComponentManager`. + """The builder interface for the component manager system. + This class defines component manager methods a ``vivarium`` component can + access from the builder. It provides methods for querying and adding components + to the :class:`ComponentManager`. """ def __init__(self, manager: ComponentManager): @@ -340,10 +331,10 @@ def get_component(self, name: str) -> Component: ---------- name A component name. + Returns ------- A component that has name ``name``. - """ return self._manager.get_component(name) @@ -360,9 +351,7 @@ def get_components_by_type( Returns ------- - List[Any] A list of components of type ``component_type``. - """ return self._manager.get_components_by_type(component_type) @@ -371,8 +360,6 @@ def list_components(self) -> Dict[str, Component]: Returns ------- - Dict[str, Any] A dictionary mapping component names to components. - """ return self._manager.list_components() diff --git a/src/vivarium/framework/components/parser.py b/src/vivarium/framework/components/parser.py index a6270a8e9..23d143bb2 100644 --- a/src/vivarium/framework/components/parser.py +++ b/src/vivarium/framework/components/parser.py @@ -23,7 +23,7 @@ from typing import Dict, List, Tuple, Union -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from vivarium.framework.utilities import import_by_path @@ -38,8 +38,7 @@ class ParsingError(ComponentConfigError): class ComponentConfigurationParser: - """ - Parses component configuration from model specification and initializes + """Parses component configuration from model specification and initializes components. To define your own set of parsing rules, you should write a parser class @@ -58,6 +57,7 @@ class ComponentConfigurationParser: All classes that are initialized from the ``yaml`` configuration must either take no arguments or take arguments specified as strings. + """ def get_components( @@ -84,9 +84,7 @@ def get_components( Returns ------- - List A list of initialized components. - """ if isinstance(component_config, LayeredConfigTree): component_list = self.parse_component_config(component_config) @@ -98,9 +96,9 @@ def get_components( def parse_component_config(self, component_config: LayeredConfigTree) -> List[Component]: """ - Helper function for parsing a LayeredConfigTree into a flat list of Components. + Helper function for parsing a ``LayeredConfigTree`` into a flat list of Components. - This function converts the LayeredConfigTree into a dictionary and passes it + This function converts the ``LayeredConfigTree`` into a dictionary and passes it along with an empty prefix list to :meth:`process_level `. The result is a flat list of components. @@ -108,11 +106,10 @@ def parse_component_config(self, component_config: LayeredConfigTree) -> List[Co Parameters ---------- component_config - A LayeredConfigTree representing a hierarchical component specification blob. + A ``LayeredConfigTree`` representing a hierarchical component specification blob. Returns ------- - List[Component] A flat list of Components """ if not component_config: @@ -162,7 +159,6 @@ def process_level( Returns ------- - List[Component] A flat list of Components """ if not level: @@ -192,8 +188,7 @@ def process_level( return component_list def create_component_from_string(self, component_string: str) -> Component: - """ - Helper function for creating a component from a string. + """Helper function for creating a component from a string. This function takes a string representing a component and turns it into an instantiated component object. @@ -206,7 +201,6 @@ def create_component_from_string(self, component_string: str) -> Component: Returns ------- - Component An instantiated component object. """ component_path, args = self.prep_component(component_string) @@ -224,7 +218,6 @@ def prep_component(self, component_string: str) -> Tuple[str, Tuple]: Returns ------- - Tuple[str, Tuple] Component/argument tuple. """ path, args_plus = component_string.split("(") @@ -233,8 +226,7 @@ def prep_component(self, component_string: str) -> Tuple[str, Tuple]: @staticmethod def _clean_args(args: List, path: str) -> Tuple: - """ - Transform component arguments into a tuple, validating that each argument + """Transform component arguments into a tuple, validating that each argument is a string. Parameters @@ -246,7 +238,6 @@ def _clean_args(args: List, path: str) -> Tuple: Returns ------- - Tuple A tuple of arguments, each of which is guaranteed to be a string. """ out = [] @@ -266,8 +257,7 @@ def _clean_args(args: List, path: str) -> Tuple: @staticmethod def import_and_instantiate_component(component_path: str, args: Tuple[str]) -> Component: - """ - Transform a tuple representing a Component into an actual instantiated + """Transform a tuple representing a Component into an actual instantiated component object. Parameters @@ -280,8 +270,6 @@ def import_and_instantiate_component(component_path: str, args: Tuple[str]) -> C Returns ------- - Component An instantiated component object. - """ return import_by_path(component_path)(*args) diff --git a/src/vivarium/framework/engine.py b/src/vivarium/framework/engine.py index cf998caec..761f3e82a 100644 --- a/src/vivarium/framework/engine.py +++ b/src/vivarium/framework/engine.py @@ -22,13 +22,13 @@ from pathlib import Path from pprint import pformat from time import time -from typing import Any, Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Union import dill import numpy as np import pandas as pd -import yaml -from layered_config_tree import ConfigurationKeyError, LayeredConfigTree +from layered_config_tree.exceptions import ConfigurationKeyError +from layered_config_tree.main import LayeredConfigTree from vivarium import Component from vivarium.exceptions import VivariumError @@ -63,16 +63,14 @@ def _get_context_name(sim_name: Union[str, None]) -> str: Returns ------- - str A unique name for the simulation context. - Note - ---- + Notes + ----- This method mutates process global state (the class attribute ``_created_simulation_contexts``) in order to keep track contexts that have been generated. This functionality makes generating simulation contexts in parallel a non-threadsafe operation. - """ if sim_name is None: sim_number = len(SimulationContext._created_simulation_contexts) + 1 @@ -92,8 +90,9 @@ def _get_context_name(sim_name: Union[str, None]) -> str: def _clear_context_cache(): """Clear the cache of simulation context names. + Notes + ----- This is primarily useful for testing purposes. - """ SimulationContext._created_simulation_contexts = set() @@ -358,33 +357,39 @@ class Builder: Attributes ---------- - logging: LoggingInterface + configuration : ``LayeredConfigTree`` + Provides access to the :ref:`configuration` + logging : LoggingInterface Provides access to the :ref:`logging` system. - lookup: LookupTableInterface + lookup : LookupTableInterface Provides access to simulant-specific data via the :ref:`lookup table` abstraction. - value: ValuesInterface + value : ValuesInterface Provides access to computed simulant attribute values via the :ref:`value pipeline` system. - event: EventInterface + event : EventInterface Provides access to event listeners utilized in the :ref:`event` system. - population: PopulationInterface + population : PopulationInterface Provides access to simulant state table via the :ref:`population` system. - resources: ResourceInterface + resources : ResourceInterface Provides access to the :ref:`resource` system, which manages dependencies between components. - time: TimeInterface + results : ResultsInterface + Provides access to the :ref:`results` system. + randomness : RandomnessInterface + Provides access to the :ref:`randomness` system. + time : TimeInterface Provides access to the simulation's :ref:`clock`. - components: ComponentInterface + components : ComponentInterface Provides access to the :ref:`component management` system, which maintains a reference to all managers and components in the simulation. - lifecycle: LifeCycleInterface + lifecycle : LifeCycleInterface Provides access to the :ref:`life-cycle` system, which manages the simulation's execution life-cycle. - data: ArtifactInterface + data : ArtifactInterface Provides access to the simulation's input data housed in the :ref:`data artifact`. diff --git a/src/vivarium/framework/event.py b/src/vivarium/framework/event.py index 550dbe8b1..08b31e700 100644 --- a/src/vivarium/framework/event.py +++ b/src/vivarium/framework/event.py @@ -71,9 +71,7 @@ def split(self, new_index: pd.Index) -> "Event": Returns ------- - Event The new event. - """ return Event(new_index, self.user_data, self.time, self.step_size) @@ -94,7 +92,7 @@ def __init__(self, manager, name): self.manager = manager self.listeners = [[] for _ in range(10)] - def emit(self, index: pd.Index, user_data: Dict = None) -> Event: + def emit(self, index: pd.Index, user_data: Optional[Dict] = None) -> Event: """Notifies all listeners to this channel that an event has occurred. Events are emitted to listeners in order of priority (with order 0 being @@ -108,7 +106,6 @@ def emit(self, index: pd.Index, user_data: Dict = None) -> Event: affected by this event. user_data Any additional data provided by the user about the event. - """ if not user_data: user_data = {} @@ -159,7 +156,6 @@ def setup(self, builder): ---------- builder Object giving access to core framework functionality. - """ self.clock = builder.time.clock() self.step_size = builder.time.step_size() @@ -190,7 +186,6 @@ def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]: A function that accepts an index and optional user data. This function creates and timestamps an Event and distributes it to all interested listeners - """ channel = self.get_channel(name) try: @@ -201,7 +196,7 @@ def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]: pass return channel.emit - def register_listener(self, name: str, listener: Callable, priority: int = 5): + def register_listener(self, name: str, listener: Callable, priority: int = 5) -> None: """Registers a new listener to the named event. Parameters @@ -217,7 +212,7 @@ def register_listener(self, name: str, listener: Callable, priority: int = 5): self.get_channel(name).listeners[priority].append(listener) def get_listeners(self, name: str) -> Dict[int, List[Callable]]: - """Get all listeners registered for the named event. + """Get all listeners registered for the named event. Parameters ---------- @@ -241,14 +236,12 @@ def list_events(self) -> List[Event]: Returns ------- - List[Event] - A list of all known event names. + A list of all known events. Notes ----- This value can change after setup if components dynamically create new event labels. - """ return list(self._event_types.keys()) @@ -281,7 +274,6 @@ def get_emitter(self, name: str) -> Callable[[pd.Index, Optional[Dict]], Event]: An emitter for the named event. The emitter should be called by the requesting component at the appropriate point in the simulation lifecycle. - """ return self._manager.get_emitter(name) @@ -322,6 +314,5 @@ def register_listener( state table (the state of the simulation at the beginning of the next time step should only depend on the current state of the system). - """ self._manager.register_listener(name, listener, priority) diff --git a/src/vivarium/framework/lifecycle.py b/src/vivarium/framework/lifecycle.py index 2fb0c6baf..cf2501e37 100644 --- a/src/vivarium/framework/lifecycle.py +++ b/src/vivarium/framework/lifecycle.py @@ -37,8 +37,6 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Tuple -import numpy as np - from vivarium.exceptions import VivariumError from vivarium.manager import Manager @@ -81,7 +79,7 @@ def entrance_count(self) -> int: """The number of times this state has been entered.""" return self._entrance_count - def add_next(self, next_state: "LifeCycleState", loop: bool = False): + def add_next(self, next_state: "LifeCycleState", loop: bool = False) -> None: """Link this state to the next state in the simulation life cycle. States are linked together and used to ensure that the simulation @@ -98,7 +96,6 @@ def add_next(self, next_state: "LifeCycleState", loop: bool = False): loop Whether the provided state is the linear next state or a loop back to a previous state in the life cycle. - """ if loop: self._loop_next = next_state @@ -115,19 +112,17 @@ def valid_next_state(self, state: Optional["LifeCycleState"]) -> bool: Returns ------- - bool Whether the state is valid for a transition. - """ return (state is None and state is self._next) or ( state is not None and (state is self._next or state is self._loop_next) ) - def enter(self): + def enter(self) -> None: """Marks an entrance into this state.""" self._entrance_count += 1 - def add_handlers(self, handlers: List[Callable]): + def add_handlers(self, handlers: List[Callable]) -> None: """Registers a set of functions that will be executed during the state. The primary use case here is for introspection and reporting. @@ -137,7 +132,6 @@ def add_handlers(self, handlers: List[Callable]): ---------- handlers The set of functions that will be executed during this state. - """ for h in handlers: name = h.__name__ @@ -214,7 +208,7 @@ def __init__(self): self._phases = [] self.add_phase("initialization", ["initialization"], loop=False) - def add_phase(self, phase_name: str, states: List[str], loop): + def add_phase(self, phase_name: str, states: List[str], loop) -> None: """Add a new phase to the lifecycle. Phases must be added in order. @@ -234,7 +228,6 @@ def add_phase(self, phase_name: str, states: List[str], loop): ------ LifeCycleError If the phase or state names are non-unique. - """ self._validate(phase_name, states) @@ -256,14 +249,12 @@ def get_state(self, state_name: str) -> LifeCycleState: Returns ------- - LifeCycleState The requested state. Raises ------ LifeCycleError If the requested state does not exist. - """ if state_name not in self: raise LifeCycleError(f"Attempting to look up non-existent state {state_name}.") @@ -280,14 +271,12 @@ def get_state_names(self, phase_name: str) -> List[str]: Return ------ - List[str] The state names in the provided phase. Raises ------ LifeCycleError If the phase does not exist in the life cycle. - """ if phase_name not in self._phase_names: raise LifeCycleError( @@ -296,7 +285,7 @@ def get_state_names(self, phase_name: str) -> List[str]: phase = [p for p in self._phases if p.name == phase_name].pop() return [s.name for s in phase.states] - def _validate(self, phase_name: str, states: List[str]): + def _validate(self, phase_name: str, states: List[str]) -> None: """Validates that a phase and set of states are unique.""" if phase_name in self._phase_names: raise LifeCycleError( @@ -332,7 +321,7 @@ def __init__(self, lifecycle_manager): self.lifecycle_manager = lifecycle_manager self.constraints = set() - def check_valid_state(self, method: Callable, permitted_states: List[str]): + def check_valid_state(self, method: Callable, permitted_states: List[str]) -> None: """Ensures a component method is being called during an allowed state. Parameters @@ -346,7 +335,6 @@ def check_valid_state(self, method: Callable, permitted_states: List[str]): ------ ConstraintError If the method is being called outside the permitted states. - """ current_state = self.lifecycle_manager.current_state if current_state not in permitted_states: @@ -372,9 +360,7 @@ def constrain_normal_method( Returns ------- - Callable The constrained method. - """ @functools.wraps(method) @@ -401,6 +387,14 @@ def to_guid(method: Callable) -> str: collected, making :func:`id` unreliable for checking if a method has been constrained before. + Parameters + ---------- + method + The method to convert to a global id. + + Returns + ------- + The global id of the method. """ return f"{method.__self__.name}.{method.__name__}" @@ -428,7 +422,6 @@ def __call__(self, method: Callable, permitted_states: List[str]) -> Callable: ValueError If the provided method is a python "special" method (i.e. a method surrounded by double underscores). - """ if not hasattr(method, "__self__"): raise TypeError( @@ -473,7 +466,7 @@ def current_state(self) -> str: def timings(self) -> Dict[str, List[float]]: return self._timings - def add_phase(self, phase_name: str, states: List[str], loop: bool = False): + def add_phase(self, phase_name: str, states: List[str], loop: bool = False) -> None: """Add a new phase to the lifecycle. Phases must be added in order. @@ -493,11 +486,10 @@ def add_phase(self, phase_name: str, states: List[str], loop: bool = False): ------ LifeCycleError If the phase or state names are non-unique. - """ self.lifecycle.add_phase(phase_name, states, loop) - def set_state(self, state: str): + def set_state(self, state: str) -> None: """Sets the current life cycle state to the provided state. Parameters @@ -512,7 +504,6 @@ def set_state(self, state: str): InvalidTransitionError If setting the provided state represents an invalid life cycle transition. - """ new_state = self.lifecycle.get_state(state) if self._current_state.valid_next_state(new_state): @@ -538,13 +529,11 @@ def get_state_names(self, phase: str) -> List[str]: Returns ------- - List[str] A list of state names in order of execution. - """ return self.lifecycle.get_state_names(phase) - def add_handlers(self, state_name: str, handlers: List[Callable]): + def add_handlers(self, state_name: str, handlers: List[Callable]) -> None: """Registers a set of functions to be called during a life cycle state. This method does not apply any constraints, rather it is used @@ -556,13 +545,12 @@ def add_handlers(self, state_name: str, handlers: List[Callable]): The name of the state to register the handlers for. handlers A list of functions that will execute during the state. - """ s = self.lifecycle.get_state(state_name) s.add_handlers(handlers) def add_constraint( - self, method: Callable, allow_during: List[str] = (), restrict_during: List[str] = () + self, method: Callable, allow_during: List[str] = [], restrict_during: List[str] = [] ): """Constrains a function to be executable only during certain states. @@ -587,7 +575,6 @@ def add_constraint( ConstraintError If a lifecycle constraint has already been applied to the provided method. - """ if allow_during and restrict_during or not (allow_during or restrict_during): raise ValueError( @@ -627,7 +614,7 @@ class LifeCycleInterface: def __init__(self, manager: LifeCycleManager): self._manager = manager - def add_handlers(self, state: str, handlers: List[Callable]): + def add_handlers(self, state: str, handlers: List[Callable]) -> None: """Registers a set of functions to be called during a life cycle state. This method does not apply any constraints, rather it is used @@ -639,13 +626,12 @@ def add_handlers(self, state: str, handlers: List[Callable]): The name of the state to register the handlers for. handlers A list of functions that will execute during the state. - """ self._manager.add_handlers(state, handlers) def add_constraint( - self, method: Callable, allow_during: List[str] = (), restrict_during: List[str] = () - ): + self, method: Callable, allow_during: List[str] = [], restrict_during: List[str] = [] + ) -> None: """Constrains a function to be executable only during certain states. Parameters @@ -669,7 +655,6 @@ def add_constraint( ConstraintError If a life cycle constraint has already been applied to the provided method. - """ self._manager.add_constraint(method, allow_during, restrict_during) @@ -678,8 +663,6 @@ def current_state(self) -> Callable[[], str]: Returns ------- - Callable[[], str] A callable that returns the current simulation lifecycle state. - """ return lambda: self._manager.current_state diff --git a/src/vivarium/framework/logging/utilities.py b/src/vivarium/framework/logging/utilities.py index 4aaf708d6..f5507252f 100644 --- a/src/vivarium/framework/logging/utilities.py +++ b/src/vivarium/framework/logging/utilities.py @@ -25,7 +25,6 @@ def configure_logging_to_terminal(verbosity: int, long_format: bool = True) -> N long_format Whether to use the long format for logging messages, which includes explicit information about the simulation context and component in the log messages. - """ _clear_default_configuration() _add_logging_sink( @@ -44,7 +43,6 @@ def configure_logging_to_file(output_directory: Path) -> None: ---------- output_directory The directory to write the log file to. - """ log_file = output_directory / "simulation.log" _add_logging_sink( @@ -88,7 +86,6 @@ def _add_logging_sink( serialize Whether to serialize log messages. This is useful when logging to a file or a database. - """ log_formatter = _LogFormatter(long_format) logging_level = _get_log_level(verbosity) diff --git a/src/vivarium/framework/lookup/interpolation.py b/src/vivarium/framework/lookup/interpolation.py index 2898aae23..7bd6f469c 100644 --- a/src/vivarium/framework/lookup/interpolation.py +++ b/src/vivarium/framework/lookup/interpolation.py @@ -33,6 +33,7 @@ class Interpolation: for left bin edge, column name for right bin edge). order : Order of interpolation. + """ def __init__( @@ -99,7 +100,6 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: Returns ------- - pandas.DataFrame A table with the interpolated values for the given interpolants. """ @@ -189,8 +189,10 @@ def validate_call_data(data, categorical_parameters, continuous_parameters): ) -def check_data_complete(data, continuous_parameters): - """For any parameters specified with edges, make sure edges +def check_data_complete(data, continuous_parameters) -> None: + """Check that data is complete for interpolation. + + For any parameters specified with edges, make sure edges don't overlap and don't have any gaps. Assumes that edges are specified with ends and starts overlapping (but one exclusive and the other inclusive) so can check that end of previous == start @@ -204,6 +206,15 @@ def check_data_complete(data, continuous_parameters): should cover a continuous range of that parameter with no overlaps or gaps and the range covered should be the same for all combinations of other parameter values. + + Raises + ------ + ValueError + If there are missing values for every combinations of continuous parameters. + ValueError + If the parameter data contains overlaps. + NotImplementedError + If a parameter contains non-continuous bins. """ param_edges = [ @@ -255,11 +266,18 @@ class Order0Interp: Attributes ---------- - data : - The data from which to build the interpolation. Contains - categorical_parameters and continuous_parameters. - continuous_parameters : - Column names to be used as parameters in Interpolation. + data + The data from which to build the interpolation. + value_columns + Columns to be interpolated. + extrapolate + Whether or not to extrapolate beyond the edge of supplied bins. + parameter_bins + A dictionary where they keys are a tuple of the form + (column name used in call, column name for left bin edge, column name for right bin edge) + and the values are dictionaries of the form {"bins": [ordered left edges of bins], + "max": max right edge (used when extrapolation not allowed)}. + """ def __init__( @@ -271,19 +289,21 @@ def __init__( validate: bool, ): """ - Parameters ---------- - data : + data Data frame used to build interpolation. - continuous_parameters : + continuous_parameters Parameter columns. Should be of form (column name used in call, column name for left bin edge, column name for right bin edge) or column name. Assumes left bin edges are inclusive and right exclusive. - extrapolate : + value_columns + Columns to be interpolated. + extrapolate Whether or not to extrapolate beyond the edge of supplied bins. - + validate + Whether or not to validate the data. """ if validate: check_data_complete(data, continuous_parameters) @@ -311,14 +331,12 @@ def __call__(self, interpolants: pd.DataFrame) -> pd.DataFrame: Parameters ---------- - interpolants: + interpolants Data frame containing the parameters to interpolate.. Returns ------- - pandas.DataFrame A table with the interpolated values for the given interpolants. - """ # build a dataframe where we have the start of each parameter bin for each interpolant interpolant_bins = pd.DataFrame(index=interpolants.index) diff --git a/src/vivarium/framework/lookup/table.py b/src/vivarium/framework/lookup/table.py index b5149276e..ee00a0dda 100644 --- a/src/vivarium/framework/lookup/table.py +++ b/src/vivarium/framework/lookup/table.py @@ -167,7 +167,6 @@ def call(self, index: pd.Index) -> pd.DataFrame: Returns ------- - pandas.DataFrame A table with the interpolated values for the population requested. """ @@ -235,7 +234,6 @@ def call(self, index: pd.Index) -> pd.DataFrame: Returns ------- - pandas.DataFrame A table with the mapped values for the population requested. """ pop = self.population_view.get(index) @@ -284,7 +282,6 @@ def call(self, index: pd.Index) -> pd.DataFrame: Returns ------- - pandas.DataFrame A table with a column for each of the scalar values for the population requested. diff --git a/src/vivarium/framework/plugins.py b/src/vivarium/framework/plugins.py index 2279b7817..08d8a10d5 100644 --- a/src/vivarium/framework/plugins.py +++ b/src/vivarium/framework/plugins.py @@ -9,7 +9,7 @@ """ -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from vivarium.exceptions import VivariumError from vivarium.framework.utilities import import_by_path diff --git a/src/vivarium/framework/population/manager.py b/src/vivarium/framework/population/manager.py index 4cabd2d72..d0ea654d0 100644 --- a/src/vivarium/framework/population/manager.py +++ b/src/vivarium/framework/population/manager.py @@ -47,7 +47,7 @@ def __init__(self): self._components = {} self._columns_produced = {} - def add(self, initializer: Callable, columns_produced: List[str]): + def add(self, initializer: Callable, columns_produced: List[str]) -> None: """Adds an initializer and columns to the set, enforcing uniqueness. Parameters @@ -67,7 +67,6 @@ def add(self, initializer: Callable, columns_produced: List[str]): If the component bound to the method already has an initializer registered or if the columns produced are duplicates of columns another initializer produces. - """ if not isinstance(initializer, MethodType): raise TypeError( @@ -188,7 +187,7 @@ def __repr__(self): ########################### def get_view( - self, columns: Union[List[str], Tuple[str]], query: str = None + self, columns: Union[List[str], Tuple[str]], query: Optional[str] = None ) -> PopulationView: """Get a time-varying view of the population state table. @@ -216,7 +215,6 @@ def get_view( Returns ------- - PopulationView A filtered view of the requested columns of the population state table. @@ -237,7 +235,9 @@ def get_view( ) return view - def _get_view(self, columns: Union[List[str], Tuple[str]], query: str = None): + def _get_view( + self, columns: Union[List[str], Tuple[str]], query: Optional[str] = None + ) -> PopulationView: if columns and "tracked" not in columns: if query is None: query = "tracked == True" @@ -253,7 +253,7 @@ def register_simulant_initializer( requires_columns: List[str] = (), requires_values: List[str] = (), requires_streams: List[str] = (), - ): + ) -> None: """Marks a source of initial state information for new simulants. Parameters @@ -274,7 +274,6 @@ def register_simulant_initializer( requires_streams A list of the randomness streams necessary to initialize the simulant attributes. - """ self._initializer_components.add(initializer, creates_columns) dependencies = ( @@ -293,19 +292,18 @@ def register_simulant_initializer( def get_simulant_creator(self) -> Callable[[int, Optional[Dict[str, Any]]], pd.Index]: """Gets a function that can generate new simulants. + The creator function takes the number of simulants to be created as it's + first argument and a dict population configuration that will be available + to simulant initializers as it's second argument. It generates the new rows + in the population state table and then calls each initializer + registered with the population system with a data + object containing the state table index of the new simulants, the + configuration info passed to the creator, the current simulation + time, and the size of the next time step. + Returns ------- - Callable - The simulant creator function. The creator function takes the - number of simulants to be created as it's first argument and a dict - population configuration that will be available to simulant - initializers as it's second argument. It generates the new rows in - the population state table and then calls each initializer - registered with the population system with a data - object containing the state table index of the new simulants, the - configuration info passed to the creator, the current simulation - time, and the size of the next time step. - + The simulant creator function. """ return self._create_simulants @@ -347,9 +345,7 @@ def get_population(self, untracked: bool) -> pd.DataFrame: Returns ------- - pandas.DataFrame A copy of the population table. - """ pop = self._population.copy() if self._population is not None else pd.DataFrame() if not untracked and "tracked" in pop.columns: @@ -384,7 +380,7 @@ def __init__(self, manager: PopulationManager): self._manager = manager def get_view( - self, columns: Union[List[str], Tuple[str]], query: str = None + self, columns: Union[List[str], Tuple[str]], query: Optional[str] = None ) -> PopulationView: """Get a time-varying view of the population state table. @@ -412,28 +408,26 @@ def get_view( Returns ------- - PopulationView A filtered view of the requested columns of the population state table. - """ return self._manager.get_view(columns, query) def get_simulant_creator(self) -> Callable[[int, Optional[Dict[str, Any]]], pd.Index]: """Gets a function that can generate new simulants. + The creator function takes the number of simulants to be created as it's + first argument and a dict population configuration that will be available + to simulant initializers as it's second argument. It generates the new rows + in the population state table and then calls each initializer + registered with the population system with a data + object containing the state table index of the new simulants, the + configuration info passed to the creator, the current simulation + time, and the size of the next time step. + Returns ------- - The simulant creator function. The creator function takes the - number of simulants to be created as it's first argument and a dict - population configuration that will be available to simulant - initializers as it's second argument. It generates the new rows in - the population state table and then calls each initializer - registered with the population system with a data - object containing the state table index of the new simulants, the - configuration info passed to the creator, the current simulation - time, and the size of the next time step. - + The simulant creator function. """ return self._manager.get_simulant_creator() @@ -444,7 +438,7 @@ def initializes_simulants( requires_columns: List[str] = (), requires_values: List[str] = (), requires_streams: List[str] = (), - ): + ) -> None: """Marks a source of initial state information for new simulants. Parameters @@ -465,7 +459,6 @@ def initializes_simulants( requires_streams A list of the randomness streams necessary to initialize the simulant attributes. - """ self._manager.register_simulant_initializer( initializer, creates_columns, requires_columns, requires_values, requires_streams diff --git a/src/vivarium/framework/population/population_view.py b/src/vivarium/framework/population/population_view.py index cbcdfcdf6..3f74a8fe9 100644 --- a/src/vivarium/framework/population/population_view.py +++ b/src/vivarium/framework/population/population_view.py @@ -12,7 +12,7 @@ """ -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import pandas as pd @@ -25,22 +25,12 @@ class PopulationView: """A read/write manager for the simulation state table. + It can be used to both read and update the state of the population. A PopulationView can only read and write columns for which it is configured. Attempts to update non-existent columns are ignored except during simulant creation when new columns are allowed to be created. - Parameters - ---------- - manager - The population manager for the simulation. - columns - The set of columns this view should have access too. If empty, this - view will have access to the entire state table. - query - A :mod:`pandas`-style filter that will be applied any time this - view is read from. - Notes ----- By default, this view will filter out ``untracked`` simulants unless @@ -53,8 +43,23 @@ def __init__( manager: "PopulationManager", view_id: int, columns: Union[List[str], Tuple[str]] = (), - query: str = None, + query: Optional[str] = None, ): + """ + + Parameters + ---------- + manager + The population manager for the simulation. + view_id + The unique identifier for this view. + columns + The set of columns this view should have access too. If empty, this + view will have access to the entire state table. + query + A :mod:`pandas`-style filter that will be applied any time this + view is read from. + """ self._manager = manager self._id = view_id self._columns = list(columns) @@ -72,19 +77,17 @@ def columns(self) -> List[str]: the view will have access to the full table by default. That case should be only be used in situations where the full state table is actually needed, like for some metrics collection applications. - """ if not self._columns: return list(self._manager.get_population(True).columns) return list(self._columns) @property - def query(self) -> Union[str, None]: + def query(self) -> Optional[str]: """A :mod:`pandas` style query to filter the rows of this view. This query will be applied any time the view is read. This query may reference columns not in the view's columns. - """ return self._query @@ -99,7 +102,6 @@ def subview(self, columns: Union[List[str], Tuple[str]]) -> "PopulationView": Returns ------- - PopulationView A new view with access to the requested columns. Raises @@ -116,7 +118,6 @@ def subview(self, columns: Union[List[str], Tuple[str]]) -> "PopulationView": requesting a subview, a component can read the sections it needs without running the risk of trying to access uncreated columns because the component itself has not created them. - """ if not columns or set(columns) - set(self.columns): @@ -147,7 +148,6 @@ def get(self, index: pd.Index, query: str = "") -> pd.DataFrame: Returns ------- - pandas.DataFrame A table with the subset of the population requested. Raises @@ -160,7 +160,6 @@ def get(self, index: pd.Index, query: str = "") -> pd.DataFrame: See Also -------- :meth:`subview ` - """ pop = self._manager.get_population(True).loc[index] @@ -204,7 +203,6 @@ def update(self, population_update: Union[pd.DataFrame, pd.Series]) -> None: If the provided data name or columns do not match columns that this view manages or if the view is being updated with a data type inconsistent with the original population data. - """ state_table = self._manager.get_population(True) population_update = self._format_update_and_check_preconditions( @@ -289,7 +287,6 @@ def _format_update_and_check_preconditions( Returns ------- - pandas.DataFrame The input data formatted as a DataFrame. Raises @@ -360,7 +357,6 @@ def _coerce_to_dataframe( Returns ------- - pandas.DataFrame The input data formatted as a DataFrame. Raises @@ -372,7 +368,6 @@ def _coerce_to_dataframe( If the input data is a :class:`pandas.Series` and this :class:`PopulationView` manages multiple columns or if the population update contains columns not managed by this view. - """ if not isinstance(population_update, (pd.Series, pd.DataFrame)): raise TypeError( @@ -434,7 +429,6 @@ def _ensure_coherent_initialization( PopulationError If the population update contains no new information or if it contains information in conflict with the existing state table. - """ missing_pops = len(state_table.index.difference(population_update.index)) if missing_pops: @@ -475,9 +469,7 @@ def _update_column_and_ensure_dtype( Returns ------- - pandas.Series The column with the provided update applied - """ # FIXME: This code does not work as described. I'm leaving it here because writing # real dtype checking code is a pain and we never seem to hit the actual edge cases. diff --git a/src/vivarium/framework/randomness/index_map.py b/src/vivarium/framework/randomness/index_map.py index e7645b1cd..85920b52f 100644 --- a/src/vivarium/framework/randomness/index_map.py +++ b/src/vivarium/framework/randomness/index_map.py @@ -42,7 +42,6 @@ def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: clock_time The simulation clock time. Used as the salt during hashing to minimize inter-simulation collisions. - """ if new_keys.empty or not self._use_crn: return # Nothing to do @@ -65,6 +64,12 @@ def update(self, new_keys: pd.DataFrame, clock_time: pd.Timestamp) -> None: def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.MultiIndex]: """Parses raw new keys into the mapping index. + This returns a tuple of the new and final mapping indices. Both are pandas + indices with a level for the index assigned by the population system and + additional levels for the key columns associated with the simulant index. The + new mapping index contains only the values for the new keys and the final mapping + combines the existing mapping and the new mapping index. + Parameters ---------- new_keys @@ -73,13 +78,7 @@ def _parse_new_keys(self, new_keys: pd.DataFrame) -> Tuple[pd.MultiIndex, pd.Mul Returns ------- - Tuple[pandas.MultiIndex, pandas.MultiIndex] - A tuple of the new mapping index and the final mapping index. Both are pandas - indices with a level for the index assigned by the population system and - additional levels for the key columns associated with the simulant index. The - new mapping index contains only the values for the new keys and the final mapping - combines the existing mapping and the new mapping index. - + A tuple of the new mapping index and the final mapping index. """ keys = new_keys.copy() keys.index.name = self.SIM_INDEX_COLUMN @@ -108,10 +107,8 @@ def _build_final_mapping( Returns ------- - pandas.Series The new mapping incorporating the updates from the new mapping index and resolving collisions. - """ new_key_index = new_mapping_index.droplevel(self.SIM_INDEX_COLUMN) mapping_update = self._hash(new_key_index, salt=clock_time) @@ -140,10 +137,8 @@ def _resolve_collisions( Returns ------- - pandas.Series The new mapping incorporating the updates from the new mapping index and resolving collisions. - """ current_mapping = current_mapping.drop_duplicates() collisions = new_key_index.difference(current_mapping.index) @@ -168,11 +163,9 @@ def _hash(self, keys: pd.Index, salt: int = 0) -> pd.Series: Returns ------- - pandas.Series A pandas series indexed by the given keys and whose values take on integers in the range [0, len(self)]. Duplicates may appear and should be dealt with by the calling code. - """ key_frame = keys.to_frame() new_map = pd.Series(0, index=keys) @@ -204,7 +197,6 @@ def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: Returns ------- - pandas.Series A series of ten digit integers based on the input data. Raises @@ -212,7 +204,6 @@ def _convert_to_ten_digit_int(self, column: pd.Series) -> pd.Series: RandomnessError If the column contains data that is neither a datetime-like nor numeric. - """ if isinstance(column.iloc[0], datetime.datetime): column = self._clip_to_seconds(column.astype(np.int64)) diff --git a/src/vivarium/framework/randomness/manager.py b/src/vivarium/framework/randomness/manager.py index 23eb073e8..2f75722c9 100644 --- a/src/vivarium/framework/randomness/manager.py +++ b/src/vivarium/framework/randomness/manager.py @@ -2,6 +2,7 @@ ========================= Randomness System Manager ========================= + """ import pandas as pd @@ -80,12 +81,17 @@ def get_randomness_stream( copied and should only be used to generate the state table columns specified in ``builder.configuration.randomness.key_columns``. + Returns + ------- + An entry point into the Common Random Number generation framework. + The stream provides vectorized access to random numbers and a few + other utilities. + Raises ------ RandomnessError If another location in the simulation has already created a randomness stream with the same identifier. - """ stream = self._get_randomness_stream(decision_point, initializes_crn_attributes) if not initializes_crn_attributes: @@ -142,14 +148,12 @@ def get_seed(self, decision_point: str) -> int: Returns ------- - int A seed for a random number generation that is linked to Vivarium's common random number framework. - """ return get_hash("_".join([decision_point, str(self._clock()), str(self._seed)])) - def register_simulants(self, simulants: pd.DataFrame): + def register_simulants(self, simulants: pd.DataFrame) -> None: """Adds new simulants to the randomness mapping. Parameters @@ -163,7 +167,6 @@ def register_simulants(self, simulants: pd.DataFrame): RandomnessError If the provided table does not contain all key columns specified in the configuration. - """ if not all(k in simulants.columns for k in self._key_columns): raise RandomnessError( @@ -208,11 +211,9 @@ def get_stream( Returns ------- - RandomnessStream An entry point into the Common Random Number generation framework. The stream provides vectorized access to random numbers and a few other utilities. - """ return self._manager.get_randomness_stream(decision_point, initializes_crn_attributes) @@ -228,14 +229,12 @@ def get_seed(self, decision_point: str) -> int: Returns ------- - int A seed for a random number generation that is linked to Vivarium's common random number framework. - """ return self._manager.get_seed(decision_point) - def register_simulants(self, simulants: pd.DataFrame): + def register_simulants(self, simulants: pd.DataFrame) -> None: """Registers simulants with the Common Random Number Framework. Parameters @@ -245,6 +244,5 @@ def register_simulants(self, simulants: pd.DataFrame): columns specified in ``builder.configuration.randomness.key_columns``. This function should be called as soon as the key columns are generated. - """ self._manager.register_simulants(simulants) diff --git a/src/vivarium/framework/randomness/stream.py b/src/vivarium/framework/randomness/stream.py index 742bdf003..99266c093 100644 --- a/src/vivarium/framework/randomness/stream.py +++ b/src/vivarium/framework/randomness/stream.py @@ -16,14 +16,16 @@ [0.2, 0.2, RESIDUAL_CHOICE] => [0.2, 0.2, 0.6] - Note - ---- - Currently this object is only used in the `choice` function of this - module. + +Notes +----- +Currently this object is only used in the `choice` function of this +module. + """ import hashlib -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -46,9 +48,7 @@ def get_hash(key: str) -> int: Returns ------- - int A hash of the provided key. - """ max_allowable_numpy_seed = 4294967295 # 2**32 - 1 return int(hashlib.sha1(key.encode("utf8")).hexdigest(), 16) % max_allowable_numpy_seed @@ -70,6 +70,10 @@ class RandomnessStream: A way to get the current simulation time. seed An extra number used to seed the random number generation. + index_map + A key-index mapping with a fectorized hash and vectorized lookups. + initializes_crn_attributes + A boolean indicating whether the stram is used to initialize CRN attributes. Notes ----- @@ -113,9 +117,7 @@ def _key(self, additional_key: Any = None) -> str: Returns ------- - str A key to seed random number generation. - """ return "_".join([self.key, str(self.clock()), str(additional_key), str(self.seed)]) @@ -132,11 +134,10 @@ def get_draw(self, index: pd.Index, additional_key: Any = None) -> pd.Series: Returns ------- - pandas.Series A series of random numbers indexed by the provided `pandas.Index`. - Note - ---- + Notes + ----- This is the core of the CRN implementation, allowing for consistent use of random numbers across simulations with multiple scenarios. @@ -145,7 +146,6 @@ def get_draw(self, index: pd.Index, additional_key: Any = None) -> pd.Series: https://en.wikipedia.org/wiki/Variance_reduction and "Untangling Uncertainty with Common Random Numbers: A Simulation Study; A.Flaxman, et. al., Summersim 2017" - """ # Return a structured null value if an empty index is passed if index.empty: @@ -212,10 +212,8 @@ def filter_for_rate( Returns ------- - pandas.core.generic.PandasObject The subpopulation of the simulants for whom the event occurred. - The return type will be the same as type(population) - + The return type will be the same as type(population). """ return self.filter_for_probability( population, rate_to_probability(rate), additional_key @@ -249,10 +247,8 @@ def filter_for_probability( Returns ------- - pandas.core.generic.PandasObject The subpopulation of the simulants for whom the event occurred. - The return type will be the same as type(population) - + The return type will be the same as type(population). """ if population.empty: return population @@ -266,8 +262,8 @@ def choice( self, index: pd.Index, choices: Union[List, Tuple, np.ndarray, pd.Series], - p: Union[List, Tuple, np.ndarray, pd.Series] = None, - additional_key: Any = None, + p: Optional[Union[List, Tuple, np.ndarray, pd.Series]] = None, + additional_key: Optional[Any] = None, ) -> pd.Series: """Decides between a weighted or unweighted set of choices. @@ -294,7 +290,6 @@ def choice( Returns ------- - pandas.Series An indexed set of decisions from among the available `choices`. Raises @@ -303,7 +298,6 @@ def choice( If any row in `p` contains `RESIDUAL_CHOICE` and the remaining weights in the row are not normalized or any row of `p contains more than one reference to `RESIDUAL_CHOICE`. - """ draws = self.get_draw(index, additional_key) return _choice(draws, choices, p) @@ -311,13 +305,12 @@ def choice( def sample_from_distribution( self, index: pd.Index, - distribution: stats.rv_continuous = None, - ppf: Callable[[pd.Series, ...], pd.Series] = None, + distribution: Optional[stats.rv_continuous] = None, + ppf: Optional[Callable[[pd.Series, dict[str, Any]], pd.Series]] = None, additional_key: Any = None, - **distribution_kwargs: Any, + **distribution_kwargs: dict[str, Any], ) -> pd.Series: - """ - Given a distribution, returns an indexed set of samples from that + """Given a distribution, returns an indexed set of samples from that distribution. Parameters @@ -333,6 +326,10 @@ def sample_from_distribution( Any additional information used to seed random number generation. distribution_kwargs Additional keyword arguments to pass to the distribution's ppf function. + + Returns + ------- + An indexed set of samples from the provided distribution. """ if ppf is None and distribution is None: raise ValueError("Either distribution or ppf must be provided") @@ -355,7 +352,7 @@ def __repr__(self) -> str: def _choice( draws: pd.Series, choices: Union[List, Tuple, np.ndarray, pd.Series], - p: Union[List, Tuple, np.ndarray, pd.Series] = None, + p: Optional[Union[List, Tuple, np.ndarray, pd.Series]] = None, ) -> pd.Series: """Decides between a weighted or unweighted set of choices. @@ -381,7 +378,6 @@ def _choice( Returns ------- - pandas.Series An indexed set of decisions from among the available `choices`. Raises @@ -390,7 +386,6 @@ def _choice( If any row in `p` contains `RESIDUAL_CHOICE` and the remaining weights in the row are not normalized or any row of `p` contains more than one reference to `RESIDUAL_CHOICE`. - """ # Convert p to normalized probabilities broadcasted over index. p = ( @@ -430,9 +425,14 @@ def _set_residual_probability(p: np.ndarray) -> np.ndarray: Returns ------- - numpy.ndarray Array where each row is a set of normalized probability weights. + Raises + ------ + RandomnessError + If more than one residual choice is supplied for a single set of weights. + RandomnessError + If residual choice is supplied with weights that sum to more than 1. """ residual_mask = p == RESIDUAL_CHOICE if residual_mask.any(): # I.E. if we have any placeholders. diff --git a/src/vivarium/framework/resource.py b/src/vivarium/framework/resource.py index 47d013e9b..57c9b9d2d 100644 --- a/src/vivarium/framework/resource.py +++ b/src/vivarium/framework/resource.py @@ -72,7 +72,6 @@ def type(self) -> str: """The type of resource produced by this resource group's producer. Must be one of `RESOURCE_TYPES`. - """ return self._resource_type @@ -141,7 +140,6 @@ def sorted_nodes(self): ----- Topological sorts are not stable. Be wary of depending on order where you shouldn't. - """ if self._sorted_nodes is None: try: @@ -184,7 +182,6 @@ def add_resources( If either the resource type is invalid, a component has multiple resource producers for the ``column`` resource type, or there are multiple producers of the same resource. - """ if resource_type not in RESOURCE_TYPES: raise ResourceError( @@ -217,7 +214,6 @@ def _get_resource_group( See Also -------- :class:`ResourceGroup` - """ if not resource_names: # We have a "producer" that doesn't produce anything, but @@ -243,7 +239,6 @@ def _to_graph(self) -> nx.DiGraph: between post setup time when the :class:`values manager ` finalizes pipeline dependencies and population creation time. - """ resource_graph = nx.DiGraph() # networkx ignores duplicates @@ -270,7 +265,6 @@ def __iter__(self) -> Iterable[MethodType]: We exclude all non-initializer dependencies. They were necessary in graph construction, but we only need the column producers at population creation time. - """ return iter( [ @@ -315,7 +309,7 @@ def add_resources( resource_names: List[str], producer: Any, dependencies: List[str], - ): + ) -> None: """Adds managed resources to the resource pool. Parameters @@ -337,7 +331,6 @@ def add_resources( If either the resource type is invalid, a component has multiple resource producers for the ``column`` resource type, or there are multiple producers of the same resource. - """ self._manager.add_resources(resource_type, resource_names, producer, dependencies) @@ -347,6 +340,5 @@ def __iter__(self): We exclude all non-initializer dependencies. They were necessary in graph construction, but we only need the column producers at population creation time. - """ return iter(self._manager) diff --git a/src/vivarium/framework/results/context.py b/src/vivarium/framework/results/context.py index b4bdeb034..773aa2d42 100644 --- a/src/vivarium/framework/results/context.py +++ b/src/vivarium/framework/results/context.py @@ -2,6 +2,7 @@ =============== Results Context =============== + """ from collections import defaultdict @@ -249,7 +250,9 @@ def gather_results( None, ]: """Generate and yield current results for all observations at this lifecycle - phase and event. Each set of results are stratified and grouped by + phase and event. + + Each set of results are stratified and grouped by all registered stratifications as well as filtered by their respective observation's pop_filter. @@ -264,9 +267,9 @@ def gather_results( Yields ------ - A tuple containing each observation's newly observed results, the name of - the observation, and the observations results updater function. Note that - it yields (None, None, None) if the filtered population is empty. + A tuple containing each observation's newly observed results, the name of + the observation, and the observations results updater function. Note that + it yields (None, None, None) if the filtered population is empty. Raises ------ diff --git a/src/vivarium/framework/results/interface.py b/src/vivarium/framework/results/interface.py index eb669bef6..6988ff2ca 100644 --- a/src/vivarium/framework/results/interface.py +++ b/src/vivarium/framework/results/interface.py @@ -6,6 +6,7 @@ This module provides a :class:`ResultsInterface ` class with methods to register stratifications and results producers (referred to as "observations") to a simulation. + """ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union @@ -49,6 +50,7 @@ class ResultsInterface: The purpose of this interface is to provide controlled access to a results backend by means of the builder object; it exposes methods to register both stratifications and results producers (referred to as "observations"). + """ def __init__(self, manager: "ResultsManager") -> None: @@ -324,9 +326,14 @@ def register_adding_observation( aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]] = len, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: - """Registers an adding observation to the results system; that is, - one that adds/sums new results to existing result values. Note that an adding - observation is a specific type of stratified observation. + """Registers an adding observation to the results system. + + An "adding" observation is one that adds/sums new results to existing + result values. + + Notes + ----- + An adding observation is a specific type of stratified observation. Parameters ---------- @@ -386,9 +393,14 @@ def register_concatenating_observation( ] = lambda measure, results: results, to_observe: Callable[[Event], bool] = lambda event: True, ) -> None: - """Registers a concatenating observation to the results system; that is, - one that concatenates new results to existing results. Note that a - concatenating observation is a specific type of unstratified observation. + """Registers a concatenating observation to the results system. + + A "concatenating" observation is one that concatenates new results to + existing results. + + Notes + ----- + A concatenating observation is a specific type of unstratified observation. Parameters ---------- diff --git a/src/vivarium/framework/results/manager.py b/src/vivarium/framework/results/manager.py index f67a547d4..5868f625d 100644 --- a/src/vivarium/framework/results/manager.py +++ b/src/vivarium/framework/results/manager.py @@ -2,6 +2,7 @@ ====================== Results System Manager ====================== + """ from collections import defaultdict @@ -31,6 +32,7 @@ class ResultsManager(Manager): This class contains the public methods used by the :class:`ResultsInterface ` to register stratifications and observations as well as the :meth:`get_results ` method used to retrieve formatted results by the :class:`ResultsContext `. + """ CONFIGURATION_DEFAULTS = { @@ -64,8 +66,8 @@ def get_results(self) -> Dict[str, pd.DataFrame]: Returns ------- - Dict[str, pandas.DataFrame] - A dictionary of formatted results for each measure. + A dictionary of measure-specific formatted results. The keys are the + measure names and the values are the respective results. """ formatted = {} for observation_details in self._results_context.observations.values(): @@ -195,8 +197,10 @@ def register_stratification( requires_columns: List[str] = [], requires_values: List[str] = [], ) -> None: - """Manager-level stratification registration. Adds a stratification - to the :class:`ResultsContext ` + """Manager-level stratification registration. + + Adds a stratification to the + :class:`ResultsContext ` as well as the stratification's required resources to this manager. Parameters @@ -305,7 +309,9 @@ def register_observation( requires_values: List[str], **kwargs, ) -> None: - """Manager-level observation registration. Adds an observation to the + """Manager-level observation registration. + + Adds an observation to the :class:`ResultsContext ` as well as the observation's required resources to this manager. diff --git a/src/vivarium/framework/results/observation.py b/src/vivarium/framework/results/observation.py index a9092fd1d..c2be68299 100644 --- a/src/vivarium/framework/results/observation.py +++ b/src/vivarium/framework/results/observation.py @@ -15,6 +15,7 @@ :class:`UnstratifiedObservation` or a :class:`StratifiedObservation`. More specialized implementations of these classes involve defining the various methods provided as attributes to the parent class. + """ import itertools @@ -38,6 +39,7 @@ class BaseObservation(ABC): This class includes an :meth:`observe ` method that determines whether to observe results for a given event. + """ name: str @@ -75,7 +77,21 @@ def observe( df: Union[pd.DataFrame, DataFrameGroupBy], stratifications: Optional[tuple[str, ...]], ) -> Optional[pd.DataFrame]: - # """Determine whether to observe the given event and, if so, gather the results.""" + """Determine whether to observe the given event, and if so, gather the results. + + Parameters + ---------- + event + The event to observe. + df + The population or population grouped by the stratifications. + stratifications + The stratifications to use for the observation. + + Returns + ------- + The results of the observation. + """ if not self.to_observe(event): return None else: @@ -110,6 +126,7 @@ class UnstratifiedObservation(BaseObservation): Method or function that formats the raw observation results. to_observe Method or function that determines whether to perform an observation on this Event. + """ def __init__( @@ -139,7 +156,19 @@ def create_empty_df( requested_stratification_names: set[str], registered_stratifications: list[Stratification], ) -> pd.DataFrame: - """Initialize an empty dataframe.""" + """Initialize an empty dataframe. + + Parameters + ---------- + requested_stratification_names + The names of the stratifications requested for this observation. + registered_stratifications + The list of all registered stratifications. + + Returns + ------- + An empty DataFrame. + """ return pd.DataFrame() @@ -174,6 +203,7 @@ class StratifiedObservation(BaseObservation): Method or function that computes the quantity for this observation. to_observe Method or function that determines whether to perform an observation on this Event. + """ def __init__( @@ -207,7 +237,24 @@ def create_expanded_df( requested_stratification_names: set[str], registered_stratifications: list[Stratification], ) -> pd.DataFrame: - """Initialize a dataframe of 0s with complete set of stratifications as the index.""" + """Initialize a dataframe of 0s with complete set of stratifications as the index. + + Parameters + ---------- + requested_stratification_names + The names of the stratifications requested for this observation. + registered_stratifications + The list of all registered stratifications. + + Returns + ------- + An empty DataFrame with the complete set of stratifications as the index. + + Notes + ----- + If no stratifications are requested, then we are aggregating over the + entire population and a single-row index named 'stratification' is created. + """ # Set up the complete index of all used stratifications requested_and_registered_stratifications = [ @@ -254,7 +301,6 @@ def get_complete_stratified_results( Returns ------- - pandas.DataFrame The results of the observation. """ df = self._aggregate(pop_groups, self.aggregator_sources, self.aggregator) @@ -328,6 +374,7 @@ class AddingObservation(StratifiedObservation): Method or function that computes the quantity for this observation. to_observe Method or function that determines whether to perform an observation on this Event. + """ def __init__( @@ -359,6 +406,17 @@ def add_results( ) -> pd.DataFrame: """Add newly-observed results to the existing results. + Parameters + ---------- + existing_results + The existing results DataFrame. + new_observations + The new observations DataFrame. + + Returns + ------- + The new results added to the existing results. + Notes ----- If the new observations contain columns not present in the existing results, @@ -399,6 +457,7 @@ class ConcatenatingObservation(UnstratifiedObservation): Method or function that formats the raw observation results. to_observe Method or function that determines whether to perform an observation on this Event. + """ def __init__( @@ -429,7 +488,19 @@ def get_results_of_interest(self, pop: pd.DataFrame) -> pd.DataFrame: def concatenate_results( existing_results: pd.DataFrame, new_observations: pd.DataFrame ) -> pd.DataFrame: - """Concatenate the existing results with the new observations.""" + """Concatenate the existing results with the new observations. + + Parameters + ---------- + existing_results + The existing results. + new_observations + The new observations. + + Returns + ------- + The new results concatenated to the existing results. + """ if existing_results.empty: return new_observations return pd.concat([existing_results, new_observations], axis=0).reset_index(drop=True) diff --git a/src/vivarium/framework/results/observer.py b/src/vivarium/framework/results/observer.py index 5106b087b..f48e8d02a 100644 --- a/src/vivarium/framework/results/observer.py +++ b/src/vivarium/framework/results/observer.py @@ -10,6 +10,7 @@ The provided :class:`Observer` class is an abstract base class that should be subclassed by concrete observers. Each concrete observer is required to implement a `register_observations` method that registers all required observations. + """ from abc import ABC, abstractmethod @@ -25,6 +26,7 @@ class Observer(Component, ABC): Notes ----- A `register_observation` method must be defined in the subclass. + """ def __init__(self) -> None: diff --git a/src/vivarium/framework/results/stratification.py b/src/vivarium/framework/results/stratification.py index 455d5b070..c560649b1 100644 --- a/src/vivarium/framework/results/stratification.py +++ b/src/vivarium/framework/results/stratification.py @@ -2,6 +2,7 @@ =============== Stratifications =============== + """ from dataclasses import dataclass @@ -24,6 +25,7 @@ class Stratification: This class includes a :meth:`stratify ` method that produces an output column by calling the mapper on the source columns. + """ name: str @@ -85,9 +87,11 @@ def __post_init__(self) -> None: def stratify(self, population: pd.DataFrame) -> pd.Series: """Apply the `mapper` to the population `sources` columns to create a new - Series to be added to the population. Any `excluded_categories` - (which have already been removed from `categories`) will be converted - to NaNs in the new column and dropped later at the observation level. + Series to be added to the population. + + Any `excluded_categories` (which have already been removed from `categories`) + will be converted to NaNs in the new column and dropped later at the + observation level. Parameters ---------- @@ -96,7 +100,6 @@ def stratify(self, population: pd.DataFrame) -> pd.Series: Returns ------- - pandas.Series A Series containing the mapped values to be used for stratifying. Raises @@ -132,12 +135,11 @@ def _default_mapper(pop: pd.DataFrame) -> pd.Series: Parameters ---------- pop - A DataFrame containing the data to be stratified. + The data to be stratified. Returns ------- - pandas.Series - A Series containing the data to be stratified. + The squeezed data to be stratified. Notes ----- diff --git a/src/vivarium/framework/state_machine.py b/src/vivarium/framework/state_machine.py index e377d6ebd..c74252dd4 100644 --- a/src/vivarium/framework/state_machine.py +++ b/src/vivarium/framework/state_machine.py @@ -39,7 +39,6 @@ def _next_state( A set of potential transitions available to the simulants. population_view A view of the internal state of the simulation. - """ if len(transition_set) == 0 or index.empty: return @@ -79,11 +78,9 @@ def _groupby_new_state( Returns ------- - List[Tuple[str, pandas.Index] The first item in each tuple is the name of an output state and the second item is a `pandas.Index` representing the simulants to transition into that state. - """ groups = pd.Series(index).groupby( pd.Categorical(decisions.values, categories=outputs), observed=False @@ -111,7 +108,7 @@ def _process_trigger(trigger): class Transition(Component): """A process by which an entity might change into a particular state. - Parameters + Attributes ---------- input_state The start state of the entity that undergoes the transition. @@ -120,7 +117,8 @@ class Transition(Component): probability_func A method or function that describing the probability of this transition occurring. - + triggered + A flag indicating whether this transition is triggered by some event. """ ##################### @@ -228,7 +226,6 @@ def next_state( When this transition is occurring. population_view A view of the internal state of the simulation. - """ return _next_state(index, event_time, self.transition_set, population_view) @@ -245,7 +242,6 @@ def transition_effect( The time at which this transition occurs. population_view A view of the internal state of the simulation. - """ population_view.update(pd.Series(self.state_id, index=index)) self.transition_side_effect(index, event_time) @@ -260,7 +256,6 @@ def add_transition(self, transition: Transition) -> None: ---------- transition The transition to add - """ self.transition_set.append(transition) @@ -288,15 +283,17 @@ class TransientState(State, Transient): class TransitionSet(Component): """A container for state machine transitions. - Parameters + Attributes ---------- state_id The unique name of the state that instantiated this TransitionSet. Typically a string but any object implementing __str__ will do. - iterable - Any iterable whose elements are `Transition` objects. allow_null_transition Specified whether it is possible not to transition on a given time-step + transitions + A list of transitions that can be taken from this state. + random + The randomness stream. """ @@ -331,7 +328,6 @@ def setup(self, builder: "Builder") -> None: builder Interface to several simulation tools including access to common random number generation, in particular. - """ self.random = builder.randomness.get_stream(self.name) @@ -349,12 +345,9 @@ def choose_new_state(self, index: pd.Index) -> Tuple[List, pd.Series]: Returns ------- - List - The possible end states of this set of transitions. - pandas.Series - A series containing the name of the next state for each simulant + A tuple of the possible end states of this set of transitions and a + series containing the name of the next state for each simulant in the index. - """ outputs, probabilities = zip( *[ @@ -397,12 +390,10 @@ def _normalize_probabilities(self, outputs, probabilities): Returns ------- - List - The original output list expanded to include a null transition (a - transition back to the starting state) if requested. - numpy.ndarray - The original probabilities rescaled to sum to 1 and potentially - expanded to include a null transition weight. + A tuple of the original output list expanded to include a null transition + (a transition back to the starting state) if requested and the original + probabilities rescaled to sum to 1 and potentially expanded to include + a null transition weight. """ outputs = list(outputs) @@ -501,7 +492,6 @@ def transition(self, index: pd.Index, event_time: "Time") -> None: An iterable of integer labels for the simulants. event_time The time at which this transition occurs. - """ for state, affected in self._get_state_pops(index): if not affected.empty: @@ -531,7 +521,13 @@ def get_initialization_parameters(self) -> Dict[str, Any]: """ Gets the values of the state column specified in the __init__`. - Note: this retrieves the value of the attribute at the time of calling + Returns + ------- + The value of the state column. + + Notes + ----- + This retrieves the value of the attribute at the time of calling which is not guaranteed to be the same as the original value. """ diff --git a/src/vivarium/framework/time.py b/src/vivarium/framework/time.py index 15da5e8cd..5a9df6639 100644 --- a/src/vivarium/framework/time.py +++ b/src/vivarium/framework/time.py @@ -184,8 +184,10 @@ def move_simulants_to_end(self, index: pd.Index) -> None: self._simulants_to_snooze = self._simulants_to_snooze.union(index) def step_size_post_processor(self, values: List[NumberLike], _) -> pd.Series: - """Computes the largest feasible step size for each simulant. This is the smallest component-modified - step size (rounded down to increments of the minimum step size), or the global step size, whichever is larger. + """Computes the largest feasible step size for each simulant. + + This is the smallest component-modified step size (rounded down to increments + of the minimum step size), or the global step size, whichever is larger. If no components modify the step size, we default to the global step size. Parameters @@ -195,10 +197,7 @@ def step_size_post_processor(self, values: List[NumberLike], _) -> pd.Series: Returns ------- - pandas.Series The largest feasible step size for each simulant - - """ min_modified = pd.DataFrame(values).min(axis=0).fillna(self.standard_step_size) diff --git a/src/vivarium/framework/utilities.py b/src/vivarium/framework/utilities.py index 6e87ae5bf..9ad9137f9 100644 --- a/src/vivarium/framework/utilities.py +++ b/src/vivarium/framework/utilities.py @@ -54,8 +54,12 @@ def import_by_path(path: str) -> Callable: Parameters ---------- - path: - Path to object to import + path + Path to object to import + + Returns + ------- + The imported class or function """ module_path, _, class_name = path.rpartition(".") diff --git a/src/vivarium/framework/values.py b/src/vivarium/framework/values.py index 652a2a25f..354296d06 100644 --- a/src/vivarium/framework/values.py +++ b/src/vivarium/framework/values.py @@ -14,7 +14,7 @@ """ from collections import defaultdict -from typing import Any, Callable, Iterable, List, Tuple +from typing import Any, Callable, Iterable, List, Optional, Tuple import pandas as pd @@ -49,9 +49,7 @@ def replace_combiner(value: Any, mutator: Callable, *args: Any, **kwargs: Any) - Returns ------- - Any A modified version of the input value. - """ args = list(args) + [value] return mutator(*args, **kwargs) @@ -78,7 +76,6 @@ def list_combiner(value: List, mutator: Callable, *args: Any, **kwargs: Any) -> ------- The input list with new mutator portion of the pipeline value appended to it. - """ value.append(mutator(*args, **kwargs)) return value @@ -103,9 +100,7 @@ def rescale_post_processor(value: NumberLike, manager: "ValuesManager") -> Numbe Returns ------- - Union[numpy.ndarray, pandas.Series, pandas.DataFrame, numbers.Number] The annual rates rescaled to the size of the current time step size. - """ if hasattr(value, "index"): return value.mul( @@ -148,10 +143,8 @@ def union_post_processor(values: List[NumberLike], _) -> NumberLike: Returns ------- - Union[numpy.ndarray, pandas.Series, pandas.DataFrame, numbers.Number] The probability over the union of the sample spaces represented by the original probabilities. - """ # if there is only one value, return the value if len(values) == 1: @@ -229,7 +222,6 @@ def __call__(self, *args, skip_post_processor=False, **kwargs): ------ DynamicValueError If the pipeline is invoked without a source set. - """ return self._call(*args, skip_post_processor=skip_post_processor, **kwargs) @@ -317,7 +309,6 @@ def register_value_producer( See Also -------- :meth:`ValuesInterface.register_value_producer` - """ pipeline = self._register_value_producer( value_name, source, preferred_combiner, preferred_post_processor @@ -366,7 +357,7 @@ def register_value_modifier( requires_columns: List[str] = (), requires_values: List[str] = (), requires_streams: List[str] = (), - ): + ) -> None: """Marks a ``Callable`` as the modifier of a named value. Parameters @@ -391,7 +382,6 @@ def register_value_modifier( requires_streams A list of the randomness streams that need to be properly sourced before the pipeline modifier is called. - """ modifier_name = self._get_modifier_name(modifier) @@ -405,7 +395,7 @@ def register_value_modifier( ) self.resources.add_resources("value_modifier", [name], modifier, dependencies) - def get_value(self, name): + def get_value(self, name) -> Pipeline: """Retrieve the pipeline representing the named value. Parameters @@ -419,7 +409,6 @@ def get_value(self, name): should be identical to the arguments to the pipeline source (frequently just a :class:`pandas.Index` representing the simulants). - """ return self._pipelines[name] # May create a pipeline. @@ -501,7 +490,7 @@ def register_value_producer( requires_values: List[str] = (), requires_streams: List[str] = (), preferred_combiner: Callable = replace_combiner, - preferred_post_processor: Callable = None, + preferred_post_processor: Optional[Callable] = None, ) -> Pipeline: """Marks a ``Callable`` as the producer of a named value. @@ -536,9 +525,7 @@ def register_value_producer( Returns ------- - Pipeline A callable reference to the named dynamic value pipeline. - """ return self._manager.register_value_producer( value_name, @@ -586,9 +573,7 @@ def register_rate_producer( Returns ------- - Pipeline A callable reference to the named dynamic rate pipeline. - """ return self.register_value_producer( rate_name, @@ -606,7 +591,7 @@ def register_value_modifier( requires_columns: List[str] = (), requires_values: List[str] = (), requires_streams: List[str] = (), - ): + ) -> None: """Marks a ``Callable`` as the modifier of a named value. Parameters @@ -631,7 +616,6 @@ def register_value_modifier( requires_streams A list of the randomness streams that need to be properly sourced before the pipeline modifier is called. - """ self._manager.register_value_modifier( value_name, modifier, requires_columns, requires_values, requires_streams diff --git a/src/vivarium/interface/cli.py b/src/vivarium/interface/cli.py index bf928710b..84c3c42a9 100644 --- a/src/vivarium/interface/cli.py +++ b/src/vivarium/interface/cli.py @@ -113,7 +113,6 @@ def run( is provided, a subdirectory will be created with the same name as the MODEL_SPECIFICATION if one does not exist. Results will be written to a further subdirectory named after the start time of the simulation run. - """ if verbose and quiet: raise click.UsageError("Cannot be both verbose and quiet.") diff --git a/src/vivarium/interface/interactive.py b/src/vivarium/interface/interactive.py index c1462a2b2..80dd376b6 100644 --- a/src/vivarium/interface/interactive.py +++ b/src/vivarium/interface/interactive.py @@ -14,7 +14,7 @@ """ from math import ceil -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional import pandas as pd @@ -42,7 +42,7 @@ def setup(self): super().setup() self.initialize_simulants() - def step(self, step_size: Timedelta = None): + def step(self, step_size: Optional[Timedelta] = None) -> None: """Advance the simulation one step. Parameters @@ -75,9 +75,7 @@ def run(self, with_logging: bool = True) -> int: Returns ------- - int The number of steps the simulation took. - """ return self.run_until(self._clock.stop_time, with_logging=with_logging) @@ -96,9 +94,7 @@ def run_for(self, duration: Timedelta, with_logging: bool = True) -> int: Returns ------- - int The number of steps the simulation took. - """ return self.run_until(self._clock.time + duration, with_logging=with_logging) @@ -118,9 +114,7 @@ def run_until(self, end_time: Time, with_logging: bool = True) -> int: Returns ------- - int The number of steps the simulation took. - """ if not ( isinstance(end_time, type(self._clock.time)) @@ -136,7 +130,10 @@ def run_until(self, end_time: Time, with_logging: bool = True) -> int: return iterations def take_steps( - self, number_of_steps: int = 1, step_size: Timedelta = None, with_logging: bool = True + self, + number_of_steps: int = 1, + step_size: Optional[Timedelta] = None, + with_logging: bool = True, ): """Run the simulation for the given number of steps. @@ -150,7 +147,6 @@ def take_steps( with_logging Whether or not to log the simulation steps. Only works in an ipython environment. - """ if not isinstance(number_of_steps, int): raise ValueError("Number of steps must be an integer.") @@ -171,6 +167,9 @@ def get_population(self, untracked: bool = False) -> pd.DataFrame: Whether or not to return simulants who are no longer being tracked by the simulation. + Returns + ------- + The population state table. """ return self._population.get_population(untracked) @@ -197,6 +196,10 @@ def get_listeners(self, event_type: str) -> Dict[int, List[Callable]]: event_type The type of event to grab the listeners for. + Returns + ------- + A dictionary that maps each priority level of the named event's + listeners to a list of listeners at that level. """ if event_type not in self._events: raise ValueError(f"No event {event_type} in system.") @@ -213,6 +216,9 @@ def get_emitter(self, event_type: str) -> Callable: event_type The type of event to grab the listeners for. + Returns + ------- + The callable that emits the named event. """ if event_type not in self._events: raise ValueError(f"No event {event_type} in system.") @@ -223,9 +229,7 @@ def list_components(self) -> Dict[str, Any]: Returns ------- - Dict[str, Any] A dictionary mapping component names to components. - """ return self._component_manager.list_components() @@ -237,10 +241,10 @@ def get_component(self, name: str) -> Any: ---------- name A component name. + Returns ------- A component that has the name ``name`` else None. - """ return self._component_manager.get_component(name) diff --git a/src/vivarium/interface/utilities.py b/src/vivarium/interface/utilities.py index fd806b618..afcab8a11 100644 --- a/src/vivarium/interface/utilities.py +++ b/src/vivarium/interface/utilities.py @@ -121,9 +121,7 @@ def get_output_model_name_string( Returns ------- - str A model name string for use in output labeling. - """ if artifact_path: model_name = Path(artifact_path).stem @@ -141,7 +139,22 @@ def get_output_root( results_directory: Union[str, Path], model_specification_file: Union[str, Path], artifact_path: Union[str, Path], -): +) -> Path: + """Create a root directory for output files. + + Parameters + ---------- + results_directory + Directory to store the results in. + model_specification_file + Path to the model specification file. + artifact_path + Path to the artifact file. + + Returns + ------- + The date-stamped output root directory. + """ launch_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") model_name = get_output_model_name_string(artifact_path, model_specification_file) output_root = Path(results_directory + f"/{model_name}/{launch_time}") diff --git a/src/vivarium/manager.py b/src/vivarium/manager.py index cb1e8d065..c6cf24c41 100644 --- a/src/vivarium/manager.py +++ b/src/vivarium/manager.py @@ -5,6 +5,7 @@ A base Manager class to be used to create manager for use in ``vivarium`` simulations. + """ from typing import TYPE_CHECKING, Any, Dict @@ -15,9 +16,9 @@ class Manager: CONFIGURATION_DEFAULTS: Dict[str, Any] = {} - """ - A dictionary containing the defaults for any configurations managed by this + """A dictionary containing the defaults for any configurations managed by this manager. An empty dictionary indicates no managed configurations. + """ ############## @@ -26,8 +27,7 @@ class Manager: @property def configuration_defaults(self) -> Dict[str, Any]: - """ - Provides a dictionary containing the defaults for any configurations + """Provides a dictionary containing the defaults for any configurations managed by this manager. These default values will be stored at the `component_configs` layer of the @@ -35,7 +35,6 @@ def configuration_defaults(self) -> Dict[str, Any]: Returns ------- - Dict[str, Any] A dictionary containing the defaults for any configurations managed by this manager. """ diff --git a/src/vivarium/testing_utilities.py b/src/vivarium/testing_utilities.py index fcf8ede26..7745369f0 100644 --- a/src/vivarium/testing_utilities.py +++ b/src/vivarium/testing_utilities.py @@ -159,6 +159,7 @@ def build_table( value_columns: List = ["value"], ) -> pd.DataFrame: """ + Parameters ---------- value @@ -173,6 +174,7 @@ def build_table( A list of value columns that will appear in the returned lookup table Returns + ------- A pandas dataframe that has the cartesian product of the range of all parameter columns and the values of the key columns. """ diff --git a/tests/framework/components/test_parser.py b/tests/framework/components/test_parser.py index b8326582e..db85baceb 100644 --- a/tests/framework/components/test_parser.py +++ b/tests/framework/components/test_parser.py @@ -2,7 +2,7 @@ import pytest import yaml -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from tests.helpers import MockComponentA, MockComponentB from vivarium.framework.components.parser import ComponentConfigurationParser, ParsingError diff --git a/tests/framework/results/test_context.py b/tests/framework/results/test_context.py index e006f464a..412273020 100644 --- a/tests/framework/results/test_context.py +++ b/tests/framework/results/test_context.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd import pytest -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from loguru import logger from pandas.core.groupby.generic import DataFrameGroupBy diff --git a/tests/framework/results/test_interface.py b/tests/framework/results/test_interface.py index 6261b455d..db3557257 100644 --- a/tests/framework/results/test_interface.py +++ b/tests/framework/results/test_interface.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from loguru import logger from tests.framework.results.helpers import BASE_POPULATION, FAMILIARS diff --git a/tests/framework/results/test_manager.py b/tests/framework/results/test_manager.py index f80fa8d80..8e998a732 100644 --- a/tests/framework/results/test_manager.py +++ b/tests/framework/results/test_manager.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd import pytest -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from loguru import logger from pandas.api.types import CategoricalDtype diff --git a/tests/framework/results/test_observer.py b/tests/framework/results/test_observer.py index b0929f7ea..6b34c71f8 100644 --- a/tests/framework/results/test_observer.py +++ b/tests/framework/results/test_observer.py @@ -1,5 +1,5 @@ import pytest -from layered_config_tree import LayeredConfigTree +from layered_config_tree.main import LayeredConfigTree from vivarium.framework.results.observer import Observer