Skip to content

Commit

Permalink
Changes in create_optimizer to support tensor parallelism with SMP (#…
Browse files Browse the repository at this point in the history
…16880)

* changes in create optimizer to support tensor parallelism with SMP

* Update src/transformers/trainer.py

Convert if check to one line.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Cavdar <dcavdar@a07817b12d7e.ant.amazon.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 22, 2022
1 parent 99c8226 commit 22fc93c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,16 +843,18 @@ def create_optimizer(self):
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"params": [p for n, p in opt_model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
Expand All @@ -872,7 +874,7 @@ def create_optimizer(self):

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

for module in self.model.modules():
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
Expand Down

0 comments on commit 22fc93c

Please # to comment.