Skip to content

Commit

Permalink
Merge pull request deepspeedai#17 from rraminen/IFU_5_27
Browse files Browse the repository at this point in the history
IFU-master-2021-05-27
  • Loading branch information
jithunnair-amd authored Jun 4, 2021
2 parents 4c7a252 + 5de081e commit 1850f88
Show file tree
Hide file tree
Showing 97 changed files with 58,512 additions and 924 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/torch16.yml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ name: Torch16

# Controls when the action will run.
on:
#pull_request:
# paths-ignore:
# - 'docs/**'
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

Expand Down Expand Up @@ -37,6 +34,11 @@ jobs:
run: |
pip install .[dev]
ds_report
- name: Formatting checks
run: |
pre-commit run --all-files
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion DeepSpeedExamples
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)


# News
* [2021/05/24] [DeepSpeed: Accelerating large-scale model inference and training via system optimizations and compression](https://www.microsoft.com/en-us/research/blog/deepspeed-accelerating-large-scale-model-inference-and-training-via-system-optimizations-and-compression/)
* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)
* [Tutorial on how to use different stages of ZeRO](https://www.deepspeed.ai/tutorials/zero/)
* [2021/04/01] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed)
* [2021/04/01] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/python-sdk/workflows/train/deepspeed)
* [2021/03/30] [[PyTorch Lightning Blog] Accessible Multi-Billion Parameter Model Training with PyTorch Lightning + DeepSpeed](https://medium.com/pytorch-lightning/accessible-multi-billion-parameter-model-training-with-pytorch-lightning-deepspeed-c9333ac3bb59)
* [2021/03/16] [1-bit Adam v2: NCCL-based implementation and more](https://www.deepspeed.ai/tutorials/onebit-adam/)
* [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)
Expand Down Expand Up @@ -153,14 +154,14 @@ All DeepSpeed documentation can be found on our website: [deepspeed.ai](https://
| Article | Description |
| ---------------------------------------------------------------------------------------------- | -------------------------------------------- |
| [DeepSpeed Features](https://www.deepspeed.ai/features/) | DeepSpeed features |
| [Getting Started](https://www.deepspeed.ai/getting-started/) | First steps with DeepSpeed |
| [Getting Started](https://www.deepspeed.ai/getting-started/) | First steps with DeepSpeed |
| [DeepSpeed JSON Configuration](https://www.deepspeed.ai/docs/config-json/) | Configuring DeepSpeed |
| [API Documentation](https://deepspeed.readthedocs.io/en/latest/) | Generated DeepSpeed API documentation |
| [CIFAR-10 Tutorial](https://www.deepspeed.ai/tutorials/cifar-10) | Getting started with CIFAR-10 and DeepSpeed |
| [Megatron-LM Tutorial](https://www.deepspeed.ai/tutorials/megatron/) | Train GPT2 with DeepSpeed and Megatron-LM |
| [BERT Pre-training Tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/) | Pre-train BERT with DeepSpeed |
| [BERT Pre-training Tutorial](https://www.deepspeed.ai/tutorials/bert-pretraining/) | Pre-train BERT with DeepSpeed |
| [Learning Rate Range Test Tutorial](https://www.deepspeed.ai/tutorials/lrrt/) | Faster training with large learning rates |
| [1Cycle Tutorial](https://www.deepspeed.ai/tutorials/1Cycle/) | SOTA learning schedule in DeepSpeed |
| [1Cycle Tutorial](https://www.deepspeed.ai/tutorials/one-cycle/) | SOTA learning schedule in DeepSpeed |



Expand Down
26 changes: 26 additions & 0 deletions csrc/includes/custom_cuda_layers.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@

#define MAX_REGISTERS 256

#define MAX_REG 256

template <typename T>
void launch_qunatize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_qunatize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_qunatize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_qunatize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
// Fused bias add with gelu activation
template <typename T>
void launch_bias_gelu(const T* input,
Expand Down
9 changes: 9 additions & 0 deletions csrc/includes/quantizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
6 changes: 4 additions & 2 deletions csrc/lamb/fused_lamb_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
#ifndef _WIN32
extern __device__ void error(void);
error();
#endif
return NULL;
}
};
Expand Down Expand Up @@ -325,13 +327,13 @@ __global__ void lamb_cuda_kernel_part3(

float lamb_coeff = 1.0;

if (reg_w != 0 and reg_u != 0) {
if (reg_w != 0 && reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}

if (blockId == 0 and threadIdInBlock == 0) {
if (blockId == 0 && threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
Expand Down
77 changes: 77 additions & 0 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_cuda_layers.h"

template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;

if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_qunatize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}

template <typename T>
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;

if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_qunatize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}

template <typename T>
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;

if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_qunatize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}

template <typename T>
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;

if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_qunatize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def(
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_asym_fp32",
&ds_sr_quantize_asym<float>,
"DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
}
Loading

0 comments on commit 1850f88

Please # to comment.