Skip to content

Commit

Permalink
Merge branch 'master' into zhejiang/reduce_host_overhead_moe
Browse files Browse the repository at this point in the history
  • Loading branch information
HeyangQin authored Jul 16, 2024
2 parents d7ce884 + 78c6c44 commit 686f511
Show file tree
Hide file tree
Showing 24 changed files with 191 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
unit-tests:
runs-on: [self-hosted, cpu]

env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions

steps:
- uses: actions/checkout@v3

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-human-eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
options: --gpus all --shm-size "8G"

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Check container state
run: |
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/nv-lightning-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, v100]

env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions

steps:
- uses: actions/checkout@v3

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/nv-torch110-p40.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, p40]

env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions

steps:
- uses: actions/checkout@v3

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/nv-torch110-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, v100]

env: {ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true} # Allow using Node16 actions

steps:
- uses: actions/checkout@v3

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ jobs:
unit-tests:
strategy:
matrix:
pyVersion: ["3.6", "3.7", "3.8", "3.9", "3.10"]
pyVersion: ["3.7", "3.8", "3.9", "3.10"]
fail-fast: false

runs-on: ubuntu-20.04
container:
image: deepspeed/gh-builder:py${{ matrix.pyVersion }}

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: environment
run: |
Expand Down
15 changes: 8 additions & 7 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# DeepSpeed Team

import functools
import os
import pkgutil
import importlib
Expand Down Expand Up @@ -196,31 +197,31 @@ def replay_graph(self, graph):
# Tensor operations
@property
def BFloat16Tensor(self):
return self.hpu.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='hpu')

@property
def ByteTensor(self):
return self.hpu.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device='hpu')

@property
def DoubleTensor(self):
return self.hpu.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device='hpu')

@property
def FloatTensor(self):
return self.hpu.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device='hpu')

@property
def HalfTensor(self):
return self.hpu.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device='hpu')

@property
def IntTensor(self):
return self.hpu.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device='hpu')

@property
def LongTensor(self):
return self.hpu.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device='hpu')

def pin_memory(self, tensor, align_bytes=1):
return tensor.pin_memory(self.device())
Expand Down
2 changes: 2 additions & 0 deletions bin/deepspeed.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
@echo off
python "%~dp0\ds" %*
2 changes: 2 additions & 0 deletions bin/ds_report.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
@echo off
python "%~dp0\ds_report" %*
4 changes: 4 additions & 0 deletions blogs/deepspeed-fastgen/chinese/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ DeepSpeed-FastGen 是 [DeepSpeed-MII](https://github.com/microsoft/DeepSpeed-MII
* [LLaMA](https://huggingface.co/models?other=llama)[LLaMA-2](https://huggingface.co/models?other=llama-2)
* [Mistral](https://huggingface.co/models?other=mistral)
* [OPT](https://huggingface.co/models?other=opt)
* [Falcon](https://huggingface.co/models?other=falcon)
* [Mixtral](https://huggingface.co/models?other=mixtral)
* [Phi-2](https://huggingface.co/models?other=phi-msft)
* [Qwen](https://huggingface.co/models?other=qwen)

所有当前模型都利用了后端的 [HuggingFace](https://github.com/huggingface) API 来提供模型权重和模型对应的分词器。

Expand Down
1 change: 1 addition & 0 deletions build_win.bat
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@echo off

set CUDA_HOME=%CUDA_PATH%
set DISTUTILS_USE_SDK=1

set DS_BUILD_AIO=0
Expand Down
99 changes: 82 additions & 17 deletions csrc/cpu/comm/shm_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ void initialize(int size, int rank)
if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
}

int get_rank(int group = 0) { return world_rank; }

int get_world_size(int group = 0) { return world_size; }
void inference_all_reduce_(torch::Tensor& data, int op);

// Success - return 0
// Fail (cannot hornor the request and need to fall back) - return -1
int inference_all_reduce(torch::Tensor& data, py::object op)
void inference_all_reduce_(torch::Tensor& data, int op)
{
if (!all_ranks_local_p) return -1;
assert(op == 0);
#ifdef DO_PROFILE
static double total_time = 0.0;
static double total_time_sq = 0.0;
Expand All @@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
auto start = std::chrono::system_clock::now();
#endif

static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));

assert(py::int_(op.attr("value")) == ReduceOpSum);

auto numel = data.numel();

int data_size = 0;
Expand All @@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
default: data_type_fallback = true;
}

if (data_type_fallback) return -1;
if (data_type_fallback) return;

all_reduce_outer_loop(data, numel, data_size);

Expand All @@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
}
}
#endif
return 0;
return;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); }

TORCH_LIBRARY(deepspeed, m)
{
m.def("inference_all_reduce(Tensor self) -> Tensor");
m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}

torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_)
{
torch::Tensor result_ = torch::empty_like(self_);
return result_;
}

torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; }

torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_)
{
TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);
torch::Tensor self_tensor = self_.contiguous();
inference_all_reduce_(self_tensor, 0);
return self_;
}

torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_)
{
torch::Tensor result = self_.clone();
inference_all_reduce__cpu(result);
return result;
}

#include <ATen/FunctionalTensorWrapper.h>
// The boilerplate functionalization logic, that teaches functionalization
// how to map x_() calls into x() calls.
// Long term, we'd like to not require users to write this logic.
// HOWEVER, if you have a custom op that is mutable,
// You will still need to write an out-of-place version of that op!
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x)
{
// We expect all tensor inputs to our op to be "functional tensors"
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x));
// First, sync and unwrap and functional tensors
at::functionalization::impl::sync(x);
auto x_ = at::functionalization::impl::from_functional_tensor(x);
// Grab the dispatcher entry corresponding to the out-of-place op, "x"
static auto op_handle = c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("deepspeed::inference_all_reduce", "")
// Specify the C++ schema of the out-of-place op.
.typed<at::Tensor(const at::Tensor&)>();
// Next, redispatch to the out-of-place op, x() (user called x_, we call x)
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(x_);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(x, tmp_output);
at::functionalization::impl::commit_update(x);
at::functionalization::impl::sync(x);
return x;
}

TORCH_LIBRARY_IMPL(deepspeed, CPU, m)
{
m.impl("inference_all_reduce", inference_all_reduce_cpu);
m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}

TORCH_LIBRARY_IMPL(deepspeed, Meta, m)
{
m.impl("inference_all_reduce", inference_all_reduce_meta);
m.impl("inference_all_reduce_", inference_all_reduce__meta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m)
{
m.def("initialize", &initialize, "shm initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def parse_arguments():
dest='strict',
action='store_false',
help='Do not perform validity checks on converted checkpoint.')
parser.add_argument('--inject-missing-state',
parser.add_argument('--inject_missing_state',
action='store_true',
help='Inject missing checkpoint state into the checkpoint if it is absent.')
args = parser.parse_args()
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op, group=None):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/env_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def installed_cann_version():
def get_shm_size():
try:
shm_stats = os.statvfs('/dev/shm')
except (OSError, FileNotFoundError, ValueError):
except (OSError, FileNotFoundError, ValueError, AttributeError):
return "UNKNOWN", None

shm_size = shm_stats.f_frsize * shm_stats.f_blocks
Expand Down
22 changes: 22 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import torch
import time
import os
import deepspeed
from deepspeed import comm as dist
from deepspeed.utils.logging import log_dist

from torch.nn.modules import Module
from packaging import version as pkg_version
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(self, model, config):

# Check if local CUDA graphs can be created in replacement modules
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
self._is_compiled = False

def destroy(self):
# Have to import here because inference_module is a global, but python
Expand Down Expand Up @@ -634,3 +637,22 @@ def _generate(self, *inputs, **kwargs):
)

return self.module.generate(*inputs, **kwargs)

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
"""
Compile the module using the specified backend and kwargs.
"""
if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

if self._is_compiled:
return

# Avoid graph breaks
deepspeed.utils.nvtx.enable_nvtx = False
self.module.compile(backend=backend, **compile_kwargs)
self._is_compiled = True

@property
def is_compiled(self) -> bool:
return self._is_compiled
Loading

0 comments on commit 686f511

Please # to comment.