forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* bf16 initial commit * Update engine.py * update split_half_float_double_csr dtypes * update to bf16 communication (make flag optional) * Update requirements-sparse_attn.txt * add compressed bf16 allreduce * add compressed bf16 allreduce * Update __init__.py * Update engine.py * Update __init__.py * Update engine.py * zero1 + bf16 * zero 2 + bf16 * pipe parallel + bf16 * pipe parallel + bf16 * partition activations + bf16
- Loading branch information
Showing
10 changed files
with
268 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .compressed_ar import compressed_all_reduce |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# python -m torch.distributed.launch --nproc_per_node=1 24_bit_allreduce.py | ||
|
||
import torch | ||
import os | ||
import cupy | ||
from torch.utils.dlpack import to_dlpack | ||
from torch.utils.dlpack import from_dlpack | ||
|
||
version = torch.__version__.split('.') | ||
TORCH_VERSION_MAJOR = int(version[0]) | ||
TORCH_VERSION_MINOR = int(version[1]) | ||
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9): | ||
compressed_all_reduce = compressed_all_reduce_cupy | ||
else: | ||
compressed_all_reduce = compressed_all_reduce_torch | ||
|
||
def torch2cupy(tensor): | ||
return cupy.fromDlpack(to_dlpack(tensor)) | ||
|
||
|
||
def cupy2torch(cupy_tensor): | ||
return from_dlpack(cupy_tensor.toDlpack()) | ||
|
||
|
||
def decompose_cupy(tensor): | ||
mantissa, exponent = cupy.frexp(torch2cupy(tensor.float())) | ||
return cupy2torch(mantissa).half(), cupy2torch(exponent).to(torch.int8) | ||
|
||
|
||
def decompose(t): | ||
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9): | ||
raise Exception('Torch version >= 1.9.0 needed for 24_bit_allreduce.decompose') | ||
mantissa, exponent = torch.frexp(t.float()) | ||
return mantissa.half(), exponent.to(torch.int8) | ||
|
||
|
||
def reconstruct(mantissa, exponent, original_dtype=torch.bfloat16): | ||
return torch.ldexp(mantissa, exponent).to(original_dtype) | ||
|
||
|
||
def compressed_all_reduce_torch(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): | ||
original_dtype = tensor.dtype | ||
m, e = decompose(tensor) | ||
torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op) | ||
torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op) | ||
return reconstruct(m, e, original_dtype) | ||
|
||
|
||
def compressed_all_reduce_cupy(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False): | ||
original_dtype = tensor.dtype | ||
m, e = decompose_cupy(tensor) | ||
torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op) | ||
torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op) | ||
return reconstruct(m, e, original_dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.