Skip to content

Commit

Permalink
change codes according to pr suggestions about transpose file
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Dec 6, 2022
1 parent 8cf2c83 commit e72fea1
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 288 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"

namespace paddle {
namespace operators {
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"

namespace paddle {
namespace operators {
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/transpose_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"

namespace paddle {
namespace operators {
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/unique_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"

namespace paddle {
namespace operators {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/autotune/auto_tune_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class AutoTuneBase {
float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
// Regard 1st run as warmup, judge the compare result by the time cost
// of rest cycles.
constexpr int repeats = 4;
constexpr int repeats = 6;
phi::GpuTimer timer;
float time_cost = 0;
const auto& stream = ctx.stream();
Expand Down
18 changes: 8 additions & 10 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ struct LoaderTypeClassifier {
};

#ifndef PADDLE_WITH_XPU_KP
// Common broadcast/elementwise Loader.
template <typename T, int VecSize, int Arity, bool IsBoundary, int LoadType>
struct BroadcastDataLoader {
__device__ __forceinline__ void operator()(
Expand All @@ -107,6 +108,7 @@ struct BroadcastDataLoader {
}
};

// Scalar elementwise Loader with consideration of IsBoundary.
template <typename T, int VecSize, int Arity>
struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
__device__ __forceinline__ void operator()(
Expand All @@ -117,17 +119,12 @@ struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
const int block_offset,
const int num,
const uint32_t numel) {
#pragma unroll
for (int i = 0; i < Arity; ++i) {
#pragma unroll
kps::Init<T, VecSize>(args[i], static_cast<T>(1));
}

int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
for (int i = 0; i < Arity; ++i) {
#pragma unroll
for (int idx = 0; idx < VecSize; ++idx) {
args[i][idx] = static_cast<T>(1);
int index = thread_offset + idx;
if (index < numel) {
args[i][idx] = ins[i][index];
Expand All @@ -137,6 +134,7 @@ struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
}
};

// Vectorized elementwise Loader without consideration of IsBoundary.
template <typename T, int VecSize, int Arity>
struct BroadcastDataLoader<T, VecSize, Arity, false, kElementwise> {
__device__ __forceinline__ void operator()(
Expand Down Expand Up @@ -164,6 +162,7 @@ struct BroadcastDataLoader<T, VecSize, Arity, false, kElementwise> {
}
};

// Common broadcast data loader.
template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, kBroadcast> {
__device__ __forceinline__ void operator()(
Expand Down Expand Up @@ -405,11 +404,10 @@ void LaunchBroadcastKernel(
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
auto stream = ctx.stream();
auto threads = gpu_config.thread_per_block;
auto threads = gpu_config.GetBlockSize();
auto blocks = gpu_config.block_per_grid;
int main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * VecSize *
gpu_config.GetBlockSize();
int tail_tid = numel % (VecSize * gpu_config.GetBlockSize());
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);

if (loader_classifier.all_elementwise) {
VectorizedBroadcastKernel<Func,
Expand Down
32 changes: 10 additions & 22 deletions paddle/phi/kernels/funcs/dims_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,6 @@ struct BroadcastDimsSimplifier {
BroadcastDimsSimplifier(const std::vector<const DenseTensor *> &ins,
const phi::DDim &dims,
int axis) {
if (!NeedBroadcast(ins, dims)) {
int64_t numel = phi::product(dims);
rank = 1;
N = ins.size();
out_dims = DimVector{numel};
in_dims.resize(N);
for (int64_t i = 0; i < N; ++i) {
in_dims[i] = DimVector{numel};
}
return;
}

N = std::max(static_cast<int>(ins.size()), 2);
in_dims.resize(N);
rank = dims.size();
Expand Down Expand Up @@ -273,18 +261,18 @@ struct BroadcastDimsSimplifier {
};

// Simplify the input dims and permute dims if possible.
struct DimsSimplifier {
struct PermuteDimsSimplifier {
public:
explicit DimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
PermuteDimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
: perm_(rank), src_dims_(rank), count_(numel) {
SimplifyPermAndDims(rank, dims, perm);
perm_.resize(rank_);
src_dims_.resize(rank_);
dst_dims_.resize(rank_);
if (!is_seq_perm_) {
if (!is_sequence_perm_) {
for (auto i = 0; i < rank_; ++i) {
dst_dims_[i] = src_dims_[perm_[i]];
}
Expand All @@ -294,7 +282,7 @@ struct DimsSimplifier {
}
}

~DimsSimplifier() = default;
~PermuteDimsSimplifier() = default;

const int &GetRank() const { return rank_; }
const int64_t &GetCount() const { return count_; }
Expand All @@ -305,8 +293,8 @@ struct DimsSimplifier {
private:
int rank_{1};
int64_t count_{0};
bool is_seq_perm_{true};
std::vector<int> perm_;
bool is_sequence_perm_{true};
std::vector<int64_t> src_dims_;
std::vector<int64_t> dst_dims_;

Expand Down Expand Up @@ -365,11 +353,11 @@ struct DimsSimplifier {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
is_seq_perm_ &= (mapped == perm_idx);
is_sequence_perm_ &= (mapped == perm_idx);
perm_idx += 1;
}
}
rank_ = is_seq_perm_ ? 1 : valid_dim_idx;
rank_ = is_sequence_perm_ ? 1 : valid_dim_idx;
}
};

Expand Down
Loading

0 comments on commit e72fea1

Please # to comment.