Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

static graph autogen code support for full_like op #54698

Merged
merged 3 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 0 additions & 98 deletions paddle/fluid/operators/fill_any_like_op.cc

This file was deleted.

15 changes: 15 additions & 0 deletions paddle/fluid/operators/generator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference
ctx->GetInputDataType(framework::GradVarName("Out")));
}}
}};
"""
elif op_name == "fill_any_like":
return f"""
class {to_pascal_case(op_name)}InferVarType : public framework::VarTypeInference {{
public:
void operator()(framework::InferVarTypeContext *ctx) const override {{
auto var_data_type = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, ctx->GetAttr("dtype")));
if (var_data_type < 0) {{
ctx->SetOutputDataType("Out", ctx->GetInputDataType("X"));
}} else {{
ctx->SetOutputDataType("Out", var_data_type);
}}
}}
}};
"""
else:
return None
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ phi::KernelKey GetCheckFiniteAndUnscaleExpectedKernelType(
return phi::KernelKey(dtype, ctx.GetPlace());
}

phi::KernelKey GetFullLikeExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
phi::KernelKey kt = op_ptr->OperatorWithKernel::GetExpectedKernelType(ctx);
const auto& data_type = ctx.Attr<int>("dtype");
if (data_type >= 0) {
kt.set_dtype(phi::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(data_type)));
}
return kt;
}

phi::KernelKey GetReduceExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ phi::KernelKey GetCheckFiniteAndUnscaleExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetFullLikeExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetReduceGradExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,12 @@
x : X
outputs :
out : Out
attrs :
{value: value, dtype: dtype}
scalar :
value :
data_type : float
support_tensor : true
get_expected_kernel_type :
full_like : GetFullLikeExpectedKernelType

- op : fused_conv2d
extra :
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@
param : [x, axis, keepdim, reduce_all]
backward : frobenius_norm_grad

- op : full_like
args : (Tensor x, Scalar value = 0.0, DataType dtype = DataType::UNDEFINED)
output: Tensor(out)
infer_meta :
func : FillAnyLikeInferMeta
param : [x]
kernel :
func : full_like
param : [x, value, dtype]

- op : gaussian
args : (IntArray shape = {}, float mean = .0f, float std = 1.0f, int seed = 0, DataType dtype = DataType::FLOAT32)
output: Tensor(out)
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,11 @@ void FillDiagonalInferMeta(
out->set_dtype(x.dtype());
}

void FillAnyLikeInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(x.dims());
out->share_lod(x);
}

void FFTC2CInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ void ExpandInferMeta(const MetaTensor& x,
void FillDiagonalInferMeta(
const MetaTensor& x, float value, int offset, bool wrap, MetaTensor* out);

void FillAnyLikeInferMeta(const MetaTensor& x, MetaTensor* out);

void FFTC2CInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axes,
const std::string& normalization,
Expand Down
28 changes: 0 additions & 28 deletions paddle/phi/ops/compat/fill_any_like_sig.cc

This file was deleted.