From 8340b2d78f2b40bc365862b24477a0190ad2e2c2 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Fri, 17 Jul 2020 10:56:38 -0700 Subject: [PATCH] Expose FairseqOptimizer.param_groups property --- fairseq/optim/fairseq_optimizer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 3242a92a35..b1b9c76edb 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -41,20 +41,24 @@ def optimizer_config(self): @property def params(self): """Return an iterable of the parameters held by the optimizer.""" - for param_group in self.optimizer.param_groups: + for param_group in self.param_groups: for p in param_group['params']: yield p + @property + def param_groups(self): + return self.optimizer.param_groups + def __getstate__(self): return self._optimizer.__getstate__() def get_lr(self): """Return the current learning rate.""" - return self.optimizer.param_groups[0]['lr'] + return self.param_groups[0]['lr'] def set_lr(self, lr): """Set the learning rate.""" - for param_group in self.optimizer.param_groups: + for param_group in self.param_groups: param_group['lr'] = lr def state_dict(self): @@ -73,7 +77,7 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): if optimizer_overrides is not None and len(optimizer_overrides) > 0: # override learning rate, momentum, etc. with latest values - for group in self.optimizer.param_groups: + for group in self.param_groups: group.update(optimizer_overrides) def backward(self, loss):