diff --git a/tensorflow_addons/optimizers/rectified_adam.py b/tensorflow_addons/optimizers/rectified_adam.py index 881840a1c5..c24c5557e3 100644 --- a/tensorflow_addons/optimizers/rectified_adam.py +++ b/tensorflow_addons/optimizers/rectified_adam.py @@ -103,7 +103,7 @@ def __init__( beyond". sma_threshold. A float value. The threshold for simple mean average. - total_steps: An integer. Total number of training steps. + total_steps: An integer value. Total number of training steps. Enable warmup by setting a positive value. warmup_proportion: A floating point value. The proportion of increasing steps. @@ -131,7 +131,7 @@ def __init__( self._set_hyper("decay", self._initial_decay) self._set_hyper("weight_decay", weight_decay) self._set_hyper("sma_threshold", sma_threshold) - self._set_hyper("total_steps", int(total_steps)) + self._set_hyper("total_steps", float(total_steps)) self._set_hyper("warmup_proportion", warmup_proportion) self._set_hyper("min_lr", min_lr) self.epsilon = epsilon or tf.keras.backend.epsilon() @@ -316,7 +316,7 @@ def get_config(self): "sma_threshold": self._serialize_hyperparameter("sma_threshold"), "epsilon": self.epsilon, "amsgrad": self.amsgrad, - "total_steps": self._serialize_hyperparameter("total_steps"), + "total_steps": int(self._serialize_hyperparameter("total_steps")), "warmup_proportion": self._serialize_hyperparameter( "warmup_proportion" ), diff --git a/tensorflow_addons/optimizers/tests/rectified_adam_test.py b/tensorflow_addons/optimizers/tests/rectified_adam_test.py index 040e5e638e..ae003ddb22 100644 --- a/tensorflow_addons/optimizers/tests/rectified_adam_test.py +++ b/tensorflow_addons/optimizers/tests/rectified_adam_test.py @@ -209,3 +209,26 @@ def test_scheduler_serialization(): "class_name": "InverseTimeDecay", "config": wd_scheduler.get_config(), } + + +def test_checkpoint_serialization(tmpdir): + optimizer = RectifiedAdam() + optimizer2 = RectifiedAdam() + + var_0 = tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32) + var_1 = tf.Variable([3.0, 4.0], dtype=tf.dtypes.float32) + + grad_0 = tf.constant([0.1, 0.2], dtype=tf.dtypes.float32) + grad_1 = tf.constant([0.03, 0.04], dtype=tf.dtypes.float32) + + grads_and_vars = list(zip([grad_0, grad_1], [var_0, var_1])) + + optimizer.apply_gradients(grads_and_vars) + + checkpoint = tf.train.Checkpoint(optimizer=optimizer) + checkpoint2 = tf.train.Checkpoint(optimizer=optimizer2) + model_path = str(tmpdir / "rectified_adam_chkpt") + checkpoint.write(model_path) + checkpoint2.read(model_path) + + optimizer2.apply_gradients(grads_and_vars)