Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

ZeRO stage 1 refresh #1042

Merged
merged 12 commits into from
May 19, 2021
6 changes: 3 additions & 3 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''
import sys
import types
import packaging

from . import ops

Expand All @@ -25,9 +26,8 @@

def _parse_version(version_str):
'''Parse a version string and extract the major, minor, and patch versions.'''
import re
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str)
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3))
ver = packaging.version.parse(version_str)
return ver.major, ver.minor, ver.micro


# Export version information
Expand Down
5 changes: 0 additions & 5 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,12 +766,7 @@ def _do_error_check(self):
GRADIENT_ACCUMULATION_STEPS)

if self.zero_enabled:
if self.zero_optimization_stage < ZERO_OPTIMIZATION_GRADIENTS:
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
#if self.zero_config.cpu_offload is True:
# assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)

def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled
Expand Down
79 changes: 49 additions & 30 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..ops.op_builder import UtilsBuilder
from ..ops.adam import DeepSpeedCPUAdam
from ..ops.adam import FusedAdam
from ..git_version_info import version

from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler

Expand Down Expand Up @@ -148,6 +149,8 @@ def __init__(self,
# Configure distributed model
self._configure_distributed_model(model)

self.pipeline_parallelism = isinstance(self.module, PipelineModule)

see_memory_usage(f"DeepSpeed Engine: After configure distributed model")

# Configure wall clock timer
Expand Down Expand Up @@ -390,6 +393,12 @@ def zero_gather_fp16_weights_on_model_save(self):
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters

def zero_grad_hooks(self):
return self._config.zero_config.grad_hooks

def zero_legacy_stage1(self):
return self._config.zero_config.legacy_stage1

def fp16_enabled(self):
return self._config.fp16_enabled

Expand Down Expand Up @@ -780,7 +789,8 @@ def _configure_zero_optimizer(self, optimizer):
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
timers = self.timers if self.wall_clock_breakdown() else None

if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
if self.zero_legacy_stage1(
) and zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer,
static_loss_scale=self.loss_scale(),
Expand All @@ -792,8 +802,19 @@ def _configure_zero_optimizer(self, optimizer):
max_elements_per_comm=self.zero_reduce_bucket_size(),
dp_process_group=self.data_parallel_group,
elastic_checkpoint=self.zero_elastic_checkpoint(),
mpu=self.mpu)
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_predivide=self.gradient_predivide)
elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
overlap_comm = self.zero_overlap_comm()
if isinstance(self.module, PipelineModule):
if overlap_comm:
logger.warning(
"Pipeline parallelism does not support overlapped communication, will be disabled."
)
overlap_comm = False

optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=timers,
Expand All @@ -806,13 +827,14 @@ def _configure_zero_optimizer(self, optimizer):
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
overlap_comm=overlap_comm,
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
ignore_unused_parameters=self.zero_ignore_unused_parameters())
ignore_unused_parameters=self.zero_ignore_unused_parameters(),
partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS)
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
Expand Down Expand Up @@ -999,18 +1021,15 @@ def forward(self, *inputs, **kwargs):
return loss

def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
#Zero stage 2 communicates during non gradient accumulation boundaries as well
# ZeRO stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()

#Communicate only at gradient accumulation boundaries
# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage(
) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter():
self.optimizer.reduce_scatter_gradients(
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_average=self.gradient_average)
if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
self.optimizer.reduce_gradients(
pipeline_parallel=self.pipeline_parallelism)
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

Expand Down Expand Up @@ -1731,19 +1750,19 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
# then instead just returns None.
self._curr_ckpt_path = os.path.join(save_dir, tag)

state = dict(
module=self.module_state_dict(),
optimizer=self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None,
lr_scheduler=self.lr_scheduler.state_dict()
if self.lr_scheduler is not None else None,
csr_tensor_module_names=self.csr_tensor_module_names,
skipped_steps=self.skipped_steps,
global_steps=self.global_steps,
global_samples=self.global_samples,
dp_world_size=self.dp_world_size,
mp_world_size=self.mp_world_size,
)
state = dict(module=self.module_state_dict(),
optimizer=self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None,
lr_scheduler=self.lr_scheduler.state_dict()
if self.lr_scheduler is not None else None,
csr_tensor_module_names=self.csr_tensor_module_names,
skipped_steps=self.skipped_steps,
global_steps=self.global_steps,
global_samples=self.global_samples,
dp_world_size=self.dp_world_size,
mp_world_size=self.mp_world_size,
ds_config=self.config,
ds_version=version)
state.update(client_state)

log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
Expand Down Expand Up @@ -1771,10 +1790,10 @@ def _copy_recovery_script(self, save_path):

def _save_zero_checkpoint(self, save_path, tag):
zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag)
zero_sd = dict(
optimizer_state_dict=self.optimizer.state_dict(),
param_shapes=self._get_param_shapes(),
)
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(),
param_shapes=self._get_param_shapes(),
ds_config=self.config,
ds_version=version)
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
Expand Down
5 changes: 2 additions & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,8 @@ def _exec_reduce_tied_grads(self):

def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.is_data_parallel and self.pipeline_enable_backward_allreduce:
self.buffered_allreduce_fallback(
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
if self.pipeline_enable_backward_allreduce:
self.allreduce_gradients(bucket_size=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False

def _reserve_pipe_buffers(self, num_buffers):
Expand Down
Loading