Skip to content

Commit

Permalink
[CPU] add fp16 support to shm inference_all_reduce (#5669)
Browse files Browse the repository at this point in the history
This PR adds FP16 support to DeepSpeed SHM inference_all_reduce.
Previously only FP32 and BF16 is supported. This is to align with
PyTorch CPU support on FP16 datatype.

---------

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
3 people authored Jun 24, 2024
1 parent d89e8cd commit e16de6d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 70 deletions.
128 changes: 60 additions & 68 deletions csrc/cpu/comm/shm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,19 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src)
return _mm512_cvtusepi32_epi16(t_value);
}

void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
__attribute__((target("avx512bw")));
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_fp16_to_fp32(const __m256i src) { return _mm512_cvtph_ps(src); }

inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_fp16(const __m512 src)
{
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}

void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));

void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));

void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
Expand All @@ -164,26 +170,13 @@ void reduce_all_buffers(int start_elements,
{
switch (scalar_type) {
case c10::ScalarType::BFloat16:
if (world_size == 2) {
// add the other buffer to to_buffer
reduce_2_bf16_buffers_iio(num_elements,
buffers[1 - to_buffer_idx] + start_elements * 2,
to_buffer + start_elements * 2,
to_buffer + start_elements * 2);
} else {
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
}
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Half:
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Float:
if (world_size == 2) {
reduce_2_fp32_buffers_iio(num_elements,
buffers[1 - to_buffer_idx] + start_elements * 4,
to_buffer + start_elements * 4,
to_buffer + start_elements * 4);
} else {
assert(world_size > 2);
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
}
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
break;
default: assert(!"Should not get here");
}
Expand All @@ -197,8 +190,8 @@ void reduce_all_buffers(int start_elements,

// Reduce functions down below use vectorized algorithm, the number of bytes processed each
// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes
// If you change implementation of reduce_2_bf16_buffers_iio or reduce_2_fp32_buffers_iio, check
// whether this number needs to be changed
// If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs
// to be changed
#define VECTOR_LENGTH_IN_BYTES 32

void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
Expand Down Expand Up @@ -227,10 +220,9 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
case 6: CVT_ADD_BF16(5);
case 5: CVT_ADD_BF16(4);
case 4: CVT_ADD_BF16(3);
case 3:
CVT_ADD_BF16(2);
CVT_ADD_BF16(1);
break;
case 3: CVT_ADD_BF16(2);
case 2: CVT_ADD_BF16(1);
case 1: break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
Expand All @@ -251,7 +243,13 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
}
}

void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
#define CVT_ADD_FP16(x) \
do { \
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)

void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
{
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
Expand All @@ -260,19 +258,41 @@ void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out

// process aligned part
#pragma omp parallel for
for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
auto in0_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in0 + i)));
auto in1_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in1 + i)));
auto out_val = _mm512_add_ps(in0_val, in1_val);
_mm256_storeu_si256((__m256i*)((char*)out + i), cvt_fp32_to_bf16(out_val));
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16: CVT_ADD_FP16(15);
case 15: CVT_ADD_FP16(14);
case 14: CVT_ADD_FP16(13);
case 13: CVT_ADD_FP16(12);
case 12: CVT_ADD_FP16(11);
case 11: CVT_ADD_FP16(10);
case 10: CVT_ADD_FP16(9);
case 9: CVT_ADD_FP16(8);
case 8: CVT_ADD_FP16(7);
case 7: CVT_ADD_FP16(6);
case 6: CVT_ADD_FP16(5);
case 5: CVT_ADD_FP16(4);
case 4: CVT_ADD_FP16(3);
case 3: CVT_ADD_FP16(2);
case 2: CVT_ADD_FP16(1);
case 1: break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
}

// process remaining part
int i = main_elements * element_size;
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float in0_val = *((at::BFloat16*)((char*)in0 + i));
float in1_val = *((at::BFloat16*)((char*)in1 + i));
*((at::BFloat16*)((char*)out + i)) = in0_val + in1_val;
float val = 0.0f;
for (int j = 0; j < world_size; j++) { val += *(at::Half*)(buffers[j] + i); }
*(at::Half*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
Expand Down Expand Up @@ -310,10 +330,9 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
case 6: CVT_ADD_F32(5);
case 5: CVT_ADD_F32(4);
case 4: CVT_ADD_F32(3);
case 3:
CVT_ADD_F32(2);
CVT_ADD_F32(1);
break;
case 3: CVT_ADD_F32(2);
case 2: CVT_ADD_F32(1);
case 1: break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
Expand All @@ -334,33 +353,6 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
}
}

void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
{
const int element_size = 4;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;

// process aligned part
#pragma omp parallel for
for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
auto in0_val = _mm256_loadu_ps((float*)((char*)in0 + i));
auto in1_val = _mm256_loadu_ps((float*)((char*)in1 + i));
auto out_val = _mm256_add_ps(in0_val, in1_val);
_mm256_storeu_ps((float*)((char*)out + i), out_val);
}

// process remaining part
int i = main_elements * element_size;
while (remain_elements > 0) {
float in0_val = *((float*)((char*)in0 + i));
float in1_val = *((float*)((char*)in1 + i));
*((float*)((char*)out + i)) = in0_val + in1_val;
remain_elements--;
i += element_size;
}
}

static bool is_initialized = 0;
static int world_rank;

Expand Down
1 change: 1 addition & 0 deletions csrc/cpu/comm/shm_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)

switch (data.scalar_type()) {
case c10::ScalarType::BFloat16: data_size = numel * 2; break;
case c10::ScalarType::Half: data_size = numel * 2; break;
case c10::ScalarType::Float: data_size = numel * 4; break;
default: data_type_fallback = true;
}
Expand Down
13 changes: 11 additions & 2 deletions tests/unit/comm/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,22 @@ def test(self):
assert torch.all(x == result)


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
class TestDistInferenceAllReduce(DistributedTest):
world_size = 4
device_count = get_accelerator().device_count()
if device_count >= 4:
world_size = [1, 2, 4]
elif device_count >= 2:
world_size = [1, 2]
else:
world_size = [1]

def test(self):
def test(self, dtype):
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
result = result.to(dtype)
x = x.to(dtype)
dist.inference_all_reduce(x)
assert torch.all(x == result)

Expand Down

0 comments on commit e16de6d

Please # to comment.