diff --git a/neuralmonkey/runners/base_runner.py b/neuralmonkey/runners/base_runner.py index bc60ea1a2..8075563c2 100644 --- a/neuralmonkey/runners/base_runner.py +++ b/neuralmonkey/runners/base_runner.py @@ -4,7 +4,6 @@ import numpy as np import tensorflow as tf -from neuralmonkey.logging import notice from neuralmonkey.model.model_part import GenericModelPart from neuralmonkey.model.feedable import Feedable from neuralmonkey.model.parameterized import Parameterized @@ -25,7 +24,7 @@ class ExecutionResult(NamedTuple( ("scalar_summaries", tf.Summary), ("histogram_summaries", tf.Summary), ("image_summaries", tf.Summary)])): - """A data structure that represents a result of a graph execution. + """A data structure that represents the result of a graph execution. The goal of each runner is to populate this structure and set it as its ``self._result``. @@ -40,8 +39,41 @@ class ExecutionResult(NamedTuple( class GraphExecutor(GenericModelPart): + """The abstract parent class of all graph executors. + + In Neural Monkey, a graph executor is an object that retrieves tensors + from the computational graph. The two major groups of graph executors are + trainers and runners. + + Each graph executor is an instance of `GenericModelPart` class, which means + it has parameterized and feedable dependencies which reference the model + part objects needed to be created in order to compute the tensors of + interest (called "fetches"). + + Every graph executor has a method called `get_executable`, which returns + an `GraphExecutor.Executable` instance, which specifies what tensors to + execute and collects results from the session execution. + """ class Executable(Generic[Executor]): + """Abstract base class for executables. + + Executables are objects associated with the graph executors. Each + executable has two main functions: `next_to_execute` and + `collect_results`. These functions are called in a loop, until + the executable's result has been set. + + To make use of Mypy's type checking, the executables are generic and + are parameterized by the type of their graph executor. Since Python + does not know the concept of nested classes, each executable receives + the instance of the graph executor through its constructor. + + When subclassing `GraphExecutor`, it is also necessary to subclass + the `Executable` class and name it `Executable`, so it overrides the + definition of this class. Following this guideline, the default + implementation of the `get_executable` function on the graph executor + will work without the need of overriding it. + """ def __init__(self, executor: Executor, @@ -110,6 +142,12 @@ def parameterizeds(self) -> Set[Parameterized]: class BaseRunner(GraphExecutor, Generic[MP]): + """Base class for runners. + + Runners are graph executors that retrieve tensors from the model without + changing the model parameters. Each runner has a top-level model part it + relates to. + """ # pylint: disable=too-few-public-methods # Pylint issue here: https://github.com/PyCQA/pylint/issues/2607 @@ -130,12 +168,9 @@ def __init__(self, decoder: MP) -> None: GraphExecutor.__init__(self, {decoder}) self.output_series = output_series + # TODO(tf-data) rename decoder to something more general self.decoder = decoder - if not hasattr(decoder, "data_id"): - notice("Top-level decoder {} does not have the 'data_id' attribute" - .format(decoder)) - @property def decoder_data_id(self) -> Optional[str]: return getattr(self.decoder, "data_id", None)