diff --git a/tensorflow/python/distribute/group_embedding_collective_strategy.py b/tensorflow/python/distribute/group_embedding_collective_strategy.py index 8fa94ea9f7e..a989aaa889c 100644 --- a/tensorflow/python/distribute/group_embedding_collective_strategy.py +++ b/tensorflow/python/distribute/group_embedding_collective_strategy.py @@ -101,7 +101,7 @@ def estimator(self, model_fn, **kwargs): from tensorflow.python.distribute.hvd_strategy import wraps_estimator _estimator = wraps_estimator(_estimator_lib.Estimator) elif self._hb: - _estimator = hb.estimator.Estimator + _estimator = self._hb.estimator.Estimator return _estimator(model_fn, **kwargs) diff --git a/tensorflow/python/distribute/hvd_strategy.py b/tensorflow/python/distribute/hvd_strategy.py index 8a3ae9c3f43..14d221e02e2 100644 --- a/tensorflow/python/distribute/hvd_strategy.py +++ b/tensorflow/python/distribute/hvd_strategy.py @@ -388,20 +388,24 @@ def wraps_optimizer(cls): HvdOptimizer ''' class HvdOptimizer(cls, optimizer.Optimizer): - def __init__(self, *args, **kwargs): - kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\ - HvdContext.get().world_size - super(HvdOptimizer, self).__init__(*args, **kwargs) + def __init__(self, learning_rate=0.001, *args, **kwargs): + learning_rate = learning_rate * HvdContext.get().world_size + super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs) - def compute_gradients(self, loss, **kwargs): - loss = hvd.allreduce(loss, op=hvd.Sum) - return super().compute_gradients(loss, **kwargs) - if isinstance(cls, HvdOptimizer): return cls else: def horovod_optimizer(*args, **kwargs): - return HvdOptimizer(*args, **kwargs) + from horovod.tensorflow import DistributedOptimizer + horovod_args = DistributedOptimizer.__code__.co_varnames + horovod_real_kargs = {} + candidate_keys = list(kwargs.keys()) + for kwarg in candidate_keys: + if kwarg in horovod_args: + value = kwargs[kwarg] + del kwargs[kwarg] + horovod_real_kargs[kwarg] = value + return DistributedOptimizer(HvdOptimizer(*args, **kwargs), **horovod_real_kargs) return horovod_optimizer @@ -478,16 +482,6 @@ def HorovodMonitoredTrainingSession(*args, **kwargs): # pylint: disable=invalid kwargs['config'] = wraps_session_config(kwargs.pop('config', None)) kwargs['is_chief'] = True args = list(args) - if args: - master = args[0] - if not master: - master = '' - args[0] = master - else: - master = kwargs.pop('master', None) - if not master: - master = '' - kwargs['master'] = master prev_monitored_session = _monitored_session.MonitoredSession sess = fn(*args, **kwargs) @@ -1074,10 +1068,14 @@ def __init__(self, model_fn, **kwargs): self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', True) self._predict_drop_remainder = kwargs.pop( 'predict_drop_remainder', True) + config = kwargs.get('config', None) + if config is None: + config = run_config_lib.RunConfig() + else: + kwargs.pop('config') super().__init__( - wraps_model_fn(model_fn, model_dir, kwargs['config']), - **kwargs) + wraps_model_fn(model_fn, model_dir, config), **kwargs) def _assert_members_are_not_overridden(self): r'''disable the overridden check here. @@ -1449,4 +1447,4 @@ def export(export_dir_base, as_text=as_text, clear_devices=clear_devices, strip_default_attrs=strip_default_attrs, - modes=[mode]) \ No newline at end of file + modes=[mode])