diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index faa60f20efa3..c5c4a1cd9ba1 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -253,7 +253,7 @@ def get_partition_size(item): return int(partition_size) -def get_full_inputs(tensors, device=None): +def get_full_inputs(tensors, device=None, fp32_comm=False): inputs = [] num_args = int(len(tensors) / 2) for i in range(num_args - 1): @@ -274,9 +274,14 @@ def get_full_inputs(tensors, device=None): part_i = flat_tensor.narrow(0, partition_size * i, partition_size) if i == mp_rank: part_i.copy_(item) + if fp32_comm: + part_i = part_i.float() partitions.append(part_i) if mp_group is not None: dist.all_gather(partitions, partitions[mp_rank], group=mp_group) + if fp32_comm: + for i in range(mp_size): + partitions[i] = partitions[i].to(item.dtype) input_tensor = flat_tensor.view(list(size.numpy())) item.data = input_tensor.data @@ -599,9 +604,14 @@ def backward(ctx, *grads): global cuda_device, transport_stream, PARTITION_ACTIVATIONS if PARTITION_ACTIVATIONS: + if ctx.saved_tensors and ctx.saved_tensors[0].dtype == torch.bfloat16: + FP32_COMM = True + else: + FP32_COMM = False # with torch.cuda.stream(transport_stream): inputs = get_full_inputs(ctx.saved_tensors, - device=cuda_device if PA_TO_CPU else None) + device=cuda_device if PA_TO_CPU else None, + fp32_comm=FP32_COMM) detached_inputs = detach_variable(inputs) else: inputs = ctx.saved_tensors diff --git a/deepspeed/runtime/comm/__init__.py b/deepspeed/runtime/comm/__init__.py index e69de29bb2d1..f321901e2381 100644 --- a/deepspeed/runtime/comm/__init__.py +++ b/deepspeed/runtime/comm/__init__.py @@ -0,0 +1 @@ +from .compressed_ar import compressed_all_reduce diff --git a/deepspeed/runtime/comm/compressed_ar.py b/deepspeed/runtime/comm/compressed_ar.py new file mode 100644 index 000000000000..24a01ffcb9e3 --- /dev/null +++ b/deepspeed/runtime/comm/compressed_ar.py @@ -0,0 +1,54 @@ +# python -m torch.distributed.launch --nproc_per_node=1 24_bit_allreduce.py + +import torch +import os +import cupy +from torch.utils.dlpack import to_dlpack +from torch.utils.dlpack import from_dlpack + +version = torch.__version__.split('.') +TORCH_VERSION_MAJOR = int(version[0]) +TORCH_VERSION_MINOR = int(version[1]) +if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9): + compressed_all_reduce = compressed_all_reduce_cupy +else: + compressed_all_reduce = compressed_all_reduce_torch + +def torch2cupy(tensor): + return cupy.fromDlpack(to_dlpack(tensor)) + + +def cupy2torch(cupy_tensor): + return from_dlpack(cupy_tensor.toDlpack()) + + +def decompose_cupy(tensor): + mantissa, exponent = cupy.frexp(torch2cupy(tensor.float())) + return cupy2torch(mantissa).half(), cupy2torch(exponent).to(torch.int8) + + +def decompose(t): + if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9): + raise Exception('Torch version >= 1.9.0 needed for 24_bit_allreduce.decompose') + mantissa, exponent = torch.frexp(t.float()) + return mantissa.half(), exponent.to(torch.int8) + + +def reconstruct(mantissa, exponent, original_dtype=torch.bfloat16): + return torch.ldexp(mantissa, exponent).to(original_dtype) + + +def compressed_all_reduce_torch(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): + original_dtype = tensor.dtype + m, e = decompose(tensor) + torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op) + torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op) + return reconstruct(m, e, original_dtype) + + +def compressed_all_reduce_cupy(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): + original_dtype = tensor.dtype + m, e = decompose_cupy(tensor) + torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op) + torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op) + return reconstruct(m, e, original_dtype) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 3fa0b32a6032..841a69e99046 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -94,8 +94,18 @@ def get_fp16_enabled(param_dict): return False +def get_fp16_type(param_dict): + if get_fp16_enabled(param_dict): + return get_scalar_param(param_dict[FP16], FP16_TYPE, FP16_TYPE_DEFAULT) + else: + return "fp32" + + def get_loss_scale(param_dict): if get_fp16_enabled(param_dict): + if get_fp16_type(param_dict) == "bfloat16": + # default loss scale to 1.0 if dtype == bf16, as loss scaling isn't needed + return 1.0 return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) @@ -111,7 +121,7 @@ def get_initial_dynamic_scale(param_dict): else: initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT - return 2**initial_scale_power + return 2 ** initial_scale_power def get_dynamic_loss_scale_args(param_dict): @@ -138,7 +148,7 @@ def get_dynamic_loss_scale_args(param_dict): FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT) loss_scale_args = { - INITIAL_LOSS_SCALE: 2**init_scale, + INITIAL_LOSS_SCALE: 2 ** init_scale, SCALE_WINDOW: scale_window, DELAYED_SHIFT: delayed_shift, MIN_LOSS_SCALE: min_loss_scale @@ -168,6 +178,9 @@ def get_zero_reduce_scatter(param_dict): def get_allreduce_always_fp32(param_dict): + if get_fp16_type(param_dict) == "bfloat16": + # default allreduce_always_fp32 to True if dtype == bf16, as nccl can't communicate bf16 tensors + return get_scalar_param(param_dict, FP32_ALLREDUCE, FP32_ALLREDUCE_DEFAULT_BF16) return get_scalar_param(param_dict, FP32_ALLREDUCE, FP32_ALLREDUCE_DEFAULT) @@ -409,7 +422,7 @@ def get_optimizer_gradient_clipping(param_dict): def get_optimizer_legacy_fusion(param_dict): if OPTIMIZER in param_dict.keys() and \ - LEGACY_FUSION in param_dict[OPTIMIZER].keys(): + LEGACY_FUSION in param_dict[OPTIMIZER].keys(): return param_dict[OPTIMIZER][LEGACY_FUSION] else: return LEGACY_FUSION_DEFAULT @@ -496,7 +509,7 @@ def get_checkpoint_tag_validation_mode(checkpoint_params): return tag_validation_mode else: raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \ - f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}") + f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}") '''Write deepspeed config files by modifying basic templates. @@ -568,11 +581,11 @@ def __init__(self, json_file, mpu=None, param_dict=None): ] if any(map(lambda t: t in self._param_dict, batch_params)): raise ElasticityConfigError("One or more batch related parameters were found in your " \ - f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \ - f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \ - "elastic training is enabled, which takes control of these parameters. " \ - "If you want to supress this error (the parameters will be silently ignored) " \ - f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.") + f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \ + f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \ + "elastic training is enabled, which takes control of these parameters. " \ + "If you want to supress this error (the parameters will be silently ignored) " \ + f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.") # micro_bsz * world_size * gas = total_batch_size # gas = total_batch_size // (micro_bsz * world_size) @@ -581,13 +594,13 @@ def __init__(self, json_file, mpu=None, param_dict=None): if TRAIN_BATCH_SIZE in self._param_dict: logger.warning("[Elasticity] overriding training_batch_size: " \ - f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}") + f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}") if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict: logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \ - f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}") + f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}") if GRADIENT_ACCUMULATION_STEPS in self._param_dict: - logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\ - f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}") + logger.warning("[Elasticity] overriding gradient_accumulation_steps: " \ + f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}") logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}") @@ -622,6 +635,9 @@ def _initialize_params(self, param_dict): self.gradient_clipping = get_gradient_clipping(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict) + self.fp16_type = get_fp16_type(param_dict) + self.precision = PRECISION_TYPES[self.fp16_type] + self.amp_enabled = get_amp_enabled(param_dict) self.amp_params = get_amp_params(param_dict) self.loss_scale = get_loss_scale(param_dict) @@ -630,7 +646,7 @@ def _initialize_params(self, param_dict): self.optimizer_name = get_optimizer_name(param_dict) if self.optimizer_name is not None and \ - self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: + self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: self.optimizer_name = self.optimizer_name.lower() self.optimizer_params = get_optimizer_params(param_dict) @@ -678,9 +694,9 @@ def _batch_assertion(self): f'Gradient accumulation steps: {grad_acc} has to be greater than 0' assert train_batch == micro_batch * grad_acc * self.world_size, \ - (f'Check batch related parameters. train_batch_size is not equal' - ' to micro_batch_per_gpu * gradient_acc_step * world_size' - f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}') + (f'Check batch related parameters. train_batch_size is not equal' + ' to micro_batch_per_gpu * gradient_acc_step * world_size' + f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}') def _set_batch_related_parameters(self): @@ -688,44 +704,44 @@ def _set_batch_related_parameters(self): micro_batch = self.train_micro_batch_size_per_gpu grad_acc = self.gradient_accumulation_steps - #all values are provided nothing needs to be set + # all values are provided nothing needs to be set if train_batch is not None and \ - micro_batch is not None and \ - grad_acc is not None: + micro_batch is not None and \ + grad_acc is not None: return - #global_accumulation_steps needs to be set + # global_accumulation_steps needs to be set elif train_batch is not None and \ - micro_batch is not None: + micro_batch is not None: grad_acc = train_batch // micro_batch grad_acc //= self.world_size self.gradient_accumulation_steps = grad_acc - #micro_batch_per_gpu needs to be set + # micro_batch_per_gpu needs to be set elif train_batch is not None and \ - grad_acc is not None: + grad_acc is not None: micro_batch = train_batch // self.world_size micro_batch //= grad_acc self.train_micro_batch_size_per_gpu = micro_batch - #train_batch_size needs to be set + # train_batch_size needs to be set elif micro_batch is not None and \ - grad_acc is not None: + grad_acc is not None: train_batch_size = micro_batch * grad_acc train_batch_size *= self.world_size self.train_batch_size = train_batch_size - #gradient_accumulation_steps and micro_batch_per_gpus is set + # gradient_accumulation_steps and micro_batch_per_gpus is set elif train_batch is not None: self.gradient_accumulation_steps = 1 self.train_micro_batch_size_per_gpu = train_batch // self.world_size - #train_batch_size and gradient_accumulation_step is set + # train_batch_size and gradient_accumulation_step is set elif micro_batch is not None: self.train_batch_size = micro_batch * self.world_size self.gradient_accumulation_steps = 1 - #either none of the three parameters are provided or just gradient_accumulation_step is provided + # either none of the three parameters are provided or just gradient_accumulation_step is provided else: assert False, \ 'Either train_batch_size or micro_batch_per_gpu needs to be provided' @@ -755,17 +771,19 @@ def print(self, name): ':')))) def _do_error_check(self): - assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) + assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format( + TRAIN_MICRO_BATCH_SIZE_PER_GPU) assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format( GRADIENT_ACCUMULATION_STEPS) if self.zero_enabled: assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" - assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) - #if self.zero_config.cpu_offload is True: + assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format( + MAX_STAGE_ZERO_OPTIMIZATION) + # if self.zero_config.cpu_offload is True: # assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS) - #assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD) + # assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD) def _do_warning_check(self): fp16_enabled = self.fp16_enabled or self.zero_enabled @@ -774,21 +792,21 @@ def _do_warning_check(self): if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: logger.warning( "DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization." - .format(vocabulary_size, - TENSOR_CORE_ALIGN_SIZE)) + .format(vocabulary_size, + TENSOR_CORE_ALIGN_SIZE)) if self.optimizer_params is not None and \ - MAX_GRAD_NORM in self.optimizer_params.keys() and \ + MAX_GRAD_NORM in self.optimizer_params.keys() and \ self.optimizer_params[MAX_GRAD_NORM] > 0: if fp16_enabled: if self.global_rank == 0: logger.warning( 'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper' - .format(MAX_GRAD_NORM, - self.optimizer_params[MAX_GRAD_NORM])) + .format(MAX_GRAD_NORM, + self.optimizer_params[MAX_GRAD_NORM])) else: if self.global_rank == 0: logger.warning( 'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero' - .format(self.optimizer_params[MAX_GRAD_NORM])) + .format(self.optimizer_params[MAX_GRAD_NORM])) self.optimizer_params[MAX_GRAD_NORM] = 0.0 diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 2f5916df753a..e9690105de3c 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -2,6 +2,7 @@ Copyright (c) Microsoft Corporation Licensed under the MIT license. """ +import torch ############################################# # Routes @@ -128,6 +129,18 @@ FP16_ENABLED = "enabled" FP16_ENABLED_DEFAULT = False +FP16_TYPE = "type" +FP16_TYPE_DEFAULT = "fp16" +PRECISION_TYPES = { + "fp32": torch.float32, + "float32": torch.float32, + "float": torch.float32, + "fp16": torch.half, + "float16": torch.half, + "half": torch.half, + "bfloat16": torch.bfloat16 +} + # FP16 loss scale, zero means using dynamic scaling FP16_LOSS_SCALE = "loss_scale" FP16_LOSS_SCALE_DEFAULT = 0 @@ -189,6 +202,7 @@ ''' FP32_ALLREDUCE = "fp32_allreduce" FP32_ALLREDUCE_DEFAULT = False +FP32_ALLREDUCE_DEFAULT_BF16 = True # if dtype is bf16 - default to fp32 communication ######################################### # Scale/predivide gradients before allreduce diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 4704f7f6817f..38a62a812cea 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -26,6 +26,7 @@ from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT +from deepspeed.runtime.comm import compressed_all_reduce from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ @@ -61,6 +62,7 @@ def split_half_float_double_csr(tensors): "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", + "torch.cuda.BFloat16Tensor", CSRTensor.type() ] buckets = [] @@ -100,6 +102,7 @@ def print_configuration(args, name): class DeepSpeedEngine(Module): r"""DeepSpeed engine for training. """ + def __init__(self, args, model, @@ -136,6 +139,7 @@ def __init__(self, self.store_gradients = False self.store_gradients_cpu = False self.stored_gradients = None + self.bf16_compressed_allreduce = False # hardcode for now - it's not really working if dist_init_required is None: dist_init_required = not dist.is_initialized() @@ -397,6 +401,9 @@ def zero_gather_fp16_weights_on_model_save(self): def fp16_enabled(self): return self._config.fp16_enabled + def precision(self): + return self._config.precision + def amp_enabled(self): return self._config.amp_enabled @@ -530,9 +537,10 @@ def _do_args_sanity_check(self, args): args.deepspeed_config = args.deepscale_config assert "LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment variable, it is set by the deepspeed launcher, " \ - "deepspeed.init_distributed, or the torch.distributed launcher. If using a different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." + "deepspeed.init_distributed, or the torch.distributed launcher. If using a different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." if hasattr(args, 'local_rank') and args.local_rank != None: - assert isinstance(args.local_rank, int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" + assert isinstance(args.local_rank, + int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" if args.local_rank >= 0: env_local_rank = int(os.environ.get("LOCAL_RANK")) assert env_local_rank == args.local_rank, \ @@ -569,14 +577,19 @@ def is_replicated(p): for p in self.module.parameters(): if torch.is_tensor(p) and is_replicated(p): + if self.precision() == torch.bfloat16 and self.allreduce_always_fp32(): + p.data = p.float().data + dist.broadcast(p, self.broadcast_src_rank, group=self.data_parallel_group) + if self.precision() == torch.bfloat16 and self.allreduce_always_fp32(): + p.data = p.to(self.precision()).data def _configure_distributed_model(self, model): self.module = model if self.fp16_enabled(): - self.module.half() + self.module.to(self.precision()) if not self.dont_change_device: self.module.to(self.device) @@ -772,7 +785,8 @@ def _configure_zero_optimizer(self, optimizer): max_elements_per_comm=self.zero_reduce_bucket_size(), dp_process_group=self.data_parallel_group, elastic_checkpoint=self.zero_elastic_checkpoint(), - mpu=self.mpu) + mpu=self.mpu, + precision=self.precision()) elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS: optimizer = FP16_DeepSpeedZeroOptimizer( optimizer, @@ -791,7 +805,8 @@ def _configure_zero_optimizer(self, optimizer): mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), - gradient_accumulation_steps=self.gradient_accumulation_steps()) + gradient_accumulation_steps=self.gradient_accumulation_steps(), + precision=self.precision()) elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS: print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 @@ -979,6 +994,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Communicate only at gradient accumulation boundaries elif self.is_gradient_accumulation_boundary(): + # TODO: communication in fp16 / fp32 if self.zero_optimization_stage( ) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter(): self.optimizer.reduce_scatter_gradients( @@ -1143,7 +1159,7 @@ def _take_model_step(self, lr_kwargs): self.lr_scheduler.step(**(lr_kwargs or {})) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: - self._report_progress(self.global_steps + 1) + self._report_progress(self.global_steps + 1) self.timers('_step_check_overflow').stop() self.global_steps += 1 @@ -1276,14 +1292,16 @@ def allreduce_bucket(self, bucket): tensor_to_allreduce = tensor - if self.allreduce_always_fp32(): + if self.allreduce_always_fp32() and not self.bf16_compressed_allreduce: tensor_to_allreduce = tensor.float() if self.postscale_gradients(): if self.gradient_predivide_factor() != 1.0: tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor()) - - dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) + if self.bf16_compressed_allreduce and self.precision() == torch.bfloat16: + compressed_all_reduce(tensor_to_allreduce, group=self.data_parallel_group) + else: + dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) if self.gradient_average: if self.gradient_predivide_factor() != self.dp_world_size: @@ -1291,9 +1309,12 @@ def allreduce_bucket(self, bucket): self.dp_world_size) else: tensor_to_allreduce.div_(self.dp_world_size) - dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) + if self.bf16_compressed_allreduce and self.precision() == torch.bfloat16: + compressed_all_reduce(tensor_to_allreduce, group=self.data_parallel_group) + else: + dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) - if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce: + if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce and not self.bf16_compressed_allreduce: tensor.copy_(tensor_to_allreduce) return tensor @@ -1813,8 +1834,8 @@ def get_layer_state_dict(module, prefix=""): else: state_dict[key] = param.detach().cpu() shared_weights[data_ptr_id] = key - #print(f"param {name} {param.shape}") - #print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}") + # print(f"param {name} {param.shape}") + # print(f"param {key} {param.shape} {state_dict[key].storage().data_ptr()}") # now buffers - not sure if need to take care of potentially shared weights here for name, buf in module.named_buffers(recurse=False): diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 29f16181d5e3..42d4a0b15278 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -480,6 +480,9 @@ def inference_batch(self, data_iter): presents_shape = presents_shape_tensor.tolist() if self.is_last_stage(): + if self.precision() == torch.bfloat16 and self.allreduce_always_fp32(): + logits = logits.to(torch.float) + presents = presents.to(torch.float) dist.broadcast(tensor=logits, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) @@ -488,9 +491,9 @@ def inference_batch(self, data_iter): group=self.mpu.get_pipe_parallel_group()) else: - logits = torch.zeros(logits_shape, dtype=torch.half if self.fp16_enabled() else torch.float32).to( + logits = torch.zeros(logits_shape, dtype=self.precision() if self.precision() != torch.bfloat16 else torch.float32).to( self.device) - presents = torch.zeros(presents_shape, dtype=torch.half if self.fp16_enabled() else torch.float32).to( + presents = torch.zeros(presents_shape, dtype=self.precision() if self.precision() != torch.bfloat16 else torch.float32).to( self.device) src_rank = self.grid.stage_to_global(self.num_stages - 1) assert src_rank in self.grid.pp_group @@ -500,6 +503,8 @@ def inference_batch(self, data_iter): dist.broadcast(tensor=presents, src=src_rank, group=self.grid.get_pipe_parallel_group()) + if self.precision() == torch.bfloat16 and self.allreduce_always_fp32(): + logits, presents = logits.to(self.precision()), presents.to(self.precision()) logits = logits.clone().detach() presents = presents.clone().detach() @@ -924,10 +929,10 @@ def _exec_send_activations(self, buffer_id): self._send_tensor_meta(outputs, self.next_stage) if isinstance(outputs, torch.Tensor): - p2p.send(outputs, self.next_stage) + p2p.send(outputs, self.next_stage, fp32_comm=self.allreduce_always_fp32()) elif isinstance(outputs, tuple): for idx, buffer in enumerate(outputs): - p2p.send(buffer, self.next_stage) + p2p.send(buffer, self.next_stage, fp32_comm=self.allreduce_always_fp32()) else: raise NotImplementedError('Could not send output of type ' f'{type(outputs)}') @@ -970,13 +975,13 @@ def _exec_send_grads(self, buffer_id): if isinstance(inputs, torch.Tensor): assert inputs.grad is not None - p2p.send(inputs.grad, self.prev_stage) + p2p.send(inputs.grad, self.prev_stage, fp32_comm=self.allreduce_always_fp32()) else: # XXX terrible hacky branch if self.is_grad_partitioned: # First two sends are partitioned gradient - p2p.send(inputs[0], self.prev_stage) - p2p.send(inputs[1], self.prev_stage) + p2p.send(inputs[0], self.prev_stage, fp32_comm=self.allreduce_always_fp32()) + p2p.send(inputs[1], self.prev_stage, fp32_comm=self.allreduce_always_fp32()) # XXX hack hack hack # p2p.send(inputs[2].grad, self.prev_stage) else: @@ -986,7 +991,7 @@ def _exec_send_grads(self, buffer_id): assert buffer.grad is None continue assert buffer.grad is not None - p2p.send(buffer.grad, self.prev_stage) + p2p.send(buffer.grad, self.prev_stage, fp32_comm=self.allreduce_always_fp32()) # We can free up the input buffer now self.pipe_buffers['inputs'][buffer_id] = None @@ -1004,7 +1009,7 @@ def _exec_recv_activations(self, buffer_id): self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage) if isinstance(self.pipe_recv_buf, torch.Tensor): - p2p.recv(self.pipe_recv_buf, self.prev_stage) + p2p.recv(self.pipe_recv_buf, self.prev_stage, fp32_comm=self.allreduce_always_fp32()) recvd = self.pipe_recv_buf.clone().detach() recvd.requires_grad = recvd.is_floating_point() else: @@ -1020,7 +1025,7 @@ def _exec_recv_activations(self, buffer_id): device=self.device) buffer = self.meta_buffer - p2p.recv(buffer, self.prev_stage) + p2p.recv(buffer, self.prev_stage, fp32_comm=self.allreduce_always_fp32()) recvd[idx] = buffer.clone().detach() # NCCL does not like to send torch.BoolTensor types, so un-cast the @@ -1066,7 +1071,7 @@ def _exec_recv_grads(self, buffer_id): self.grad_layer = self._allocate_buffers(sizes, num_buffers=1)[0] if isinstance(self.grad_layer, torch.Tensor): - p2p.recv(self.grad_layer, self.next_stage) + p2p.recv(self.grad_layer, self.next_stage, fp32_comm=self.allreduce_always_fp32()) else: assert isinstance(outputs, tuple) for idx, buffer in enumerate(self.grad_layer): @@ -1075,7 +1080,7 @@ def _exec_recv_grads(self, buffer_id): buffer.data = torch.zeros(buffer.size(), dtype=torch.long, device=self.device) - p2p.recv(buffer, self.next_stage) + p2p.recv(buffer, self.next_stage, fp32_comm=self.allreduce_always_fp32()) if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() @@ -1128,13 +1133,9 @@ def _allocate_zeros(self, shape, fp16=None, **kwargs): A tensor from torch.zeros() allocated on self.device. """ - if fp16 is None: - fp16 = self.fp16_enabled() + precision = self.precision() if self.precision() != torch.bfloat16 else torch.float32 + return torch.zeros(shape, dtype=precision, device=self.device, **kwargs) - if fp16: - return torch.zeros(shape, dtype=torch.half, device=self.device, **kwargs) - else: - return torch.zeros(shape, device=self.device, **kwargs) def _allocate_buffer(self, shape, num_buffers=-1, **kwargs): buffers = [] diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py index 24c0f250a4b9..c323cd21a02d 100644 --- a/deepspeed/runtime/pipe/p2p.py +++ b/deepspeed/runtime/pipe/p2p.py @@ -28,31 +28,37 @@ def _is_valid_send_recv(src_stage, dest_stage): "Functionality currently limited to send and receive between adjacent ranks only" -def send(tensor, dest_stage, async_op=False): +def send(tensor, dest_stage, async_op=False, fp32_comm=False): global _groups async_op = False src_stage = _grid.get_stage_id() _is_valid_send_recv(src_stage, dest_stage) - + tensor_to_broadcast = tensor + if fp32_comm: + tensor_to_broadcast = tensor_to_broadcast.float() group = _get_send_recv_group(src_stage, dest_stage) src_rank = _grid.stage_to_global(stage_id=src_stage) - - return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) + dist.broadcast(tensor_to_broadcast, src_rank, group=group, async_op=async_op) + if fp32_comm and tensor is not tensor_to_broadcast: + tensor.copy_(tensor_to_broadcast) -def recv(tensor, src_stage, async_op=False): +def recv(tensor, src_stage, async_op=False, fp32_comm=False): global _groups async_op = False dest_stage = _grid.get_stage_id() _is_valid_send_recv(src_stage, dest_stage) - + tensor_to_broadcast = tensor + if fp32_comm: + tensor_to_broadcast = tensor_to_broadcast.float() group = _get_send_recv_group(src_stage, dest_stage) src_rank = _grid.stage_to_global(stage_id=src_stage) - - return dist.broadcast(tensor, src_rank, group=group, async_op=async_op) + dist.broadcast(tensor_to_broadcast, src_rank, group=group, async_op=async_op) + if fp32_comm and tensor is not tensor_to_broadcast: + tensor.copy_(tensor_to_broadcast) def barrier(stage_id): diff --git a/deepspeed/runtime/zero/stage1.py b/deepspeed/runtime/zero/stage1.py index b75d5b4b1fd3..776cd6fe0975 100755 --- a/deepspeed/runtime/zero/stage1.py +++ b/deepspeed/runtime/zero/stage1.py @@ -120,13 +120,21 @@ def __init__(self, allgather_size=500000000, clip_grad=0.0, max_elements_per_comm=5e8, - elastic_checkpoint=True): + elastic_checkpoint=True, + precision=torch.half): # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten + # set precision + self.precision = precision + if self.precision == torch.bfloat16: + self.fp32_allreduce = True + else: + self.fp32_allreduce = False + if dp_process_group is not None and partition_size is not None: raise ValueError("Cannot specify both dp_process_group " "and partition size") @@ -622,6 +630,7 @@ def reduce_scatter_gradients(self, postscale_gradients, gradient_predivide_factor, gradient_average): + world_size = dist.get_world_size(group=self.dp_process_group) local_rank = dist.get_rank(group=self.dp_process_group) @@ -634,7 +643,7 @@ def reduce_scatter_gradients(self, comm_tensor_list=self.params_in_rank_sub_partitions[i][rank], comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i] [rank], - dtype=torch.half, + dtype=self.precision, default_device=self.default_device, sub_partition_size=self.sub_partition_sizes[i], num_comm_intervals=self.num_comm_intervals_per_group[i]) @@ -646,7 +655,10 @@ def reduce_scatter_gradients(self, for comm_idx in range(num_comm_intervals): single_comm_all_partitions = [] for rank in range(world_size): - single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx]) + if self.fp32_allreduce: + single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx].float()) + else: + single_comm_all_partitions.append(all_sub_partitions[rank][comm_idx]) if postscale_gradients: if gradient_predivide_factor != 1.0: @@ -669,6 +681,9 @@ def reduce_scatter_gradients(self, dist.reduce_scatter(output=single_comm_all_partitions[local_rank], input_list=single_comm_all_partitions, group=self.dp_process_group) + if self.fp32_allreduce: + for rank in range(world_size): + all_sub_partitions[rank][comm_idx] = all_sub_partitions[rank][comm_idx].to(self.precision) def step(self, closure=None, comms_timer=None): # First compute norm for all group so we know if there is overflow @@ -747,9 +762,15 @@ def step(self, closure=None, comms_timer=None): #gather the updated weights from everyone for fp16_all_sub_partitions in self.parallel_comm_sub_partitioned_fp16_groups: for comm_id, sub_partitions in enumerate(fp16_all_sub_partitions): + if self.fp32_allreduce: + for i in range(len(sub_partitions)): + sub_partitions[i] = sub_partitions[i].float() dist.all_gather(sub_partitions, sub_partitions[partition_id], group=self.dp_process_group) + if self.fp32_allreduce: + for i in range(len(sub_partitions)): + sub_partitions[i] = sub_partitions[i].to(self.precision) if comms_timer is not None: comms_timer.stop() diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 39d780e55574..e98a374902ce 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -95,7 +95,8 @@ def __init__(self, allreduce_always_fp32=False, postscale_gradients=True, gradient_predivide_factor=1.0, - gradient_accumulation_steps=1): + gradient_accumulation_steps=1, + precision=torch.half): if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") @@ -113,6 +114,9 @@ def __init__(self, if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer + self.precision = precision + self.fp32_allreduce = True if self.precision == torch.bfloat16 else allreduce_always_fp32 + # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() @@ -144,14 +148,12 @@ def __init__(self, self.overflow = False self.clip_grad = clip_grad - self.allreduce_always_fp32 = allreduce_always_fp32 self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps self.micro_step_id = 0 if self.reduce_scatter: - assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" @@ -306,10 +308,10 @@ def __init__(self, self.grad_position = {} self.temp_grad_buffer_for_cpu_offload = torch.zeros( largest_param_numel, - device=self.device).half().pin_memory() + device=self.device).to(self.precision).pin_memory() self.temp_grad_buffer_for_gpu_offload = torch.zeros( largest_param_numel, - device=torch.cuda.current_device()).half() + device=torch.cuda.current_device()).to(self.precision) for i, params_group in enumerate(self.fp16_groups): self.get_grad_position(i, @@ -654,13 +656,13 @@ def gradient_reduction_w_predivide(self, tensor): tensor_to_allreduce = tensor - if self.allreduce_always_fp32: + if self.fp32_allreduce: tensor_to_allreduce = tensor.float() if self.postscale_gradients: if self.gradient_predivide_factor != 1.0: tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) - + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: @@ -669,7 +671,7 @@ def gradient_reduction_w_predivide(self, tensor): tensor_to_allreduce.div_(dp_world_size) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: + if self.fp32_allreduce and tensor is not tensor_to_allreduce: tensor.copy_(tensor_to_allreduce) return tensor @@ -730,6 +732,8 @@ def average_tensor(self, tensor): for dst, bucket_offset, numel in rank_and_offsets: grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) dst_rank = _get_global_rank(self.dp_process_group, dst) + if self.fp32_allreduce: + grad_slice = grad_slice.float() async_handle = dist.reduce(grad_slice, dst=dst_rank, group=self.dp_process_group, @@ -738,6 +742,7 @@ def average_tensor(self, tensor): for handle in async_handles: handle.wait() + grad_slice = grad_slice.to(self.precision) ############################################################################## ############################# CPU Offload Methods############################# @@ -890,12 +895,12 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.SUM, group=self.dp_process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) + op=dist.ReduceOp.SUM) total_norm = total_norm_cuda[0].item()**(1. / norm_type) @@ -1101,7 +1106,7 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None): stream = torch.cuda.current_stream() with torch.cuda.stream(stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log, allreduce_always_fp32=self.fp32_allreduce) if rank is None or rank == dist.get_rank(group=self.dp_process_group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) @@ -1214,9 +1219,13 @@ def _model_parallel_all_reduce(self, tensor, op): if self.model_parallel_group is None: pass else: - torch.distributed.all_reduce(tensor=tensor, + if self.fp32_allreduce: + tensor = tensor.float() + dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group) + if self.fp32_allreduce: + tensor = tensor.to(self.precision) def get_grad_norm_direct(self, gradients, params, norm_type=2): """Clips gradient norm of an iterable of parameters. @@ -1239,13 +1248,13 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.MAX, group=self.dp_process_group) # Take max across all GPUs. self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) + op=dist.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 @@ -1258,12 +1267,12 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, + dist.all_reduce(total_norm_cuda, + op=dist.ReduceOp.SUM, group=self.dp_process_group) self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) + op=dist.ReduceOp.SUM) total_norm = total_norm_cuda[0].item()**(1. / norm_type) @@ -1494,11 +1503,18 @@ def step(self, closure=None): 0, shard_id * shard_size, num_elements).detach() + if self.fp32_allreduce: + curr_shard = curr_shard.float() shard_list.append(curr_shard) dist.all_gather(shard_list, shard_list[partition_id], group=self.dp_process_group) + + if self.fp32_allreduce: + for i in range(len(shard_list)): + shard_list[i] = shard_list[i].to(self.precision) + self.stop_timers([OPTIMIZER_ALLGATHER]) # TODO: we probably don't need this? just to be safe @@ -1558,8 +1574,8 @@ def has_overflow(self, partition_gradients=True): overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial( ) overflow_gpu = torch.cuda.ByteTensor([overflow]) - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, + dist.all_reduce(overflow_gpu, + op=dist.ReduceOp.MAX, group=self.dp_process_group) else: @@ -1574,7 +1590,7 @@ def has_overflow(self, partition_gradients=True): # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs self._model_parallel_all_reduce(tensor=overflow_gpu, - op=torch.distributed.ReduceOp.MAX) + op=dist.ReduceOp.MAX) overflow = overflow_gpu[0].item() return bool(overflow) @@ -1866,7 +1882,7 @@ def load_state_dict(self, def _handle_overflow(cpu_sum, x, i): import math - rank = torch.distributed.get_rank() + rank = dist.get_rank() if rank == 0: t_i = -1 for v_i, v in enumerate(x.data.contiguous().view(-1)):