From 3e1788b7b96df719bfcc701fc4134c6ed3471442 Mon Sep 17 00:00:00 2001 From: dathudeptrai Date: Thu, 19 Nov 2020 14:00:41 +0700 Subject: [PATCH 1/4] =?UTF-8?q?=F0=9F=A4=98=20Support=20Multi-GPU=20gradie?= =?UTF-8?q?nt=20Accumulate=20for=20trainer.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tensorflow_tts/optimizers/__init__.py | 1 + .../optimizers/gradient_accumulate.py | 89 ++++++ tensorflow_tts/trainers/base_trainer.py | 276 +++++++++++++----- 3 files changed, 294 insertions(+), 72 deletions(-) create mode 100644 tensorflow_tts/optimizers/gradient_accumulate.py diff --git a/tensorflow_tts/optimizers/__init__.py b/tensorflow_tts/optimizers/__init__.py index ded1f4b0..a490ed67 100755 --- a/tensorflow_tts/optimizers/__init__.py +++ b/tensorflow_tts/optimizers/__init__.py @@ -1 +1,2 @@ from tensorflow_tts.optimizers.adamweightdecay import AdamWeightDecay, WarmUp +from tensorflow_tts.optimizers.gradient_accumulate import GradientAccumulator diff --git a/tensorflow_tts/optimizers/gradient_accumulate.py b/tensorflow_tts/optimizers/gradient_accumulate.py new file mode 100644 index 00000000..bbbd939b --- /dev/null +++ b/tensorflow_tts/optimizers/gradient_accumulate.py @@ -0,0 +1,89 @@ +"""Gradient Accummlate for training TF2 custom training loop. +Copy from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py. +""" + + +import re + +import tensorflow as tf + + +class GradientAccumulator(object): + """Gradient accumulation utility. + When used with a distribution strategy, the accumulator should be called in a + replica context. Gradients will be accumulated locally on each replica and + without synchronization. Users should then call ``.gradients``, scale the + gradients if required, and pass the result to ``apply_gradients``. + """ + + # We use the ON_READ synchronization policy so that no synchronization is + # performed on assignment. To get the value, we call .value() which returns the + # value on the current replica without synchronization. + + def __init__(self): + """Initializes the accumulator.""" + self._gradients = [] + self._accum_steps = None + + @property + def step(self): + """Number of accumulated steps.""" + if self._accum_steps is None: + self._accum_steps = tf.Variable( + tf.constant(0, dtype=tf.int64), + trainable=False, + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + + return self._accum_steps.value() + + @property + def gradients(self): + """The accumulated gradients on the current replica.""" + if not self._gradients: + raise ValueError( + "The accumulator should be called first to initialize the gradients" + ) + return list( + gradient.value() if gradient is not None else gradient + for gradient in self._gradients + ) + + def __call__(self, gradients): + """Accumulates :obj:`gradients` on the current replica.""" + if not self._gradients: + _ = self.step # Create the step variable. + self._gradients.extend( + [ + tf.Variable( + tf.zeros_like(gradient), + trainable=False, + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + if gradient is not None + else gradient + for gradient in gradients + ] + ) + if len(gradients) != len(self._gradients): + raise ValueError( + "Expected %s gradients, but got %d" + % (len(self._gradients), len(gradients)) + ) + + for accum_gradient, gradient in zip(self._gradients, gradients): + if accum_gradient is not None and gradient is not None: + accum_gradient.assign_add(gradient, read_value=False) + + self._accum_steps.assign_add(1) + + def reset(self): + """Resets the accumulated gradients on the current replica.""" + if not self._gradients: + return + self._accum_steps.assign(0) + for gradient in self._gradients: + if gradient is not None: + gradient.assign(tf.zeros_like(gradient), read_value=False) diff --git a/tensorflow_tts/trainers/base_trainer.py b/tensorflow_tts/trainers/base_trainer.py index 0996cd65..3920f810 100755 --- a/tensorflow_tts/trainers/base_trainer.py +++ b/tensorflow_tts/trainers/base_trainer.py @@ -21,6 +21,8 @@ import tensorflow as tf from tqdm import tqdm +from tensorflow_tts.optimizers import GradientAccumulator + class BasedTrainer(metaclass=abc.ABCMeta): """Customized trainer module for all models.""" @@ -206,6 +208,12 @@ def __init__( self._is_discriminator_mixed_precision = is_discriminator_mixed_precision self._strategy = strategy self._already_apply_input_signature = False + self._generator_gradient_accumulator = GradientAccumulator() + self._discriminator_gradient_accumulator = GradientAccumulator() + self._generator_gradient_accumulator.reset() + self._discriminator_gradient_accumulator.reset() + + def init_train_eval_metrics(self, list_metrics_name): with self._strategy.scope(): @@ -333,83 +341,164 @@ def compute_per_example_discriminator_losses(self, batch, gen_outputs): dict_metrics_losses = {} return per_example_losses, dict_metrics_losses - def _one_step_forward_per_replica(self, batch): - per_replica_gen_losses = 0.0 - per_replica_dis_losses = 0.0 - - # one step generator. - with tf.GradientTape() as g_tape: - outputs = self._generator(**batch, training=True) - ( - per_example_losses, - dict_metrics_losses, - ) = self.compute_per_example_generator_losses(batch, outputs) + def _calculate_generator_gradient_per_batch(self, batch): + outputs = self._generator(**batch, training=True) + ( + per_example_losses, + dict_metrics_losses, + ) = self.compute_per_example_generator_losses(batch, outputs) + per_replica_gen_losses = tf.nn.compute_average_loss( + per_example_losses, + global_batch_size=self.config["batch_size"] + * self.get_n_gpus() + * self.config["gradient_accumulation_steps"], + ) - per_replica_gen_losses = tf.nn.compute_average_loss( - per_example_losses, - global_batch_size=self.config["batch_size"] * self.get_n_gpus(), + if self._is_generator_mixed_precision: + scaled_per_replica_gen_losses = self._gen_optimizer.get_scaled_loss( + per_replica_gen_losses ) - if self._is_generator_mixed_precision: - scaled_per_replica_gen_losses = self._gen_optimizer.get_scaled_loss( - per_replica_gen_losses - ) - if self._is_generator_mixed_precision: - scaled_gradients = g_tape.gradient( + scaled_gradients = tf.gradients( scaled_per_replica_gen_losses, self._generator.trainable_variables ) gradients = self._gen_optimizer.get_unscaled_gradients(scaled_gradients) else: - gradients = g_tape.gradient( + gradients = tf.gradients( per_replica_gen_losses, self._generator.trainable_variables ) - self._gen_optimizer.apply_gradients( - zip(gradients, self._generator.trainable_variables) + # gradient accumulate for generator here + if self.config["gradient_accumulation_steps"] > 1: + self._generator_gradient_accumulator(gradients) + + # accumulate loss into metrics + self.update_train_metrics(dict_metrics_losses) + + if self.config["gradient_accumulation_steps"] == 1: + return gradients, per_replica_gen_losses + else: + return per_replica_gen_losses + + def _calculate_discriminator_gradient_per_batch(self, batch): + ( + per_example_losses, + dict_metrics_losses, + ) = self.compute_per_example_discriminator_losses( + batch, self._generator(**batch, training=True) + ) + + per_replica_dis_losses = tf.nn.compute_average_loss( + per_example_losses, + global_batch_size=self.config["batch_size"] + * self.get_n_gpus() + * self.config["gradient_accumulation_steps"], ) + if self._is_discriminator_mixed_precision: + scaled_per_replica_dis_losses = self._dis_optimizer.get_scaled_loss( + per_replica_dis_losses + ) + + if self._is_discriminator_mixed_precision: + scaled_gradients = tf.gradients( + scaled_per_replica_dis_losses, + self._discriminator.trainable_variables, + ) + gradients = self._dis_optimizer.get_unscaled_gradients(scaled_gradients) + else: + gradients = tf.gradients( + per_replica_dis_losses, self._discriminator.trainable_variables + ) + # accumulate loss into metrics self.update_train_metrics(dict_metrics_losses) - # one step discriminator - # recompute y_hat after 1 step generator for discriminator training. - if self.steps >= self.config["discriminator_train_start_steps"]: - with tf.GradientTape() as d_tape: - ( - per_example_losses, - dict_metrics_losses, - ) = self.compute_per_example_discriminator_losses( - batch, self._generator(**batch) - ) + # gradient accumulate for discriminator here + if self.config["gradient_accumulation_steps"] > 1: + self._discriminator_gradient_accumulator(gradients) + + if self.config["gradient_accumulation_steps"] == 1: + return gradients, per_replica_dis_losses + else: + return per_replica_dis_losses + + + def _one_step_forward_per_replica(self, batch): + per_replica_gen_losses = 0.0 + per_replica_dis_losses = 0.0 - per_replica_dis_losses = tf.nn.compute_average_loss( - per_example_losses, - global_batch_size=self.config["batch_size"] * self.get_n_gpus(), + if self.config["gradient_accumulation_steps"] == 1: + ( + gradients, + per_replica_gen_losses, + ) = self._calculate_generator_gradient_per_batch(batch) + self._gen_optimizer.apply_gradients( + zip(gradients, self._generator.trainable_variables) + ) + else: + # gradient acummulation here. + for i in tf.range(self.config["gradient_accumulation_steps"]): + reduced_batch = { + k: v[ + i + * self.config["batch_size"] : (i + 1) + * self.config["batch_size"] + ] + for k, v in batch.items() + } + + # run 1 step accumulate + reduced_batch_losses = self._calculate_generator_gradient_per_batch( + reduced_batch ) - if self._is_discriminator_mixed_precision: - scaled_per_replica_dis_losses = self._dis_optimizer.get_scaled_loss( - per_replica_dis_losses - ) + # sum per_replica_losses + per_replica_gen_losses += reduced_batch_losses + + gradients = self._generator_gradient_accumulator.gradients + self._gen_optimizer.apply_gradients( + zip(gradients, self._generator.trainable_variables) + ) + self._generator_gradient_accumulator.reset() - if self._is_discriminator_mixed_precision: - scaled_gradients = d_tape.gradient( - scaled_per_replica_dis_losses, - self._discriminator.trainable_variables, + # one step discriminator + # recompute y_hat after 1 step generator for discriminator training. + if self.steps >= self.config["discriminator_train_start_steps"]: + if self.config["gradient_accumulation_steps"] == 1: + ( + gradients, + per_replica_dis_losses, + ) = self._calculate_discriminator_gradient_per_batch(batch) + self._dis_optimizer.apply_gradients( + zip(gradients, self._discriminator.trainable_variables) ) - gradients = self._dis_optimizer.get_unscaled_gradients(scaled_gradients) else: - gradients = d_tape.gradient( - per_replica_dis_losses, self._discriminator.trainable_variables - ) + # gradient acummulation here. + for i in tf.range(self.config["gradient_accumulation_steps"]): + reduced_batch = { + k: v[ + i + * self.config["batch_size"] : (i + 1) + * self.config["batch_size"] + ] + for k, v in batch.items() + } + + # run 1 step accumulate + reduced_batch_losses = ( + self._calculate_discriminator_gradient_per_batch(reduced_batch) + ) - self._dis_optimizer.apply_gradients( - zip(gradients, self._discriminator.trainable_variables) - ) + # sum per_replica_losses + per_replica_dis_losses += reduced_batch_losses - # accumulate loss into metrics - self.update_train_metrics(dict_metrics_losses) + gradients = self._discriminator_gradient_accumulator.gradients + self._dis_optimizer.apply_gradients( + zip(gradients, self._discriminator.trainable_variables) + ) + self._discriminator_gradient_accumulator.reset() return per_replica_gen_losses + per_replica_dis_losses @@ -613,6 +702,11 @@ def __init__( # check if we already apply input_signature for train_step. self._already_apply_input_signature = False + # create gradient accumulator + self._gradient_accumulator = GradientAccumulator() + self._gradient_accumulator.reset() + + def init_train_eval_metrics(self, list_metrics_name): with self._strategy.scope(): super().init_train_eval_metrics(list_metrics_name) @@ -698,39 +792,77 @@ def _one_step_forward(self, batch): tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None ) - def _one_step_forward_per_replica(self, batch): - with tf.GradientTape() as tape: - outputs = self._model(**batch, training=True) - per_example_losses, dict_metrics_losses = self.compute_per_example_losses( - batch, outputs - ) - per_replica_losses = tf.nn.compute_average_loss( - per_example_losses, - global_batch_size=self.config["batch_size"] * self.get_n_gpus(), - ) + def _calculate_gradient_per_batch(self, batch): + outputs = self._model(**batch, training=True) + per_example_losses, dict_metrics_losses = self.compute_per_example_losses( + batch, outputs + ) + per_replica_losses = tf.nn.compute_average_loss( + per_example_losses, + global_batch_size=self.config["batch_size"] + * self.get_n_gpus() + * self.config["gradient_accumulation_steps"], + ) - if self._is_mixed_precision: - scaled_per_replica_losses = self._optimizer.get_scaled_loss( - per_replica_losses - ) + if self._is_mixed_precision: + scaled_per_replica_losses = self._optimizer.get_scaled_loss( + per_replica_losses + ) if self._is_mixed_precision: - scaled_gradients = tape.gradient( + scaled_gradients = tf.gradients( scaled_per_replica_losses, self._trainable_variables ) gradients = self._optimizer.get_unscaled_gradients(scaled_gradients) else: - gradients = tape.gradient( - per_replica_losses, self._trainable_variables - ) + gradients = tf.gradients(per_replica_losses, self._trainable_variables) - self._optimizer.apply_gradients(zip(gradients, self._trainable_variables), 1.0) + # gradient accumulate here + if self.config["gradient_accumulation_steps"] > 1: + self._gradient_accumulator(gradients) # accumulate loss into metrics self.update_train_metrics(dict_metrics_losses) + if self.config["gradient_accumulation_steps"] == 1: + return gradients, per_replica_losses + else: + return per_replica_losses + + def _one_step_forward_per_replica(self, batch): + if self.config["gradient_accumulation_steps"] == 1: + gradients, per_replica_losses = self._calculate_gradient_per_batch(batch) + self._optimizer.apply_gradients( + zip(gradients, self._trainable_variables) + ) + else: + # gradient acummulation here. + per_replica_losses = 0.0 + for i in tf.range(self.config["gradient_accumulation_steps"]): + reduced_batch = { + k: v[ + i + * self.config["batch_size"] : (i + 1) + * self.config["batch_size"] + ] + for k, v in batch.items() + } + + # run 1 step accumulate + reduced_batch_losses = self._calculate_gradient_per_batch(reduced_batch) + + # sum per_replica_losses + per_replica_losses += reduced_batch_losses + + gradients = self._gradient_accumulator.gradients + self._optimizer.apply_gradients( + zip(gradients, self._trainable_variables) + ) + self._gradient_accumulator.reset() + return per_replica_losses + @abc.abstractmethod def compute_per_example_losses(self, batch, outputs): """Compute per example losses and return dict_metrics_losses From 1ca581f3ad5355a72a4d6056248c7619bb080e7d Mon Sep 17 00:00:00 2001 From: dathudeptrai Date: Thu, 19 Nov 2020 14:14:54 +0700 Subject: [PATCH 2/4] =?UTF-8?q?=F0=9F=9A=B2=20=20Added=20gradient=5Faccumu?= =?UTF-8?q?late=5Fsteps=20for=20all=20config,=20clearer=20note=20for=20bat?= =?UTF-8?q?ch=5Fsize=20parameter.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/fastspeech/conf/fastspeech.v1.yaml | 5 +++-- examples/fastspeech/conf/fastspeech.v3.yaml | 5 +++-- examples/fastspeech2/conf/fastspeech2.baker.v2.yaml | 5 +++-- examples/fastspeech2/conf/fastspeech2.kss.v1.yaml | 5 +++-- examples/fastspeech2/conf/fastspeech2.kss.v2.yaml | 5 +++-- examples/fastspeech2/conf/fastspeech2.v1.yaml | 5 +++-- examples/fastspeech2/conf/fastspeech2.v2.yaml | 5 +++-- examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml | 5 +++-- examples/melgan.stft/conf/melgan.stft.v1.yaml | 4 ++-- examples/melgan/conf/melgan.v1.yaml | 3 ++- .../multiband_melgan/conf/multiband_melgan.baker.v1.yaml | 3 ++- examples/multiband_melgan/conf/multiband_melgan.v1.yaml | 3 ++- examples/multiband_pwgan/conf/multiband_pwgan.v1.yaml | 4 ++-- examples/multiband_pwgan/conf/multiband_pwgan.v1ft.yaml | 3 ++- examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml | 5 +++-- examples/tacotron2/conf/tacotron2.baker.v1.yaml | 5 +++-- examples/tacotron2/conf/tacotron2.kss.v1.yaml | 5 +++-- examples/tacotron2/conf/tacotron2.v1.yaml | 5 +++-- 18 files changed, 48 insertions(+), 32 deletions(-) diff --git a/examples/fastspeech/conf/fastspeech.v1.yaml b/examples/fastspeech/conf/fastspeech.v1.yaml index 5e7e6455..653a8512 100755 --- a/examples/fastspeech/conf/fastspeech.v1.yaml +++ b/examples/fastspeech/conf/fastspeech.v1.yaml @@ -46,7 +46,7 @@ fastspeech_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with asuming that gradient_accumulation_steps is 1 remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -60,7 +60,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech/conf/fastspeech.v3.yaml b/examples/fastspeech/conf/fastspeech.v3.yaml index 9cc5f61e..34920aa4 100755 --- a/examples/fastspeech/conf/fastspeech.v3.yaml +++ b/examples/fastspeech/conf/fastspeech.v3.yaml @@ -46,7 +46,7 @@ fastspeech_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -60,7 +60,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech2/conf/fastspeech2.baker.v2.yaml b/examples/fastspeech2/conf/fastspeech2.baker.v2.yaml index 67f23dbf..eb2d92a3 100644 --- a/examples/fastspeech2/conf/fastspeech2.baker.v2.yaml +++ b/examples/fastspeech2/conf/fastspeech2.baker.v2.yaml @@ -48,7 +48,7 @@ fastspeech2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -62,7 +62,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech2/conf/fastspeech2.kss.v1.yaml b/examples/fastspeech2/conf/fastspeech2.kss.v1.yaml index 7a9f763a..f04676de 100755 --- a/examples/fastspeech2/conf/fastspeech2.kss.v1.yaml +++ b/examples/fastspeech2/conf/fastspeech2.kss.v1.yaml @@ -47,7 +47,7 @@ fastspeech2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -61,7 +61,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech2/conf/fastspeech2.kss.v2.yaml b/examples/fastspeech2/conf/fastspeech2.kss.v2.yaml index 5ab6cd41..0391bebb 100755 --- a/examples/fastspeech2/conf/fastspeech2.kss.v2.yaml +++ b/examples/fastspeech2/conf/fastspeech2.kss.v2.yaml @@ -48,7 +48,7 @@ fastspeech2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -62,7 +62,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech2/conf/fastspeech2.v1.yaml b/examples/fastspeech2/conf/fastspeech2.v1.yaml index fa46f4e7..76bf57ca 100755 --- a/examples/fastspeech2/conf/fastspeech2.v1.yaml +++ b/examples/fastspeech2/conf/fastspeech2.v1.yaml @@ -46,7 +46,7 @@ fastspeech2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -60,7 +60,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech2/conf/fastspeech2.v2.yaml b/examples/fastspeech2/conf/fastspeech2.v2.yaml index 7123d20a..3e373c80 100755 --- a/examples/fastspeech2/conf/fastspeech2.v2.yaml +++ b/examples/fastspeech2/conf/fastspeech2.v2.yaml @@ -47,7 +47,7 @@ fastspeech2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1 remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 @@ -61,7 +61,8 @@ optimizer_params: decay_steps: 150000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml b/examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml index c9f23bd3..347071d1 100755 --- a/examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml +++ b/examples/fastspeech2_libritts/conf/fastspeech2libritts.yaml @@ -46,7 +46,7 @@ fastspeech2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 32 # Batch size. +batch_size: 32 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 48 # remove all targets has mel_length <= 32 @@ -60,7 +60,8 @@ optimizer_params: decay_steps: 120000 # < train_max_steps is recommend. warmup_proportion: 0.02 weight_decay: 0.001 - + +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|encoder|decoder' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/melgan.stft/conf/melgan.stft.v1.yaml b/examples/melgan.stft/conf/melgan.stft.v1.yaml index 6706910a..4660e586 100755 --- a/examples/melgan.stft/conf/melgan.stft.v1.yaml +++ b/examples/melgan.stft/conf/melgan.stft.v1.yaml @@ -63,7 +63,7 @@ lambda_adv: 4.0 ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -86,7 +86,7 @@ discriminator_optimizer_params: boundaries: [0] # after resume and start training discriminator, global steps is 100k, but local discriminator step is 0 values: [0.0001, 0.0001] # learning rate each interval. - +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/melgan/conf/melgan.v1.yaml b/examples/melgan/conf/melgan.v1.yaml index ef192677..70294c73 100755 --- a/examples/melgan/conf/melgan.v1.yaml +++ b/examples/melgan/conf/melgan.v1.yaml @@ -53,7 +53,7 @@ lambda_feat_match: 10.0 ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 16 # Batch size. +batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -73,6 +73,7 @@ discriminator_optimizer_params: beta_1: 0.5 beta_2: 0.9 +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/multiband_melgan/conf/multiband_melgan.baker.v1.yaml b/examples/multiband_melgan/conf/multiband_melgan.baker.v1.yaml index 4cbc9618..378cc72b 100644 --- a/examples/multiband_melgan/conf/multiband_melgan.baker.v1.yaml +++ b/examples/multiband_melgan/conf/multiband_melgan.baker.v1.yaml @@ -67,7 +67,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 64 # Batch size. +batch_size: 64 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. batch_max_steps: 9600 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 48000 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -91,6 +91,7 @@ discriminator_optimizer_params: values: [0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001] amsgrad: false +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/multiband_melgan/conf/multiband_melgan.v1.yaml b/examples/multiband_melgan/conf/multiband_melgan.v1.yaml index 62b656dd..628fc395 100755 --- a/examples/multiband_melgan/conf/multiband_melgan.v1.yaml +++ b/examples/multiband_melgan/conf/multiband_melgan.v1.yaml @@ -67,7 +67,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 64 # Batch size. +batch_size: 64 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 8192 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -91,6 +91,7 @@ discriminator_optimizer_params: values: [0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001] amsgrad: false +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/multiband_pwgan/conf/multiband_pwgan.v1.yaml b/examples/multiband_pwgan/conf/multiband_pwgan.v1.yaml index 237a3c02..ff5f1378 100644 --- a/examples/multiband_pwgan/conf/multiband_pwgan.v1.yaml +++ b/examples/multiband_pwgan/conf/multiband_pwgan.v1.yaml @@ -61,7 +61,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 64 # Batch size. +batch_size: 64 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -85,7 +85,7 @@ discriminator_optimizer_params: decay_steps: 200000 decay_rate: 0.5 - +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/multiband_pwgan/conf/multiband_pwgan.v1ft.yaml b/examples/multiband_pwgan/conf/multiband_pwgan.v1ft.yaml index 79184265..39354e26 100644 --- a/examples/multiband_pwgan/conf/multiband_pwgan.v1ft.yaml +++ b/examples/multiband_pwgan/conf/multiband_pwgan.v1ft.yaml @@ -65,7 +65,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 64 # Batch size. +batch_size: 64 # Batch size for each GPU with asuming that gradient_accumulation_steps == 1. batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -90,6 +90,7 @@ discriminator_optimizer_params: decay_steps: 70000 decay_rate: 0.5 +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml b/examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml index f9eef6a8..684086d0 100644 --- a/examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml +++ b/examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml @@ -65,8 +65,8 @@ lambda_adv: 4.0 # Loss balancing coefficient. ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 6 # Batch size. -batch_max_steps: 25600 # Length of each audio in batch for training. Make sure dividable by hop_size. +batch_size: 6 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. +batch_max_steps: 25600 # Length of each audio in batch for training. Make sure dividable by hop_size. batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. @@ -90,6 +90,7 @@ discriminator_optimizer_params: decay_steps: 200000 decay_rate: 0.5 +gradient_accumulation_steps: 1 ########################################################### # INTERVAL SETTING # ########################################################### diff --git a/examples/tacotron2/conf/tacotron2.baker.v1.yaml b/examples/tacotron2/conf/tacotron2.baker.v1.yaml index 1fe5f66f..6f069060 100644 --- a/examples/tacotron2/conf/tacotron2.baker.v1.yaml +++ b/examples/tacotron2/conf/tacotron2.baker.v1.yaml @@ -47,13 +47,13 @@ tacotron2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 32 # Batch size. +batch_size: 32 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 is_shuffle: true # shuffle dataset after each epoch. use_fixed_shapes: true # use_fixed_shapes for training (2x speed-up) - # refer (https://github.com/dathudeptrai/TensorflowTTS/issues/34#issuecomment-642309118) + # refer (https://github.com/tensorspeech/TensorflowTTS/issues/34#issuecomment-642309118) ########################################################### # OPTIMIZER & SCHEDULER SETTING # @@ -65,6 +65,7 @@ optimizer_params: warmup_proportion: 0.02 weight_decay: 0.001 +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|decoder_cell' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/tacotron2/conf/tacotron2.kss.v1.yaml b/examples/tacotron2/conf/tacotron2.kss.v1.yaml index 082fa332..dc553fbe 100755 --- a/examples/tacotron2/conf/tacotron2.kss.v1.yaml +++ b/examples/tacotron2/conf/tacotron2.kss.v1.yaml @@ -47,12 +47,12 @@ tacotron2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 32 # Batch size. +batch_size: 32 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 is_shuffle: true # shuffle dataset after each epoch. -use_fixed_shapes: false # use_fixed_shapes for training (2x speed-up) +use_fixed_shapes: true # use_fixed_shapes for training (2x speed-up) # refer (https://github.com/dathudeptrai/TensorflowTTS/issues/34#issuecomment-642309118) ########################################################### @@ -65,6 +65,7 @@ optimizer_params: warmup_proportion: 0.02 weight_decay: 0.001 +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|decoder_cell' ) # must separate by |. if var_train_expr is null then we # training all variable diff --git a/examples/tacotron2/conf/tacotron2.v1.yaml b/examples/tacotron2/conf/tacotron2.v1.yaml index dc3aa942..862b9c90 100755 --- a/examples/tacotron2/conf/tacotron2.v1.yaml +++ b/examples/tacotron2/conf/tacotron2.v1.yaml @@ -47,12 +47,12 @@ tacotron2_params: ########################################################### # DATA LOADER SETTING # ########################################################### -batch_size: 32 # Batch size. +batch_size: 32 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. mel_length_threshold: 32 # remove all targets has mel_length <= 32 is_shuffle: true # shuffle dataset after each epoch. -use_fixed_shapes: false # use_fixed_shapes for training (2x speed-up) +use_fixed_shapes: true # use_fixed_shapes for training (2x speed-up) # refer (https://github.com/dathudeptrai/TensorflowTTS/issues/34#issuecomment-642309118) ########################################################### @@ -65,6 +65,7 @@ optimizer_params: warmup_proportion: 0.02 weight_decay: 0.001 +gradient_accumulation_steps: 1 var_train_expr: null # trainable variable expr (eg. 'embeddings|decoder_cell' ) # must separate by |. if var_train_expr is null then we # training all variables. From ac4508af616ddf718b716ca5cd0f8d927b021b9f Mon Sep 17 00:00:00 2001 From: dathudeptrai Date: Thu, 19 Nov 2020 14:21:57 +0700 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=A6=9F=20Update=20global=20batch-size?= =?UTF-8?q?=20for=20train-dataloder.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/fastspeech/train_fastspeech.py | 16 +-- examples/fastspeech2/train_fastspeech2.py | 17 ++-- .../fastspeech2_libritts/train_fastspeech2.py | 97 +++++++++++-------- examples/melgan.stft/train_melgan_stft.py | 23 +++-- examples/melgan/train_melgan.py | 24 ++--- .../train_multiband_melgan.py | 4 +- .../multiband_pwgan/train_multiband_pwgan.py | 4 +- .../train_parallel_wavegan.py | 12 +-- examples/tacotron2/train_tacotron2.py | 15 +-- 9 files changed, 122 insertions(+), 90 deletions(-) diff --git a/examples/fastspeech/train_fastspeech.py b/examples/fastspeech/train_fastspeech.py index 1aacf72e..bce4e1d1 100755 --- a/examples/fastspeech/train_fastspeech.py +++ b/examples/fastspeech/train_fastspeech.py @@ -36,8 +36,7 @@ from tensorflow_tts.models import TFFastSpeech from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp from tensorflow_tts.trainers import Seq2SeqBasedTrainer -from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss, - return_strategy) +from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy class FastSpeechTrainer(Seq2SeqBasedTrainer): @@ -218,7 +217,7 @@ def main(): default="", type=str, nargs="?", - help='pretrained checkpoint file to load weights from. Auto-skips non-matching layers', + help="pretrained checkpoint file to load weights from. Auto-skips non-matching layers", ) args = parser.parse_args() @@ -302,7 +301,9 @@ def main(): ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = CharactorDurationMelDataset( @@ -335,11 +336,12 @@ def main(): ) fastspeech._build() fastspeech.summary() - + if len(args.pretrained) > 1: fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True) - logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.") - + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) # AdamW for fastspeech learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( diff --git a/examples/fastspeech2/train_fastspeech2.py b/examples/fastspeech2/train_fastspeech2.py index f4c61440..8d3c9144 100755 --- a/examples/fastspeech2/train_fastspeech2.py +++ b/examples/fastspeech2/train_fastspeech2.py @@ -33,15 +33,13 @@ from tqdm import tqdm import tensorflow_tts -from examples.fastspeech2.fastspeech2_dataset import \ - CharactorDurationF0EnergyMelDataset +from examples.fastspeech2.fastspeech2_dataset import CharactorDurationF0EnergyMelDataset from examples.fastspeech.train_fastspeech import FastSpeechTrainer from tensorflow_tts.configs import FastSpeech2Config from tensorflow_tts.models import TFFastSpeech2 from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp from tensorflow_tts.trainers import Seq2SeqBasedTrainer -from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss, - return_strategy) +from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy class FastSpeech2Trainer(Seq2SeqBasedTrainer): @@ -244,9 +242,8 @@ def main(): default="", type=str, nargs="?", - help='pretrained weights .h5 file to load weights from. Auto-skips non-matching layers', + help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers", ) - args = parser.parse_args() @@ -330,7 +327,9 @@ def main(): ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = CharactorDurationF0EnergyMelDataset( @@ -367,7 +366,9 @@ def main(): fastspeech.summary() if len(args.pretrained) > 1: fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True) - logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.") + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) # AdamW for fastspeech learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( diff --git a/examples/fastspeech2_libritts/train_fastspeech2.py b/examples/fastspeech2_libritts/train_fastspeech2.py index 8643cd10..0f9367ac 100755 --- a/examples/fastspeech2_libritts/train_fastspeech2.py +++ b/examples/fastspeech2_libritts/train_fastspeech2.py @@ -33,22 +33,33 @@ import json import tensorflow_tts -from examples.fastspeech2_libritts.fastspeech2_dataset import \ - CharactorDurationF0EnergyMelDataset +from examples.fastspeech2_libritts.fastspeech2_dataset import ( + CharactorDurationF0EnergyMelDataset, +) from tensorflow_tts.configs import FastSpeech2Config from tensorflow_tts.models import TFFastSpeech2 from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp from tensorflow_tts.trainers import Seq2SeqBasedTrainer -from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss, - return_strategy, TFGriffinLim) +from tensorflow_tts.utils import ( + calculate_2d_loss, + calculate_3d_loss, + return_strategy, + TFGriffinLim, +) class FastSpeech2Trainer(Seq2SeqBasedTrainer): """FastSpeech2 Trainer class based on FastSpeechTrainer.""" def __init__( - self, config, strategy, steps=0, epochs=0, is_mixed_precision=False, stats_path: str = "", - dataset_config: str = "" + self, + config, + strategy, + steps=0, + epochs=0, + is_mixed_precision=False, + stats_path: str = "", + dataset_config: str = "", ): """Initialize trainer. Args: @@ -78,7 +89,9 @@ def __init__( self.use_griffin = config.get("use_griffin", False) self.griffin_lim_tf = None if self.use_griffin: - logging.info(f"Load griff stats from {stats_path} and config from {dataset_config}") + logging.info( + f"Load griff stats from {stats_path} and config from {dataset_config}" + ) self.griff_conf = yaml.load(open(dataset_config), Loader=yaml.Loader) self.prepare_grim(stats_path, self.griff_conf) @@ -160,7 +173,9 @@ def generate_and_save_intermediate_result(self, batch): # check directory if self.use_griffin: - griff_dir_name = os.path.join(self.config["outdir"], f"predictions/{self.steps}_wav") + griff_dir_name = os.path.join( + self.config["outdir"], f"predictions/{self.steps}_wav" + ) if not os.path.exists(griff_dir_name): os.makedirs(griff_dir_name) @@ -171,23 +186,31 @@ def generate_and_save_intermediate_result(self, batch): for idx, (mel_gt, mel_before, mel_after) in enumerate( zip(mel_gts, mels_before, mels_after), 0 ): - - + if self.use_griffin: utt_id = utt_ids[idx] - grif_before = self.griffin_lim_tf(tf.reshape(mel_before, [-1, 80])[tf.newaxis, :], n_iter=32) - grif_after = self.griffin_lim_tf(tf.reshape(mel_after, [-1, 80])[tf.newaxis, :], n_iter=32) - grif_gt = self.griffin_lim_tf(tf.reshape(mel_gt, [-1, 80])[tf.newaxis, :], n_iter=32) - self.griffin_lim_tf.save_wav(grif_before, griff_dir_name, f"{utt_id}_before") - self.griffin_lim_tf.save_wav(grif_after, griff_dir_name, f"{utt_id}_after") + grif_before = self.griffin_lim_tf( + tf.reshape(mel_before, [-1, 80])[tf.newaxis, :], n_iter=32 + ) + grif_after = self.griffin_lim_tf( + tf.reshape(mel_after, [-1, 80])[tf.newaxis, :], n_iter=32 + ) + grif_gt = self.griffin_lim_tf( + tf.reshape(mel_gt, [-1, 80])[tf.newaxis, :], n_iter=32 + ) + self.griffin_lim_tf.save_wav( + grif_before, griff_dir_name, f"{utt_id}_before" + ) + self.griffin_lim_tf.save_wav( + grif_after, griff_dir_name, f"{utt_id}_after" + ) self.griffin_lim_tf.save_wav(grif_gt, griff_dir_name, f"{utt_id}_gt") - + utt_id = utt_ids[idx] mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80] mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80] mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80] - # plit figure and save it figname = os.path.join(dirname, f"{utt_id}.png") fig = plt.figure(figsize=(10, 8)) @@ -229,10 +252,7 @@ def main(): "--use-norm", default=1, type=int, help="usr norm-mels for train or raw." ) parser.add_argument( - "--f0-stat", - default="./dump/stats_f0.npy", - type=str, - help="f0-stat path.", + "--f0-stat", default="./dump/stats_f0.npy", type=str, help="f0-stat path.", ) parser.add_argument( "--energy-stat", @@ -266,26 +286,20 @@ def main(): help="using mixed precision for generator or not.", ) parser.add_argument( - "--dataset_config", - default="preprocess/libritts_preprocess.yaml", - type=str, + "--dataset_config", default="preprocess/libritts_preprocess.yaml", type=str, ) parser.add_argument( - "--dataset_stats", - default="dump/stats.npy", - type=str, + "--dataset_stats", default="dump/stats.npy", type=str, ) parser.add_argument( - "--dataset_mapping", - default="dump/libritts_mapper.npy", - type=str, + "--dataset_mapping", default="dump/libritts_mapper.npy", type=str, ) parser.add_argument( "--pretrained", default="", type=str, nargs="?", - help='pretrained weights .h5 file to load weights from. Auto-skips non-matching layers', + help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers", ) args = parser.parse_args() @@ -362,7 +376,9 @@ def main(): # Check n_speakers matches number of speakers in speakers_map n_speakers = config["fastspeech2_params"]["n_speakers"] - assert n_speakers == len(speakers_map), f"Number of speakers in dataset does not match n_speakers in config" + assert n_speakers == len( + speakers_map + ), f"Number of speakers in dataset does not match n_speakers in config" # define train/valid dataset train_dataset = CharactorDurationF0EnergyMelDataset( @@ -375,11 +391,13 @@ def main(): f0_stat=args.f0_stat, energy_stat=args.energy_stat, mel_length_threshold=mel_length_threshold, - speakers_map=speakers_map + speakers_map=speakers_map, ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = CharactorDurationF0EnergyMelDataset( @@ -392,7 +410,7 @@ def main(): f0_stat=args.f0_stat, energy_stat=args.energy_stat, mel_length_threshold=mel_length_threshold, - speakers_map=speakers_map + speakers_map=speakers_map, ).create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], @@ -407,7 +425,7 @@ def main(): epochs=0, is_mixed_precision=args.mixed_precision, stats_path=args.dataset_stats, - dataset_config=args.dataset_config + dataset_config=args.dataset_config, ) with STRATEGY.scope(): @@ -417,11 +435,12 @@ def main(): ) fastspeech._build() fastspeech.summary() - + if len(args.pretrained) > 1: fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True) - logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.") - + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) # AdamW for fastspeech learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( diff --git a/examples/melgan.stft/train_melgan_stft.py b/examples/melgan.stft/train_melgan_stft.py index 3a0c1d4b..978eb78f 100755 --- a/examples/melgan.stft/train_melgan_stft.py +++ b/examples/melgan.stft/train_melgan_stft.py @@ -36,10 +36,8 @@ from examples.melgan.audio_mel_dataset import AudioMelDataset from examples.melgan.train_melgan import MelganTrainer, collater from tensorflow_tts.losses import TFMultiResolutionSTFT -from tensorflow_tts.models import (TFMelGANGenerator, - TFMelGANMultiScaleDiscriminator) -from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss, - return_strategy) +from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator +from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy class MultiSTFTMelganTrainer(MelganTrainer): @@ -206,7 +204,7 @@ def main(): default="", type=str, nargs="?", - help='path of .h5 melgan generator to load weights from', + help="path of .h5 melgan generator to load weights from", ) args = parser.parse_args() @@ -295,7 +293,9 @@ def main(): hop_size=tf.constant(config["hop_size"], dtype=tf.int32), ), allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = AudioMelDataset( @@ -336,7 +336,9 @@ def main(): ) discriminator = TFMelGANMultiScaleDiscriminator( - MELGAN_CONFIG.MelGANDiscriminatorConfig(**config["melgan_discriminator_params"]), + MELGAN_CONFIG.MelGANDiscriminatorConfig( + **config["melgan_discriminator_params"] + ), name="melgan_discriminator", ) @@ -344,11 +346,12 @@ def main(): fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32) y_hat = generator(fake_mels) discriminator(y_hat) - + if len(args.pretrained) > 1: generator.load_weights(args.pretrained) - logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.") - + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) generator.summary() discriminator.summary() diff --git a/examples/melgan/train_melgan.py b/examples/melgan/train_melgan.py index 49650d54..ccb7ed3a 100755 --- a/examples/melgan/train_melgan.py +++ b/examples/melgan/train_melgan.py @@ -37,11 +37,9 @@ import tensorflow_tts.configs.melgan as MELGAN_CONFIG from examples.melgan.audio_mel_dataset import AudioMelDataset from tensorflow_tts.losses import TFMelSpectrogram -from tensorflow_tts.models import (TFMelGANGenerator, - TFMelGANMultiScaleDiscriminator) +from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator from tensorflow_tts.trainers import GanBasedTrainer -from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss, - return_strategy) +from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy class MelganTrainer(GanBasedTrainer): @@ -343,7 +341,7 @@ def main(): default="", type=str, nargs="?", - help='path of .h5 melgan generator to load weights from', + help="path of .h5 melgan generator to load weights from", ) args = parser.parse_args() @@ -432,7 +430,9 @@ def main(): hop_size=tf.constant(config["hop_size"], dtype=tf.int32), ), allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = AudioMelDataset( @@ -473,7 +473,9 @@ def main(): ) discriminator = TFMelGANMultiScaleDiscriminator( - MELGAN_CONFIG.MelGANDiscriminatorConfig(**config["melgan_discriminator_params"]), + MELGAN_CONFIG.MelGANDiscriminatorConfig( + **config["melgan_discriminator_params"] + ), name="melgan_discriminator", ) @@ -481,12 +483,12 @@ def main(): fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32) y_hat = generator(fake_mels) discriminator(y_hat) - + if len(args.pretrained) > 1: generator.load_weights(args.pretrained) - logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.") - - + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) generator.summary() discriminator.summary() diff --git a/examples/multiband_melgan/train_multiband_melgan.py b/examples/multiband_melgan/train_multiband_melgan.py index 9bb1afab..a162db6b 100755 --- a/examples/multiband_melgan/train_multiband_melgan.py +++ b/examples/multiband_melgan/train_multiband_melgan.py @@ -406,7 +406,9 @@ def main(): hop_size=tf.constant(config["hop_size"], dtype=tf.int32), ), allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = AudioMelDataset( diff --git a/examples/multiband_pwgan/train_multiband_pwgan.py b/examples/multiband_pwgan/train_multiband_pwgan.py index 2e448207..a11daded 100644 --- a/examples/multiband_pwgan/train_multiband_pwgan.py +++ b/examples/multiband_pwgan/train_multiband_pwgan.py @@ -420,7 +420,9 @@ def main(): hop_size=tf.constant(config["hop_size"], dtype=tf.int32), ), allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = AudioMelDataset( diff --git a/examples/parallel_wavegan/train_parallel_wavegan.py b/examples/parallel_wavegan/train_parallel_wavegan.py index 4ba60e1b..c4e4bf41 100644 --- a/examples/parallel_wavegan/train_parallel_wavegan.py +++ b/examples/parallel_wavegan/train_parallel_wavegan.py @@ -379,7 +379,9 @@ def main(): hop_size=tf.constant(config["hop_size"], dtype=tf.int32), ), allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = AudioMelDataset( @@ -445,12 +447,8 @@ def main(): config["discriminator_optimizer_params"]["lr_fn"], )(**config["discriminator_optimizer_params"]["lr_params"]) - gen_optimizer = RectifiedAdam( - learning_rate=generator_lr_fn, amsgrad=False - ) - dis_optimizer = RectifiedAdam( - learning_rate=discriminator_lr_fn, amsgrad=False - ) + gen_optimizer = RectifiedAdam(learning_rate=generator_lr_fn, amsgrad=False) + dis_optimizer = RectifiedAdam(learning_rate=discriminator_lr_fn, amsgrad=False) trainer.compile( gen_model=generator, diff --git a/examples/tacotron2/train_tacotron2.py b/examples/tacotron2/train_tacotron2.py index 8100d1d0..4a81b80b 100755 --- a/examples/tacotron2/train_tacotron2.py +++ b/examples/tacotron2/train_tacotron2.py @@ -37,8 +37,7 @@ from tensorflow_tts.models import TFTacotron2 from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp from tensorflow_tts.trainers import Seq2SeqBasedTrainer -from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss, - return_strategy) +from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy class Tacotron2Trainer(Seq2SeqBasedTrainer): @@ -313,7 +312,7 @@ def main(): default="", type=str, nargs="?", - help='pretrained weights .h5 file to load weights from. Auto-skips non-matching layers', + help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers", ) args = parser.parse_args() @@ -402,7 +401,9 @@ def main(): train_dataset = train_dataset.create( is_shuffle=config["is_shuffle"], allow_cache=config["allow_cache"], - batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync, + batch_size=config["batch_size"] + * STRATEGY.num_replicas_in_sync + * config["gradient_accumulation_steps"], ) valid_dataset = CharactorMelDataset( @@ -436,10 +437,12 @@ def main(): tacotron2 = TFTacotron2(config=tacotron_config, training=True, name="tacotron2") tacotron2._build() tacotron2.summary() - + if len(args.pretrained) > 1: tacotron2.load_weights(args.pretrained, by_name=True, skip_mismatch=True) - logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.") + logging.info( + f"Successfully loaded pretrained weight from {args.pretrained}." + ) # AdamW for tacotron2 learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( From 19d1f0408f2c6b840b2cbb1f9e4a3ce4b05ae6e6 Mon Sep 17 00:00:00 2001 From: dathudeptrai Date: Thu, 19 Nov 2020 14:26:50 +0700 Subject: [PATCH 4/4] =?UTF-8?q?=F0=9F=98=8F=20Update=20README.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 14323722..de98f57c 100755 --- a/README.md +++ b/README.md @@ -19,16 +19,17 @@ :zany_face: TensorFlowTTS provides real-time state-of-the-art speech synthesis architectures such as Tacotron-2, Melgan, Multiband-Melgan, FastSpeech, FastSpeech2 based-on TensorFlow 2. With Tensorflow 2, we can speed-up training/inference progress, optimizer further by using [fake-quantize aware](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide) and [pruning](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras), make TTS models can be run faster than real-time and be able to deploy on mobile devices or embedded systems. ## What's new -- 2020/08/23 **(NEW!)** Add Parallel WaveGAN tensorflow implementation. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/parallel_wavegan) -- 2020/08/23 **(NEW!)** Add MBMelGAN G + ParallelWaveGAN G example. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/multiband_pwgan) -- 2020/08/20 **(NEW!)** Add C++ inference code. Thank [@ZDisket](https://github.com/ZDisket). See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/cppwin) -- 2020/08/18 **(NEW!)** Update [new base processor](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/processor/base_processor.py). Add [AutoProcessor](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/inference/auto_processor.py) and [pretrained processor](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/processor/pretrained/) json file. -- 2020/08/14 **(NEW!)** Support Chinese TTS. Pls see the [colab](https://colab.research.google.com/drive/1YpSHRBRPBI7cnTkQn1UcVTWEQVbsUm1S?usp=sharing). Thank [@azraelkuan](https://github.com/azraelkuan). -- 2020/08/05 **(NEW!)** Support Korean TTS. Pls see the [colab](https://colab.research.google.com/drive/1ybWwOS5tipgPFttNulp77P6DAB5MtiuN?usp=sharing). Thank [@crux153](https://github.com/crux153). -- 2020/07/17 Support MultiGPU for all Trainer. -- 2020/07/05 Support Convert Tacotron-2, FastSpeech to Tflite. Pls see the [colab](https://colab.research.google.com/drive/1HudLLpT9CQdh2k04c06bHUwLubhGTWxA?usp=sharing). Thank @jaeyoo from the TFlite team for his support. +- 2020/11/19 **(NEW!)** Add Multi-GPU gradient accumulator. See [here](https://github.com/TensorSpeech/TensorFlowTTS/pull/377) +- 2020/08/23 Add Parallel WaveGAN tensorflow implementation. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/parallel_wavegan) +- 2020/08/23 Add MBMelGAN G + ParallelWaveGAN G example. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/multiband_pwgan) +- 2020/08/20 Add C++ inference code. Thank [@ZDisket](https://github.com/ZDisket). See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/cppwin) +- 2020/08/18 Update [new base processor](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/processor/base_processor.py). Add [AutoProcessor](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/inference/auto_processor.py) and [pretrained processor](https://github.com/TensorSpeech/TensorFlowTTS/blob/master/tensorflow_tts/processor/pretrained/) json file +- 2020/08/14 Support Chinese TTS. Pls see the [colab](https://colab.research.google.com/drive/1YpSHRBRPBI7cnTkQn1UcVTWEQVbsUm1S?usp=sharing). Thank [@azraelkuan](https://github.com/azraelkuan) +- 2020/08/05 Support Korean TTS. Pls see the [colab](https://colab.research.google.com/drive/1ybWwOS5tipgPFttNulp77P6DAB5MtiuN?usp=sharing). Thank [@crux153](https://github.com/crux153) +- 2020/07/17 Support MultiGPU for all Trainer +- 2020/07/05 Support Convert Tacotron-2, FastSpeech to Tflite. Pls see the [colab](https://colab.research.google.com/drive/1HudLLpT9CQdh2k04c06bHUwLubhGTWxA?usp=sharing). Thank @jaeyoo from the TFlite team for his support - 2020/06/20 [FastSpeech2](https://arxiv.org/abs/2006.04558) implementation with Tensorflow is supported. -- 2020/06/07 [Multi-band MelGAN (MB MelGAN)](https://github.com/tensorspeech/TensorFlowTTS/blob/master/examples/multiband_melgan/) implementation with Tensorflow is supported. +- 2020/06/07 [Multi-band MelGAN (MB MelGAN)](https://github.com/tensorspeech/TensorFlowTTS/blob/master/examples/multiband_melgan/) implementation with Tensorflow is supported ## Features @@ -38,6 +39,7 @@ - Suitable for deployment. - Easy to implement a new model, based-on abstract class. - Mixed precision to speed-up training if possible. +- Support Single/Multi GPU gradient Accumulate. - Support both Single/Multi GPU in base trainer class. - TFlite conversion for all supported models. - Android example.