Skip to content

Commit

Permalink
Merge pull request deepspeedai#20 from microsoft/samyamr/deepspeed_ch…
Browse files Browse the repository at this point in the history
…eckpointing_configs

Adding support for deepspeed_checkpointing through deepspeed.checkpoi…
  • Loading branch information
samyam authored May 8, 2020
2 parents 8620220 + f4a4070 commit aeec738
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 112 deletions.
123 changes: 71 additions & 52 deletions deepspeed/pt/deepspeed_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Use to partition the activations stored for backward propagation
Therefore reduces the memory consumption
Also implements CPU checkpointing and contigious memory checkpointing
Also implements CPU checkpointing and contiguous memory checkpointing
Reduces memory consumption and memory fragmentation
Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py
Expand All @@ -21,7 +21,11 @@
from deepspeed.pt.deepspeed_timer import SynchronizedWallClockTimer as Timers
import torch.distributed as dist

#DeepSpeed Checkpointing Enabled or Disabled
deepspeed_checkpointing_enabled = False

#MP parameters
mpu = None
mp_rank = None
mp_size = None
mp_group = None
Expand All @@ -30,18 +34,18 @@
num_layers = None

#Checkpointing buffers
contigious_data_buffers = []
contiguous_data_buffers = []
data_offsets = []

contigious_size_buffers = []
contiguous_size_buffers = []
size_offsets = []

timers = None

#optimization flags
PARTITION_ACTIVATIONS = False
PA_TO_CPU = False
CONTIGIOUS_CHECKPOINTING = False
CONTIGUOUS_CHECKPOINTING = False
SYNCHRONIZE = False
PROFILE_TIME = False

Expand Down Expand Up @@ -291,7 +295,8 @@ def get_full_inputs(tensors, device=None):
if i == mp_rank:
part_i.copy_(item)
partitions.append(part_i)
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
if mp_group is not None:
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
input_tensor = flat_tensor.view(list(size.numpy()))
item.data = input_tensor.data

Expand Down Expand Up @@ -324,13 +329,17 @@ def forward(ctx, run_function, *args):
ctx.run_function = run_function
global num_layers
global mp_rank, mp_size, mp_group
global contigious_data_buffers, contigious_size_buffers
global contiguous_data_buffers, contiguous_size_buffers
global data_offsets, size_offsets
if mp_rank is None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()

if mpu is not None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
mp_group = None

global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

Expand All @@ -342,7 +351,7 @@ def forward(ctx, run_function, *args):
f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
)
print(
f"----Contigious Memory Checkpointing {CONTIGIOUS_CHECKPOINTING} with {num_layers} total layers"
f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
)
print(f"----Synchronization {SYNCHRONIZE}")
print(f"----Profiling {PROFILE_TIME}")
Expand All @@ -362,33 +371,33 @@ def forward(ctx, run_function, *args):
get_partition_start(item),
partition_size).clone()

if CONTIGIOUS_CHECKPOINTING:
if CONTIGUOUS_CHECKPOINTING:
buffer_device = torch.device(
'cpu') if PA_TO_CPU else partition.device

if i >= len(contigious_data_buffers):
if i >= len(contiguous_data_buffers):
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contigious_data_buffers.append(tensor_list)
contiguous_data_buffers.append(tensor_list)
data_offsets.append(0)
elif contigious_data_buffers[i] is None:
elif contiguous_data_buffers[i] is None:
tensor_list = [
torch.tensor(()).new_empty([partition_size],
dtype=partition.dtype,
device=buffer_device)
for i in range(num_layers)
]
contigious_data_buffers[i] = tensor_list
contiguous_data_buffers[i] = tensor_list
data_offsets[i] = 0

contigious_partition = contigious_data_buffers[i][
contiguous_partition = contiguous_data_buffers[i][
data_offsets[i]].data.copy_(partition.data)
data_offsets[i] = data_offsets[i] + 1
inputs.append(contigious_partition)
inputs.append(contiguous_partition)
else:
partition = partition.cpu() if PA_TO_CPU else partition
inputs.append(partition)
Expand Down Expand Up @@ -427,33 +436,33 @@ def forward(ctx, run_function, *args):
arg.data = inp.data
new_args.append(arg)

if CONTIGIOUS_CHECKPOINTING:
if CONTIGUOUS_CHECKPOINTING:
numel = size.numel()
if i >= len(contigious_size_buffers):
if i >= len(contiguous_size_buffers):
tmp = torch.tensor(())
contigious_size_buffers.append(
contiguous_size_buffers.append(
tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device))
size_offsets.append(0)
elif contigious_size_buffers[i] is None:
elif contiguous_size_buffers[i] is None:
tmp = torch.tensor(())
contigious_size_buffers[i] = tmp.new_empty([numel * num_layers],
contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers],
dtype=size.dtype,
device=size.device)
size_offsets[i] = 0

contigious_size = contigious_size_buffers[i].narrow(
contiguous_size = contiguous_size_buffers[i].narrow(
0,
size_offsets[i],
numel).data.copy_(size.data)
contigious_size = contigious_size.view_as(size)
contiguous_size = contiguous_size.view_as(size)
size_offsets[i] = size_offsets[i] + numel
new_args.append(contigious_size)
new_args.append(contiguous_size)
else:
new_args.append(size)
#if dist.get_rank() == 0:
# print (f"The stored tensor is {contigious_size} and orginal one is {size} ")
# print (f"The stored tensor is {contiguous_size} and orginal one is {size} ")

ctx.save_for_backward(*new_args)
else:
Expand All @@ -469,25 +478,25 @@ def forward(ctx, run_function, *args):
def backward(ctx, *args):
global timers
#see_memory_usage("In backward", force=True)
#removing pointers to the contigious buffer memory
#removing pointers to the contiguous buffer memory
#so that they can be garbage collected once the checkpoints
#have been used
if SYNCHRONIZE:
torch.cuda.synchronize()
if PROFILE_TIME:
timers('backward').start()

if CONTIGIOUS_CHECKPOINTING:
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contigious_data_buffers, contigious_size_buffers
global contiguous_data_buffers, contiguous_size_buffers

for buffers in contigious_data_buffers:
for buffers in contiguous_data_buffers:
buffers = []

#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
contigious_data_buffers = []
contigious_size_buffers = []
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []

Expand Down Expand Up @@ -560,43 +569,53 @@ def set_num_layers(nlayers):


def reset():
if CONTIGIOUS_CHECKPOINTING:
if CONTIGUOUS_CHECKPOINTING:
global data_offsets, size_offsets
global contigious_data_buffers, contigious_size_buffers
global contiguous_data_buffers, contiguous_size_buffers

for buffers in contigious_data_buffers:
for buffers in contiguous_data_buffers:
buffers = []

#frees up all the pointers to the checkpoints except for the ones
#stored by save for backward
contigious_data_buffers = []
contigious_size_buffers = []
contiguous_data_buffers = []
contiguous_size_buffers = []
data_offsets = []
size_offsets = []


def configure(mpu_,
partition_activations=False,
contigious_checkpointing=False,
nlayers=None,
checkpoint_in_cpu=False,
synchronize=False,
profile_backward=False):
def configure(
mpu_,
enabled=False,
partition_activations=False,
contiguous_checkpointing=False,
nlayers=None,
checkpoint_in_cpu=False,
synchronize=False,
profile_backward=False,
):

global mpu, num_layers
global mpu, num_layers, deepspeed_checkpointing_enabled

global PARTITION_ACTIVATIONS, CONTIGIOUS_CHECKPOINTING, \
global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \
PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME

deepspeed_checkpointing_enabled = enabled

num_layers = nlayers
if checkpoint_in_cpu:
assert partition_activations, "CPU Checkpointing is only availble with partitioned activations"
if contigious_checkpointing:
assert num_layers is not None, "Must specify the number of layers with contigious memory checkpointing"
if checkpoint_in_cpu or contiguous_checkpointing:
assert partition_activations, "CPU Checkpointing/Contiguous Checkpointing is only availble with partitioned activations. Set partitioned activations to true in deepspeed config"
if contiguous_checkpointing:
assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing"

mpu = mpu_
PARTITION_ACTIVATIONS = partition_activations
CONTIGIOUS_CHECKPOINTING = contigious_checkpointing
CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing
PA_TO_CPU = checkpoint_in_cpu
SYNCHRONIZE = synchronize
PROFILE_TIME = profile_backward


def is_configured():
global deepspeed_checkpointing_enabled
return deepspeed_checkpointing_enabled
110 changes: 110 additions & 0 deletions deepspeed/pt/deepspeed_checkpointing_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""

from deepspeed.pt.deepspeed_config_utils import get_scalar_param

#########################################
# DeepSpeed Activation Checkpointing
#########################################
# Activation Checkpointing Allows to save memory by only keeping a select few
#activations for the backpropagation.
ACTIVATION_CHKPT_FORMAT = '''
Activation Checkpointing should be configured as:
"session_params": {
"activation_checkpointing": {
"partitioned_activations": [true|false],
"number_checkpoints": 100,
"contigious_memory_optimization": [true|false],
"cpu_checkpointing": [true|false]
"profile_backward": [true|false],
"synchronize_checkpoint_boundary": [true|false],
}
}
'''

ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations'
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False

ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints'
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None

ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization'
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False

ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary'
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False

ACT_CHKPT_PROFILE_BACKWARD = 'profile_backward'
ACT_CHKPT_PROFILE_BACKWARD_DEFAULT = False

ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing'
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False

ACT_CHKPT = 'activation_checkpointing'

ACT_CHKPT_DEFAULT = {
ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT,
ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION:
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY:
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
ACT_CHKPT_PROFILE_BACKWARD: ACT_CHKPT_PROFILE_BACKWARD_DEFAULT,
ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT
}


class DeepSpeedActivationCheckpointingConfig(object):
def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__()

self.partition_activations = None
self.contiguous_memory_optimization = None
self.cpu_checkpointing = None
self.number_checkpoints = None
self.synchronize_checkpoint_boundary = None
self.profile_backward = None

if ACT_CHKPT in param_dict.keys():
act_chkpt_config_dict = param_dict[ACT_CHKPT]
else:
act_chkpt_config_dict = ACT_CHKPT_DEFAULT

self._initialize(act_chkpt_config_dict)

"""
For json serialization
"""

def repr(self):
return self.__dict__

def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_PARTITION_ACTIVATIONS,
ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT)

self.contiguous_memory_optimization = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION,
ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)

self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_CPU_CHECKPOINTING,
ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT)

self.number_checkpoints = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_NUMBER_CHECKPOINTS,
ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT)

self.profile_backward = get_scalar_param(act_chkpt_config_dict,
ACT_CHKPT_PROFILE_BACKWARD,
ACT_CHKPT_PROFILE_BACKWARD_DEFAULT)

self.synchronize_checkpoint_boundary = get_scalar_param(
act_chkpt_config_dict,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY,
ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)
Loading

0 comments on commit aeec738

Please # to comment.