Skip to content

Change CUDA implementation of Transpose to support all fixed size tensor types #2387

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Sh
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Transpose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Transpose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization);
Expand Down Expand Up @@ -874,9 +872,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Tile)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, InstanceNormalization)>,
Expand Down
146 changes: 75 additions & 71 deletions onnxruntime/core/providers/cuda/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@
namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Transpose, \
kOnnxDomain, \
1, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Transpose<T>);
ONNX_OPERATOR_KERNEL_EX(Transpose,
kOnnxDomain,
1,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

// special case acceleration using cublas matrix transpose
std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, const TensorShape& input_shape) {
static std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, const TensorShape& input_shape) {
int M = 0;
int N = 0;

Expand All @@ -47,7 +44,72 @@ std::tuple<int, int> TryTransposeWithCublas(const std::vector<size_t>& perm, con
}

template <typename T>
Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
Status TransposeWithCublas(cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, int M, int N) {
typedef typename ToCudaType<T>::MappedType CudaT;
CudaT one = ToCudaType<T>::FromFloat(1.0f);
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
const CudaT* input_data = reinterpret_cast<const CudaT*>(input.Data<T>());
CudaT* output_data = reinterpret_cast<CudaT*>(output.MutableData<T>());
CUBLAS_RETURN_IF_ERROR(
cublasTransposeHelper(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_T, M, N,
&one,
input_data,
N,
&zero,
input_data,
N,
output_data,
M));
return Status::OK();
}

Status Transpose::DoTranspose(const Transpose& kernel,
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output) {
// special case when there is a dim value of 0 in the shape.
if (output.Shape().Size() == 0)
return Status::OK();

auto element_type = input.GetElementType();
if (element_type == utils::GetONNXTensorElementDataType<float>() ||
element_type == utils::GetONNXTensorElementDataType<double>() ||
element_type == utils::GetONNXTensorElementDataType<MLFloat16>()) {
auto mn = TryTransposeWithCublas(permutations, input.Shape());
int M = std::get<0>(mn);
int N = std::get<1>(mn);
if (M != 0 && N != 0) {
if (element_type == utils::GetONNXTensorElementDataType<float>()) {
return TransposeWithCublas<float>(kernel.CublasHandle(), input, output, M, N);
} else if (element_type == utils::GetONNXTensorElementDataType<double>()) {
return TransposeWithCublas<double>(kernel.CublasHandle(), input, output, M, N);
} else {
return TransposeWithCublas<MLFloat16>(kernel.CublasHandle(), input, output, M, N);
}
}
}

const std::vector<int64_t>& input_dims = input.Shape().GetDims();
const std::vector<int64_t>& output_dims = output.Shape().GetDims();

auto rank = input_dims.size();
CudaAsyncBuffer<int64_t> input_strides(&kernel, rank);
CudaAsyncBuffer<size_t> perm(&kernel, permutations);
CudaAsyncBuffer<fast_divmod> fdm_output_strides(&kernel, rank);
ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims));
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));

ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
ORT_RETURN_IF_ERROR(perm.CopyToGpu());
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());

size_t element_size = input.DataType()->Size();
auto status = TransposeImpl(element_size, rank, input_strides.GpuPtr(), perm.GpuPtr(), input.DataRaw(),
fdm_output_strides.GpuPtr(), output.MutableDataRaw(), output.Shape().Size());

return status;
}

Status Transpose::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* X_ptr = ctx->Input<Tensor>(0);
if (X_ptr == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
const Tensor& X = *X_ptr;
Expand All @@ -65,66 +127,8 @@ Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const {
TensorShape output_shape{output_dims};
Tensor* Y = ctx->Output(0, output_shape);

// special case when there is a dim value of 0 in the shape.
if (output_shape.Size() == 0)
return Status::OK();

auto mn = TryTransposeWithCublas(*p_perm, input_shape);
int M = std::get<0>(mn);
int N = std::get<1>(mn);
if (M != 0 && N != 0) {
typedef typename ToCudaType<T>::MappedType CudaT;
CudaT one = ToCudaType<T>::FromFloat(1.0f);
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
const CudaT* input_data = reinterpret_cast<const CudaT*>(X.template Data<T>());
CudaT* output_data = reinterpret_cast<CudaT*>(Y->template MutableData<T>());
CUBLAS_RETURN_IF_ERROR(
cublasTransposeHelper(
CublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_T,
M,
N,
&one,
input_data,
N,
&zero,
input_data,
N,
output_data,
M));
return Status::OK();
}

CudaAsyncBuffer<int64_t> input_strides(this, rank);
CudaAsyncBuffer<size_t> perm(this, *p_perm);
CudaAsyncBuffer<fast_divmod> fdm_output_strides(this, rank);
ORT_ENFORCE(TensorPitches::Calculate(input_strides.CpuSpan(), input_dims));
ORT_ENFORCE(CalculateFdmStrides(fdm_output_strides.CpuSpan(), output_dims));

ORT_RETURN_IF_ERROR(input_strides.CopyToGpu());
ORT_RETURN_IF_ERROR(perm.CopyToGpu());
ORT_RETURN_IF_ERROR(fdm_output_strides.CopyToGpu());

TransposeImpl(
rank,
input_strides.GpuPtr(),
perm.GpuPtr(),
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(X.template Data<T>()),
fdm_output_strides.GpuPtr(),
reinterpret_cast<typename ToCudaType<T>::MappedType*>(Y->template MutableData<T>()),
output_shape.Size());

return Status::OK();
return DoTranspose(*this, *p_perm, X, *Y);
}

#define SPECIALIZED_COMPUTE(T) \
REGISTER_KERNEL_TYPED(T) \
template Status Transpose<T>::ComputeInternal(OpKernelContext* ctx) const;

SPECIALIZED_COMPUTE(float)
SPECIALIZED_COMPUTE(double)
SPECIALIZED_COMPUTE(MLFloat16)

} // namespace cuda
} // namespace onnxruntime
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
namespace onnxruntime {
namespace cuda {

template <typename T>
class Transpose final : public CudaKernel, public TransposeBase {
public:
Transpose(const OpKernelInfo& info) : CudaKernel(info), TransposeBase(info) {}

Status ComputeInternal(OpKernelContext* context) const override;

static Status DoTranspose(const Transpose& transpose_kernel,
const std::vector<size_t>& permutations, const Tensor& input, Tensor& output);
};

} // namespace cuda
Expand Down
56 changes: 41 additions & 15 deletions onnxruntime/core/providers/cuda/tensor/transpose_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,49 @@ __global__ void _TransposeKernel(size_t shape_rank, const int64_t* input_strides
output_data[id] = input_data[input_index];
}

template <typename T>
void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data,
const fast_divmod* fdm_output_strides, T* output_data, size_t N) {
Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm,
const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_TransposeKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
shape_rank, input_strides, perm, input_data,
fdm_output_strides, output_data, N);
}

#define SPECIALIZED_IMPL(T) \
template void TransposeImpl<T>(size_t shape_rank, const int64_t* input_strides, const size_t* perm, \
const T* input_data, const fast_divmod* fdm_output_strides, T* output_data, \
size_t N);
switch (element_size) {
case sizeof(int8_t):
_TransposeKernel<int8_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
shape_rank, input_strides, perm,
reinterpret_cast<const ToCudaType<int8_t>::MappedType*>(input_data),
fdm_output_strides,
reinterpret_cast<ToCudaType<int8_t>::MappedType*>(output_data),
N);
break;
case sizeof(int16_t):
_TransposeKernel<int16_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
shape_rank, input_strides, perm,
reinterpret_cast<const ToCudaType<int16_t>::MappedType*>(input_data),
fdm_output_strides,
reinterpret_cast<ToCudaType<int16_t>::MappedType*>(output_data),
N);
break;
case sizeof(int32_t):
_TransposeKernel<int32_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
shape_rank, input_strides, perm,
reinterpret_cast<const ToCudaType<int32_t>::MappedType*>(input_data),
fdm_output_strides,
reinterpret_cast<ToCudaType<int32_t>::MappedType*>(output_data),
N);
break;
case sizeof(int64_t):
_TransposeKernel<int64_t><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
shape_rank, input_strides, perm,
reinterpret_cast<const ToCudaType<int64_t>::MappedType*>(input_data),
fdm_output_strides,
reinterpret_cast<ToCudaType<int64_t>::MappedType*>(output_data),
N);
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for transpose on CUDA. Element size was ",
element_size);
}

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
SPECIALIZED_IMPL(half)
return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/cuda/tensor/transpose_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
namespace onnxruntime {
namespace cuda {

template <typename T>
void TransposeImpl(size_t shape_rank, const int64_t* input_strides, const size_t* perm, const T* input_data,
const fast_divmod* fdm_output_strides, T* output_data, size_t N);
Status TransposeImpl(size_t element_size, size_t shape_rank, const int64_t* input_strides, const size_t* perm,
const void* input_data, const fast_divmod* fdm_output_strides, void* output_data, size_t N);

} // namespace cuda
} // namespace onnxruntime
85 changes: 78 additions & 7 deletions onnxruntime/test/providers/cpu/tensor/transpose_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,91 @@ TEST(TransposeOpTest, TwoDimNoAttrStr) {
// Test 2 dimensional transpose, with permutation attribute specified
TEST(TransposeOpTest, TwoDim) {
std::vector<int64_t> input_shape({2, 3});
std::vector<float> input_vals = {
1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f};
std::vector<float> input_vals = {1.0f, 2.0f, 3.0f,
4.0f, 5.0f, 6.0f};

std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
auto expected_vals = {
1.0f, 4.0f,
2.0f, 5.0f,
3.0f, 6.0f};
auto expected_vals = {1.0f, 4.0f,
2.0f, 5.0f,
3.0f, 6.0f};

TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
}

TEST(TransposeOpTest, TwoDim_double) {
std::vector<int64_t> input_shape({2, 3});
std::vector<double> input_vals = {1.0, 2.0, 3.0,
4.0, 5.0, 6.0};

std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<double> expected_vals = {1.0, 4.0,
2.0, 5.0,
3.0, 6.0};

TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
}

TEST(TransposeOpTest, TwoDim_int32) {
std::vector<int64_t> input_shape({2, 3});
std::vector<int32_t> input_vals = {1, 2, 3,
4, 5, 6};

std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<int32_t> expected_vals = {1, 4,
2, 5,
3, 6};

TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
}

TEST(TransposeOpTest, TwoDim_int16) {
std::vector<int64_t> input_shape({2, 3});
std::vector<int16_t> input_vals = {
1, 2, 3,
4, 5, 6};

std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<int16_t> expected_vals = {
1, 4,
2, 5,
3, 6};

TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
}

TEST(TransposeOpTest, TwoDim_mlfloat16) {
std::vector<int64_t> input_shape({2, 3});
std::vector<MLFloat16> input_vals;
for (uint16_t i = 0; i < 6; ++i)
input_vals.push_back(MLFloat16(i));

std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<MLFloat16> expected_vals = {MLFloat16(1), MLFloat16(4),
MLFloat16(2), MLFloat16(5),
MLFloat16(3), MLFloat16(6)};

TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
}

TEST(TransposeOpTest, TwoDim_int8) {
std::vector<int64_t> input_shape({2, 3});
std::vector<int8_t> input_vals = {1, 2, 3,
4, 5, 6};

std::vector<int64_t> perm = {1, 0};
std::vector<int64_t> expected_shape({3, 2});
std::initializer_list<int8_t> expected_vals = {1, 4,
2, 5,
3, 6};

TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals, false);
}

TEST(TransposeOpTest, TwoDimStr) {
std::vector<int64_t> input_shape({2, 3});
std::vector<std::string> input_vals = {
Expand Down