From f5b0c5cb17cdc2a57a7a6790038e50ee652b94b2 Mon Sep 17 00:00:00 2001 From: Afzal Ansari Date: Tue, 8 Jun 2021 22:30:12 +0530 Subject: [PATCH] Added Arguments *args, **kwargs to BaseLogger.attach method (#2034) * Added Arguments *args, **kwargs * Reformatted * Updated _test method and added default value of kwargs * Updated _test method * fix minor changes * fixed minor changes * Reformatted Co-authored-by: vfdev --- ignite/contrib/handlers/base_logger.py | 11 +++++++++-- tests/ignite/contrib/handlers/test_base_logger.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/ignite/contrib/handlers/base_logger.py b/ignite/contrib/handlers/base_logger.py index 60179db95c7..edbc2538fee 100644 --- a/ignite/contrib/handlers/base_logger.py +++ b/ignite/contrib/handlers/base_logger.py @@ -151,7 +151,12 @@ class BaseLogger(metaclass=ABCMeta): """ def attach( - self, engine: Engine, log_handler: Callable, event_name: Union[str, Events, CallableEventWithFilter, EventsList] + self, + engine: Engine, + log_handler: Callable, + event_name: Union[str, Events, CallableEventWithFilter, EventsList], + *args: Any, + **kwargs: Any, ) -> RemovableEventHandle: """Attach the logger to the engine and execute `log_handler` function at `event_name` events. @@ -161,6 +166,8 @@ def attach( event_name: event to attach the logging handler to. Valid events are from :class:`~ignite.engine.events.Events` or :class:`~ignite.engine.events.EventsList` or any `event_name` added by :meth:`~ignite.engine.engine.Engine.register_events`. + args: args forwarded to the `log_handler` method + kwargs: kwargs forwarded to the `log_handler` method Returns: :class:`~ignite.engine.events.RemovableEventHandle`, which can be used to remove the handler. @@ -178,7 +185,7 @@ def attach( if event_name not in State.event_to_attr: raise RuntimeError(f"Unknown event name '{event_name}'") - return engine.add_event_handler(event_name, log_handler, self, event_name) + return engine.add_event_handler(event_name, log_handler, self, event_name, *args, **kwargs) def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any) -> RemovableEventHandle: """Shortcut method to attach `OutputHandler` to the logger. diff --git a/tests/ignite/contrib/handlers/test_base_logger.py b/tests/ignite/contrib/handlers/test_base_logger.py index 02aeabcb2a8..ed5a5a37e47 100644 --- a/tests/ignite/contrib/handlers/test_base_logger.py +++ b/tests/ignite/contrib/handlers/test_base_logger.py @@ -103,7 +103,7 @@ def test_attach(): n_epochs = 5 data = list(range(50)) - def _test(event, n_calls): + def _test(event, n_calls, kwargs={}): losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) @@ -117,7 +117,7 @@ def update_fn(engine, batch): mock_log_handler = MagicMock() - logger.attach(trainer, log_handler=mock_log_handler, event_name=event) + logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs) trainer.run(data, max_epochs=n_epochs) @@ -125,11 +125,16 @@ def update_fn(engine, batch): events = [e for e in event] else: events = [event] - calls = [call(trainer, logger, e) for e in events] + + if len(kwargs) > 0: + calls = [call(trainer, logger, e, **kwargs) for e in events] + else: + calls = [call(trainer, logger, e) for e in events] + mock_log_handler.assert_has_calls(calls) assert mock_log_handler.call_count == n_calls - _test(Events.ITERATION_STARTED, len(data) * n_epochs) + _test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0}) _test(Events.ITERATION_COMPLETED, len(data) * n_epochs) _test(Events.EPOCH_STARTED, n_epochs) _test(Events.EPOCH_COMPLETED, n_epochs)