Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[ROCm] Enable unsafe fp atomics and cleanup gpu_device_functions.h #2853

Open
wants to merge 1 commit into
base: r2.18-rocm-enhanced
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 95 additions & 133 deletions tensorflow/core/util/gpu_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda.h"
#else
#include "rocm/include/hip/hip_complex.h"
#include "rocm/include/hip/hip_fp16.h"
#include "rocm/include/hip/hip_bf16.h"
#endif

#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -350,22 +352,12 @@ __device__ T GpuShuffleSync(unsigned mask, T value, int src_lane,
// See b/69446944.
__device__ inline double GpuShuffleSync(unsigned mask, double value,
int src_lane, int width = warpSize) {
#if GOOGLE_CUDA
auto tmp = __double_as_longlong(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = GpuShuffleSync(mask, hi, src_lane, width);
lo = GpuShuffleSync(mask, lo, src_lane, width);
return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
#elif TENSORFLOW_USE_ROCM
auto tmp = static_cast<uint64_t>(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = __shfl(static_cast<int>(hi), src_lane, width);
lo = __shfl(static_cast<int>(lo), src_lane, width);
return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
static_cast<uint64_t>(lo));
#endif
}
CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleSync, CudaShuffleSync);

Expand All @@ -390,22 +382,12 @@ __device__ inline T GpuShuffleUpSync(unsigned mask, T value, unsigned delta,
__device__ inline double GpuShuffleUpSync(unsigned mask, double value,
unsigned delta,
int width = warpSize) {
#if GOOGLE_CUDA
auto tmp = __double_as_longlong(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = GpuShuffleUpSync(mask, hi, delta, width);
lo = GpuShuffleUpSync(mask, lo, delta, width);
return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
#elif TENSORFLOW_USE_ROCM
auto tmp = static_cast<uint64_t>(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = __shfl_up(static_cast<int>(hi), delta, width);
lo = __shfl_up(static_cast<int>(lo), delta, width);
return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
static_cast<uint64_t>(lo));
#endif
}
CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpSync, CudaShuffleUpSync);

Expand All @@ -430,22 +412,12 @@ __device__ inline T GpuShuffleDownSync(unsigned mask, T value, unsigned delta,
__device__ inline double GpuShuffleDownSync(unsigned mask, double value,
unsigned delta,
int width = warpSize) {
#if GOOGLE_CUDA
auto tmp = __double_as_longlong(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = GpuShuffleDownSync(mask, hi, delta, width);
lo = GpuShuffleDownSync(mask, lo, delta, width);
return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
#elif TENSORFLOW_USE_ROCM
auto tmp = static_cast<uint64_t>(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = __shfl_down(static_cast<int>(hi), delta, width);
lo = __shfl_down(static_cast<int>(lo), delta, width);
return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
static_cast<uint64_t>(lo));
#endif
}
CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownSync, CudaShuffleDownSync);

Expand All @@ -465,43 +437,18 @@ __device__ T GpuShuffleXorSync(unsigned mask, T value, int lane_mask,
#endif
}

#if TENSORFLOW_USE_ROCM
__device__ inline Eigen::half GpuShuffleXorSync(unsigned mask,
Eigen::half value,
int lane_mask,
int width = warpSize) {
assert(!(width & width - 1));
assert(detail::GpuValidateShuffleSyncMask(
mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width)));
// TODO(rocm): This doesn't preserve NaN payload and flushes denorms to zero,
// maybe this should be implemented differently?
return static_cast<Eigen::half>(
__shfl_xor(static_cast<float>(value), lane_mask, width));
}
#endif

// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
// instead of float for lo and hi (which is incorrect with ftz, for example).
// See b/69446944.
__device__ inline double GpuShuffleXorSync(unsigned mask, double value,
int lane_mask,
int width = warpSize) {
#if GOOGLE_CUDA
auto tmp = __double_as_longlong(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = GpuShuffleXorSync(mask, hi, lane_mask, width);
lo = GpuShuffleXorSync(mask, lo, lane_mask, width);
return __longlong_as_double(static_cast<uint64_t>(hi) << 32 | lo);
#elif TENSORFLOW_USE_ROCM
auto tmp = static_cast<uint64_t>(value);
auto lo = static_cast<unsigned>(tmp);
auto hi = static_cast<unsigned>(tmp >> 32);
hi = __shfl_xor(static_cast<int>(hi), lane_mask, width);
lo = __shfl_xor(static_cast<int>(lo), lane_mask, width);
return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
static_cast<uint64_t>(lo));
#endif
}
CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorSync, CudaShuffleXorSync);

Expand Down Expand Up @@ -567,9 +514,22 @@ __global__ void SetToValue(const int count, T* __restrict__ ptr, Tvalue value) {
}

namespace detail {

template <int N, typename T>
__device__ T* AddressSpaceHint(T* ptr) {
#if defined(TENSORFLOW_USE_ROCM)
using AS = __attribute__((address_space(N))) T*;
auto ptr_ = reinterpret_cast<AS>(reinterpret_cast<uintptr_t>(ptr));
return (T*)(ptr_);
#else
return ptr; // NOOP
#endif
}

// Helper function for atomic accumulation implemented as CAS.
template <typename T, typename F>
__device__ T GpuAtomicCasHelper(T* ptr, F accumulate) {
ptr = detail::AddressSpaceHint<1>(ptr);
T old = *ptr;
T assumed;
do {
Expand All @@ -591,24 +551,11 @@ __device__ float GpuAtomicCasHelper(float* ptr, F accumulate) {
}
template <typename F>
__device__ double GpuAtomicCasHelper(double* ptr, F accumulate) {
#if TENSORFLOW_USE_ROCM
// FIXME: remove the workaround below once bug is fixed.
// HIP has a bug in the implementation of __longlong_as_double
// So workaround it by using reinterpret_cast<double*>.
uint64_t result =
GpuAtomicCasHelper(reinterpret_cast<unsigned long long*>(ptr),
[accumulate](tensorflow::uint64 a) {
return __double_as_longlong(
accumulate(*(reinterpret_cast<double*>(&a))));
});
return *(reinterpret_cast<double*>(&result));
#else
return __longlong_as_double(GpuAtomicCasHelper(
reinterpret_cast<unsigned long long*>(ptr),
[accumulate](tensorflow::uint64 a) {
return __double_as_longlong(accumulate(__longlong_as_double(a)));
}));
#endif
}

// Overload of above function for half. Note that we don't have
Expand All @@ -628,31 +575,20 @@ __device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) {
#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
#endif
intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
uintptr_t intptr = reinterpret_cast<uintptr_t>(ptr);
uint32_t shift = (intptr & 0x2) * 8U;
uint32_t mask = 0xFFFF0000U >> shift;

assert(!(intptr & 0x1)); // should be 2-aligned.
if (intptr & 0x2) {
// The half is in the second part of the uint32 (upper 16 bits).
uint32* address = reinterpret_cast<uint32*>(intptr - 2);
uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
unsigned short high = static_cast<unsigned short>(arg >> 16);
Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(high));
return (static_cast<uint32>(Eigen::numext::bit_cast<uint16>(acc)) << 16) |
(arg & 0xffff);
});
return Eigen::numext::bit_cast<Eigen::half>(
static_cast<uint16>(result >> 16));
} else {
// The half is in the first part of the uint32 (lower 16 bits).
uint32* address = reinterpret_cast<uint32*>(intptr);
uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) {
unsigned short low = static_cast<unsigned short>(arg & 0xffff);
Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(low));
return (arg & 0xffff0000) |
static_cast<uint32>(Eigen::numext::bit_cast<uint16>(acc));
});
return Eigen::numext::bit_cast<Eigen::half>(
static_cast<uint16>(result & 0xffff));
}
uint32* address = reinterpret_cast<uint32*>(intptr & ~0x3);
uint32 result = GpuAtomicCasHelper(address, [accumulate, shift, mask](uint32 arg) {
uint16_t high = static_cast<uint16_t>(arg >> shift);
Eigen::half acc = accumulate(Eigen::numext::bit_cast<Eigen::half>(high));
return (static_cast<uint32>(Eigen::numext::bit_cast<uint16_t>(acc)) << shift) |
(arg & mask);
});
return Eigen::numext::bit_cast<Eigen::half>(
static_cast<uint16_t>(result >> shift));
}

template <typename F>
Expand Down Expand Up @@ -720,10 +656,11 @@ __device__ CudaSupportedType<T>* ToCudaSupportedPtr(T* ptr) {

template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicAdd(T* ptr, U value) {
return atomicAdd(detail::ToCudaSupportedPtr(ptr), value);
return atomicAdd(detail::ToCudaSupportedPtr(detail::AddressSpaceHint<1>(ptr)),
value);
}


#if !defined(TENSORFLOW_USE_ROCM)
__device__ inline Eigen::bfloat16 GpuAtomicAdd(Eigen::bfloat16* ptr,
Eigen::bfloat16 value) {
return detail::GpuAtomicCasHelper(
Expand All @@ -735,26 +672,73 @@ __device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr,
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::half a) { return a + value; });
}
#endif

#if (__CUDA_ARCH__ < 600) || TENSORFLOW_USE_ROCM
#if (__CUDA_ARCH__ < 600)
__device__ inline double GpuAtomicAdd(double* ptr, double value) {
return detail::GpuAtomicCasHelper(ptr,
[value](double a) { return a + value; });
}
#endif

#if __gfx908__ || __gfx90a__ || __gfx940__ || __gfx941__ || __gfx942__ || __gfx1101__ || __gfx1102__ || __gfx1200__ || __gfx1201__
#if TENSORFLOW_USE_ROCM
template <typename T>
__device__ T GpuAtomicAddShared(T* dst, T val) {
return atomicAdd(detail::AddressSpaceHint<3>(dst), val);
}

#define ADDRSP1 __attribute__((address_space(1)))
__device__ float
#if __clang_major__ < 16
__llvm_amdgcn_global_atomic_add_f32(ADDRSP1 float* dst, float val) __asm("llvm.amdgcn.global.atomic.fadd.f32.p1f32.f32");
#else
__llvm_amdgcn_global_atomic_add_f32(ADDRSP1 float* dst, float val) __asm("llvm.amdgcn.global.atomic.fadd.f32.p1.f32");
#endif // clang_major
#endif // gfx
namespace detail {

template <typename P, typename T, typename F>
__device__ inline T GpuAtomicAddHalfHelper(T* ptr, T value, F add) {
typedef P __attribute__((ext_vector_type(2))) P2;
auto ptr2 = (__attribute__((address_space(1)))
P2*)(reinterpret_cast<uintptr_t>(ptr) & ~0x3);
uintptr_t shift = ((reinterpret_cast<uintptr_t>(ptr) & 0x2) * 8);
// Eigen::numext::bit_cast on ext_vector produces redudant inlined memcpy.
// Use union instead.
union {
P2 v2;
uint32_t i;
} u;

u.i = static_cast<uint32_t>(Eigen::numext::bit_cast<uint16_t>(value)) << shift;

// Performs + (T)0 on adjacent location, so this is not idempotent with
// regards to its bit pattern. Should be fine as long as that locations is
// used to hold T.
u.v2 = add(ptr2, u.v2);
return Eigen::numext::bit_cast<T>(static_cast<uint16_t>(u.i >> shift));
}

} // namespace detail

__device__ inline Eigen::bfloat16 GpuAtomicAdd(Eigen::bfloat16* ptr,
Eigen::bfloat16 value) {
#if __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
return detail::GpuAtomicAddHalfHelper<short>(
ptr, value, [](auto p, auto v) {
return __builtin_amdgcn_global_atomic_fadd_v2bf16(p, v);
});
#else
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::bfloat16 a) { return a + value; });
#endif
}

__device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr,
Eigen::half value) {
#if __has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16)
return detail::GpuAtomicAddHalfHelper<_Float16>(
ptr, value, [](auto p, auto v) {
return __builtin_amdgcn_global_atomic_fadd_v2f16(p, v);
});
#else
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::half a) { return a + value; });
#endif
}
#endif

// GpuAtomicAdd
// Specializations of GpuAtomicAdd for complex types, which GpuAtomicAdd does
Expand Down Expand Up @@ -783,7 +767,7 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicAdd, CudaAtomicAdd);
// GpuAtomicSub
template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicSub(T* ptr, U value) {
return atomicSub(ptr, value);
return atomicSub(detail::AddressSpaceHint<1>(ptr), value);
}

// Specializations of substraction which add the negative value.
Expand All @@ -806,22 +790,21 @@ __device__ inline tensorflow::uint64 GpuAtomicSub(tensorflow::uint64* ptr,

__device__ inline Eigen::half GpuAtomicSub(Eigen::half* ptr,
Eigen::half value) {
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::half a) { return a - value; });
return GpuAtomicAdd(ptr, -value);
}

__device__ inline Eigen::bfloat16 GpuAtomicSub(Eigen::bfloat16* ptr,
Eigen::bfloat16 value) {
return detail::GpuAtomicCasHelper(
ptr, [value](Eigen::bfloat16 a) { return a - value; });
return GpuAtomicAdd(ptr, -value);
}

CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicSub, CudaAtomicSub);

// GpuAtomicMax
template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMax(T* ptr, U value) {
return atomicMax(detail::ToCudaSupportedPtr(ptr), value);
return atomicMax(detail::ToCudaSupportedPtr(detail::AddressSpaceHint<1>(ptr)),
value);
}

#if TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -894,7 +877,8 @@ CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMax, CudaAtomicMax);
// GpuAtomicMin
template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMin(T* ptr, U value) {
return atomicMin(detail::ToCudaSupportedPtr(ptr), value);
return atomicMin(detail::ToCudaSupportedPtr(detail::AddressSpaceHint<1>(ptr)),
value);
}

#if TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -963,28 +947,6 @@ __device__ inline int64_t GpuAtomicMin(int64_t* ptr, int64_t value) {
}
#endif

#if __gfx908__ || __gfx90a__ || __gfx940__ || __gfx941__ || __gfx942__ || __gfx1101__ || __gfx1102__ || __gfx1200__ || __gfx1201__
// Low level instructions don't return. For now, assume that return value
// is always unused.
__device__ float GpuAtomicAdd(float* dst, float val) {
ADDRSP1 float* p = (ADDRSP1 float*) dst;
__llvm_amdgcn_global_atomic_add_f32(p, val);
return val;
}
#endif

template <typename T>
__device__ inline T GpuAtomicAddShared(T* ptr, T value) {
return GpuAtomicAdd(ptr, value);
}

#if __gfx908__ || __gfx90a__ || __gfx940__ || __gfx941__ || __gfx942__ || __gfx1101__ || __gfx1102__ || __gfx1200__ || __gfx1201__
__device__ float GpuAtomicAddShared(float* dst, float val) {
atomicAdd(dst, val);
return val;
}
#endif

CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMin, CudaAtomicMin);

// GpuAtomicMul
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def InvokeHipcc(argv, log=False):
# of link time. This allows the default host compiler (gcc) be used as the
# linker for TensorFlow on ROCm platform.
hipccopts += ' -fno-gpu-rdc '
hipccopts += ' -fcuda-flush-denormals-to-zero '
hipccopts += ' -fcuda-flush-denormals-to-zero -munsafe-fp-atomics '
hipccopts += undefines
hipccopts += defines
hipccopts += std_options
Expand Down