Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck authored May 10, 2021
1 parent c814fca commit c7c2063
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer, FP16_FUSED_SUPPORTED_OPTIMIZERS, is_fp16_fused_supported_optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
Expand Down Expand Up @@ -397,9 +397,6 @@ def zero_gather_fp16_weights_on_model_save(self):
def fp16_enabled(self):
return self._config.fp16_enabled

def precision(self):
return self._config.precision

def amp_enabled(self):
return self._config.amp_enabled

Expand Down Expand Up @@ -572,18 +569,14 @@ def is_replicated(p):

for p in self.module.parameters():
if torch.is_tensor(p) and is_replicated(p):
if self.precision() == torch.bfloat16:
p = p.float()
dist.broadcast(p,
self.broadcast_src_rank,
group=self.data_parallel_group)
if self.precision() == torch.bfloat16:
p = p.bfloat16()

def _configure_distributed_model(self, model):
self.module = model
if self.fp16_enabled():
self.module.to(self.precision())
self.module.half()

if not self.dont_change_device:
self.module.to(self.device)
Expand Down Expand Up @@ -721,8 +714,7 @@ def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if isinstance(optimizer,
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
if is_fp16_fused_supported_optimizer(optimizer):
if self.dynamic_loss_scale():
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
timers = self.timers if self.wall_clock_breakdown() else None
Expand Down Expand Up @@ -780,8 +772,7 @@ def _configure_zero_optimizer(self, optimizer):
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu,
precision=self.precision())
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
Expand All @@ -800,8 +791,7 @@ def _configure_zero_optimizer(self, optimizer):
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
precision=self.precision())
gradient_accumulation_steps=self.gradient_accumulation_steps())
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
Expand Down Expand Up @@ -989,7 +979,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):

# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
# TODO: communication in fp16 / fp32
if self.zero_optimization_stage(
) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter():
self.optimizer.reduce_scatter_gradients(
Expand Down

0 comments on commit c7c2063

Please # to comment.