Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Refine automatic mixed precision support via hyper param #1681

Merged
merged 7 commits into from
Aug 30, 2019

Conversation

vinhngx
Copy link
Contributor

@vinhngx vinhngx commented Aug 28, 2019

In continuation of #1637 and in response to @afrozenator 's comments in #1680

In this PR, we re-organize automatic mixed precision training support to provide a cleaner implementation and an easier interface via using hyper parameters.

In particular, GPU automatic mixed precision training can now be enabled via setting a flag (and correspondingly a so-named hyper-parameter) gpu_automatic_mixed_precision for all tensor2tensor models, for example:

Transformer

PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_big
DATA_DIR=/data/translate_ende_wmt32k
TRAIN_DIR=/tmp/$MODEL-$HPARAMS

t2t-trainer \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR \
  --train_steps=100000 \
  --eval_steps=1000 \
  --gpu_automatic_mixed_precision=True

Resnet:

PROBLEM=image_imagenet224
MODEL=resnet
HPARAMS=resnet_50
DATA_DIR=/data/ImageNet
TRAIN_DIR=/tmp/$HPARAMS

t2t-trainer \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR \
  --hparams='batch_size=256' \
  --worker_gpu=8 \
  --gpu_automatic_mixed_precision=True

This is opposed to the previous approaches of setting the OS flag TF_ENABLE_AUTO_MIXED_PRECISION which is a non-programatic approach, or passing the flag gpu_auto_mixed_precision directly to the optimizer (which will require modification of individual models to make call to optimizer with mixed precision training option).

@googlebot googlebot added the cla: yes PR author has signed CLA label Aug 28, 2019
@@ -71,8 +69,7 @@ def optimize(loss,
opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu)
if use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
if gpu_auto_mixed_precision or os.environ.get(
"TF_ENABLE_AUTO_MIXED_PRECISION", "0") == "1":
if hparams.gpu_automatic_mixed_precision:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get(hparams, "gpu_automatic_mixed_precision", False) is preferable -- since people may pass an hparam that doesn't have this param -- for example in tests etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good one. I fixed this

memory_height=1
memory_height=1,
# Whether to use GPU automatic mixed precision (via graph rewrite)
gpu_automatic_mixed_precision=False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good, but as in your earlier PR, based on a flag can you set this to true?

i.e. after we make the hparams in t2t_trainer, based on your flag, flip this on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add the flag again to trainer and turn hparams on accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@afrozenator afrozenator left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments -- thanks for the changes

@vinhngx
Copy link
Contributor Author

vinhngx commented Aug 30, 2019

Thanks for the feedbacks @afrozenator . Let me know if the latest revision works.

@afrozenator
Copy link
Contributor

Thanks a lot @vinhngx for contributing this in the first place and now making it better!

Will merge it in shortly.

@vinhngx
Copy link
Contributor Author

vinhngx commented Aug 30, 2019

great thanks. I'm closing #1680 then.

@afrozenator afrozenator merged commit d973bc8 into tensorflow:master Aug 30, 2019
tensorflow-copybara pushed a commit that referenced this pull request Aug 30, 2019
PiperOrigin-RevId: 266390503
# for free to subscribe to this conversation on GitHub. Already have an account? #.
Labels
cla: yes PR author has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants