Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Adding automatic mixed precision support #1637

Merged
merged 6 commits into from
Jul 30, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion tensor2tensor/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
import numpy as np
import os

from tensor2tensor.layers import common_layers
from tensor2tensor.utils import adafactor as adafactor_lib
Expand All @@ -40,7 +41,7 @@ def _mixed_precision_is_enabled(hparams):
return activation_dtype == tf.float16 and weight_dtype == tf.float32


def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None, gpu_auto_mixed_precision=False):
"""Minimize loss."""
loss = weight_decay_and_noise(loss, hparams, learning_rate)
loss = tf.identity(loss, name="total_loss")
Expand All @@ -65,6 +66,16 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu)
if use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
if os.environ.get('TF_ENABLE_AUTO_MIXED_PRECISION', default='0') == '1' or gpu_auto_mixed_precision:
if use_tpu:
raise(RuntimeError("GPU auto mixed precision cannot be used with TPU"))
elif _mixed_precision_is_enabled(hparams):
raise(RuntimeError("GPU auto mixed precision cannot be used with manual mixed precision"))
else:
setattr(opt, '_use_locking', 'True')
setattr(opt, '_name', 'ConditionalOptimizer')
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)

opt_summaries = []
if common_layers.should_generate_summaries():
tf.summary.scalar("learning_rate", learning_rate)
Expand Down