Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge of PR #1645
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 260743903
  • Loading branch information
dong-s authored and copybara-github committed Jul 30, 2019
1 parent bba231f commit 4e0daf5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tensor2tensor/layers/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4):
memory inputs to attention.
The use of relative position is possible because sin(a+b) and cos(a+b) can be
expressed in terms of b, sin(a) and cos(a).
experessed in terms of b, sin(a) and cos(a).
x is a Tensor with n "positional" dimensions, e.g. one dimension for a
sequence or two dimensions for an image
Expand Down
13 changes: 1 addition & 12 deletions tensor2tensor/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function
import numpy as np
import os

from tensor2tensor.layers import common_layers
from tensor2tensor.utils import adafactor as adafactor_lib
Expand All @@ -41,7 +40,7 @@ def _mixed_precision_is_enabled(hparams):
return activation_dtype == tf.float16 and weight_dtype == tf.float32


def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None, gpu_auto_mixed_precision=False):
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
"""Minimize loss."""
loss = weight_decay_and_noise(loss, hparams, learning_rate)
loss = tf.identity(loss, name="total_loss")
Expand All @@ -66,16 +65,6 @@ def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None, gpu_au
opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu)
if use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
if os.environ.get('TF_ENABLE_AUTO_MIXED_PRECISION', default='0') == '1' or gpu_auto_mixed_precision:
if use_tpu:
raise(RuntimeError("GPU auto mixed precision cannot be used with TPU"))
elif _mixed_precision_is_enabled(hparams):
raise(RuntimeError("GPU auto mixed precision cannot be used with manual mixed precision"))
else:
setattr(opt, '_use_locking', 'True')
setattr(opt, '_name', 'ConditionalOptimizer')
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)

opt_summaries = []
if common_layers.should_generate_summaries():
tf.summary.scalar("learning_rate", learning_rate)
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def encode(self, input_str):
def decode(self, integers):
"""List of ints to str."""
integers = list(np.squeeze(integers))
return self.encoders['targets'].decode(integers)
return self.encoders["targets"].decode(integers)

def encode_list(self, integers):
"""List of ints to list of str."""
integers = list(np.squeeze(integers))
return self.encoders['inputs'].decode_list(integers)
return self.encoders["inputs"].decode_list(integers)

def decode_list(self, integers):
"""List of ints to list of str."""
integers = list(np.squeeze(integers))
return self.encoders['targets'].decode_list(integers)
return self.encoders["targets"].decode_list(integers)

def get_vis_data_from_string(self, sess, input_string):
"""Constructs the data needed for visualizing attentions.
Expand Down

0 comments on commit 4e0daf5

Please # to comment.