Skip to content

Commit

Permalink
Added Arguments *args, **kwargs to BaseLogger.attach method (#2034)
Browse files Browse the repository at this point in the history
* Added Arguments *args, **kwargs

* Reformatted

* Updated _test method and added default value of kwargs

* Updated _test method

* fix minor changes

* fixed minor changes

* Reformatted

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
afzal442 and vfdev-5 authored Jun 8, 2021
1 parent 9181716 commit f5b0c5c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
11 changes: 9 additions & 2 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ class BaseLogger(metaclass=ABCMeta):
"""

def attach(
self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList]
self,
engine: Engine,
log_handler: Callable,
event_name: Union[str, Events, CallableEventWithFilter, EventsList],
*args: Any,
**kwargs: Any,
) -> RemovableEventHandle:
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.
Expand All @@ -161,6 +166,8 @@ def attach(
event_name: event to attach the logging handler to. Valid events are from
:class:`~ignite.engine.events.Events` or :class:`~ignite.engine.events.EventsList` or any `event_name`
added by :meth:`~ignite.engine.engine.Engine.register_events`.
args: args forwarded to the `log_handler` method
kwargs: kwargs forwarded to the `log_handler` method
Returns:
:class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler.
Expand All @@ -178,7 +185,7 @@ def attach(
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")

return engine.add_event_handler(event_name, log_handler, self, event_name)
return engine.add_event_handler(event_name, log_handler, self, event_name, *args, **kwargs)

def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle:
"""Shortcut method to attach `OutputHandler` to the logger.
Expand Down
13 changes: 9 additions & 4 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_attach():
n_epochs = 5
data = list(range(50))

def _test(event, n_calls):
def _test(event, n_calls, kwargs={}):

losses = torch.rand(n_epochs * len(data))
losses_iter = iter(losses)
Expand All @@ -117,19 +117,24 @@ def update_fn(engine, batch):

mock_log_handler = MagicMock()

logger.attach(trainer, log_handler=mock_log_handler, event_name=event)
logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs)

trainer.run(data, max_epochs=n_epochs)

if isinstance(event, EventsList):
events = [e for e in event]
else:
events = [event]
calls = [call(trainer, logger, e) for e in events]

if len(kwargs) > 0:
calls = [call(trainer, logger, e, **kwargs) for e in events]
else:
calls = [call(trainer, logger, e) for e in events]

mock_log_handler.assert_has_calls(calls)
assert mock_log_handler.call_count == n_calls

_test(Events.ITERATION_STARTED, len(data) * n_epochs)
_test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0})
_test(Events.ITERATION_COMPLETED, len(data) * n_epochs)
_test(Events.EPOCH_STARTED, n_epochs)
_test(Events.EPOCH_COMPLETED, n_epochs)
Expand Down

0 comments on commit f5b0c5c

Please # to comment.