Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge of PR #1637
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 260754631
  • Loading branch information
vinhngx authored and copybara-github committed Jul 30, 2019
1 parent d13c331 commit 5bfe69a
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions tensor2tensor/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np

import os
import numpy as np
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import adafactor as adafactor_lib
from tensor2tensor.utils import misc_utils
Expand All @@ -40,7 +41,12 @@ def _mixed_precision_is_enabled(hparams):
return activation_dtype == tf.float16 and weight_dtype == tf.float32


def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
def optimize(loss,
learning_rate,
hparams,
use_tpu=False,
variables=None,
gpu_auto_mixed_precision=False):
"""Minimize loss."""
loss = weight_decay_and_noise(loss, hparams, learning_rate)
loss = tf.identity(loss, name="total_loss")
Expand All @@ -65,6 +71,18 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu)
if use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
if gpu_auto_mixed_precision or os.environ.get(
"TF_ENABLE_AUTO_MIXED_PRECISION", "0") == "1":
if use_tpu:
raise RuntimeError("GPU auto mixed precision cannot be used with TPU")
elif _mixed_precision_is_enabled(hparams):
raise RuntimeError(
"GPU auto mixed precision cannot be used with manual mixed precision")
else:
setattr(opt, "_use_locking", "True")
setattr(opt, "_name", "ConditionalOptimizer")
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)

opt_summaries = []
if common_layers.should_generate_summaries():
tf.summary.scalar("learning_rate", learning_rate)
Expand Down

0 comments on commit 5bfe69a

Please # to comment.