Skip to content

Commit

Permalink
[PHI] OneDNN version of Copy (PaddlePaddle#48539)
Browse files Browse the repository at this point in the history
* OneDNN version of Copy, tranpose kernels adjusted

* style fixes in tranpose_grad

* redundant headers deleted
  • Loading branch information
paulinagacek authored Dec 12, 2022
1 parent 30b1c1a commit 321b719
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 23 deletions.
43 changes: 27 additions & 16 deletions paddle/phi/core/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ void Copy(const Context& dev_ctx,
void* dst_ptr = nullptr;
if (paddle::platform::is_cpu_place(dst_place)) {
dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
#ifdef PADDLE_WITH_MKLDNN
dst->set_layout(src.layout());
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (paddle::platform::is_gpu_place(dst_place) ||
paddle::platform::is_cuda_pinned_place(dst_place)) {
Expand All @@ -81,7 +84,7 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(
dst->place(),
dst_place,
phi::errors::Unavailable(
errors::Unavailable(
"The Dst Tensor's place and dst_place do not match, Tensor's place "
"place is %s, dst_place is %s.",
dst->place(),
Expand Down Expand Up @@ -112,13 +115,13 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place),
true,
phi::errors::PreconditionNotMet(
errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.",
ctx_place));
auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(src_gpu_place,
ctx_gpu_place,
phi::errors::Unavailable(
errors::Unavailable(
"Source place and context place do not match, source "
"place is %s, context place is %s.",
src_gpu_place,
Expand All @@ -137,17 +140,17 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place),
true,
phi::errors::PreconditionNotMet(
errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.",
ctx_place));
auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(dst_gpu_place,
ctx_gpu_place,
phi::errors::Unavailable(
"Destination place and context place do not match, "
"destination place is %s, context place is %s.",
dst_gpu_place,
ctx_gpu_place));
PADDLE_ENFORCE_EQ(
dst_gpu_place,
ctx_gpu_place,
errors::Unavailable("Destination place and context place do not match, "
"destination place is %s, context place is %s.",
dst_gpu_place,
ctx_gpu_place));
auto stream =
blocking ? nullptr
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
Expand All @@ -161,7 +164,7 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place),
true,
phi::errors::PreconditionNotMet(
errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.",
ctx_place));
auto stream =
Expand All @@ -184,7 +187,7 @@ void Copy(const Context& dev_ctx,
paddle::memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
} else {
PADDLE_THROW(phi::errors::Unavailable(
PADDLE_THROW(errors::Unavailable(
"Context place dose not match the source and destination place."));
}
}
Expand All @@ -196,13 +199,13 @@ void Copy(const Context& dev_ctx,
PADDLE_ENFORCE_EQ(
paddle::platform::is_gpu_place(ctx_place),
true,
phi::errors::PreconditionNotMet(
errors::PreconditionNotMet(
"Context place error, excepted GPUPlace, but actually %s.",
ctx_place));
auto ctx_gpu_place = ctx_place;
PADDLE_ENFORCE_EQ(src_gpu_place,
ctx_gpu_place,
phi::errors::Unavailable(
errors::Unavailable(
"Source place and context place do not match, source "
"place is %s, context place is %s.",
src_gpu_place,
Expand Down Expand Up @@ -259,7 +262,7 @@ void Copy(const Context& dev_ctx,
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
PADDLE_THROW(errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
}
Expand Down Expand Up @@ -411,4 +414,12 @@ template void Copy(const CustomContext& dev_ctx,
bool blocking,
DenseTensor* dst);
#endif

#ifdef PADDLE_WITH_MKLDNN
template void Copy(const OneDNNContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
bool blocking,
DenseTensor* dst);
#endif
} // namespace phi
8 changes: 3 additions & 5 deletions paddle/phi/kernels/onednn/transpose_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
// limitations under the License.

#include "paddle/phi/kernels/transpose_grad_kernel.h"

#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -24,16 +22,16 @@ void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& axis,
DenseTensor* x_grad) {
PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType() == phi::AllocationType::CPU,
PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType() == AllocationType::CPU,
true,
errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace"));
"oneDNN TransposeGrad kernel must use CPUPlace"));
if (!x_grad) return;

const auto& onednn_engine = dev_ctx.GetEngine();

if (axis.size() == 1) {
paddle::framework::TensorCopy(out_grad, out_grad.place(), x_grad);
Copy<Context>(dev_ctx, out_grad, out_grad.place(), false, x_grad);
x_grad->set_mem_desc(out_grad.mem_desc());
return;
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/kernels/onednn/transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"

Expand Down Expand Up @@ -80,7 +79,7 @@ void TransposeKernel(const Context& dev_ctx,
dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc());

if (axis.size() == 1) {
paddle::framework::TensorCopy(x, x.place(), out);
Copy<Context>(dev_ctx, x, x.place(), false, out);
out->set_mem_desc(x.mem_desc());
return;
}
Expand Down

0 comments on commit 321b719

Please # to comment.