From 6c35392d62fe888219aa4abcc5a86185fe304284 Mon Sep 17 00:00:00 2001 From: lizhiyu02 <1528794076@qq.com> Date: Fri, 9 Dec 2022 11:11:55 +0000 Subject: [PATCH 1/5] generate static graph code of some ops by yaml --- paddle/fluid/operators/scatter_nd_add_op.cc | 141 ------------------ paddle/fluid/operators/scatter_op.cc | 126 ---------------- paddle/fluid/operators/selu_op.cc | 138 ----------------- paddle/fluid/operators/shard_index_op.cc | 109 -------------- paddle/fluid/operators/triangular_solve_op.cc | 135 ----------------- paddle/fluid/operators/viterbi_decode_op.cc | 72 --------- paddle/fluid/operators/where_op.cc | 134 ----------------- paddle/phi/api/yaml/backward.yaml | 54 +++++++ paddle/phi/api/yaml/legacy_backward.yaml | 53 ------- paddle/phi/api/yaml/legacy_ops.yaml | 66 -------- paddle/phi/api/yaml/op_compat.yaml | 47 ++++++ paddle/phi/api/yaml/ops.yaml | 67 +++++++++ paddle/phi/ops/compat/gather_scatter_sig.cc | 18 --- paddle/phi/ops/compat/selu_sig.cc | 26 ---- paddle/phi/ops/compat/triangular_solve_sig.cc | 30 ---- paddle/phi/ops/compat/where_grad_sig.cc | 28 ---- 16 files changed, 168 insertions(+), 1076 deletions(-) delete mode 100644 paddle/fluid/operators/scatter_nd_add_op.cc delete mode 100644 paddle/fluid/operators/scatter_op.cc delete mode 100644 paddle/fluid/operators/selu_op.cc delete mode 100644 paddle/fluid/operators/shard_index_op.cc delete mode 100644 paddle/fluid/operators/triangular_solve_op.cc delete mode 100644 paddle/fluid/operators/viterbi_decode_op.cc delete mode 100644 paddle/fluid/operators/where_op.cc delete mode 100644 paddle/phi/ops/compat/selu_sig.cc delete mode 100644 paddle/phi/ops/compat/triangular_solve_sig.cc delete mode 100644 paddle/phi/ops/compat/where_grad_sig.cc diff --git a/paddle/fluid/operators/scatter_nd_add_op.cc b/paddle/fluid/operators/scatter_nd_add_op.cc deleted file mode 100644 index 4ed08a387f2a0b..00000000000000 --- a/paddle/fluid/operators/scatter_nd_add_op.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -class ScatterNdAddOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - OperatorWithKernel::IndicateVarDataType(ctx, "Updates"), - platform::errors::InvalidArgument( - "Ref and Updates must have same type")); - return framework::OpKernelType( - framework::TransToProtoVarType( - ctx.Input("X")->type()), - ctx.device_context()); - } -}; - -class ScatterNdAddGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); - } -}; - -class ScatterNdAddOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The source input of scatter_nd_add op"); - AddInput("Index", - "The index input of scatter_nd_add op where X will be updated"); - AddInput("Updates", "The updated value of scatter_nd_add op"); - AddOutput("Out", "The output of scatter_nd_add op"); - AddComment(R"DOC( -Scatter_nd_add Operator. - -Output is obtained by applying sparse addition to a single value or slice in a Variable. - - Given: - * Case 1: - ref = [0, 1, 2, 3, 4, 5] - index = [[1], [2], [3], [1]] - updates = [9, 10, 11, 12] - - we get: - - output = [0, 22, 12, 14, 4, 5] - - * Case 2: - ref = [[65, 17], [-14, -25]] - index = [[], []] - updates = [[[-1, -2], [1, 2]], - [[3, 4], [-3, -4]]] - ref.shape = (2, 2) - index.shape = (2, 0) - updates.shape = (2, 2, 2) - - we get: - - output = [[67, 19], [-16, -27]] -)DOC"); - } -}; - -template -class ScatterNdAddGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("scatter_nd_add_grad"); - op->SetInput("Index", this->Input("Index")); - op->SetInput("Updates", this->Input("Updates")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Updates"), - this->InputGrad("Updates")); - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterNdAddGradNoNeedBufferVarsInferer, - "Updates"); - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(scatter_nd_add, - ScatterNdAddInferShapeFunctor, - PD_INFER_META(phi::ScatterNdAddInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(scatter_nd_add_grad, - ScatterNdAddGradInferShapeFunctor, - PD_INFER_META(phi::ScatterNdAddGradInferMeta)); - -REGISTER_OPERATOR(scatter_nd_add, - ops::ScatterNdAddOp, - ops::ScatterNdAddOpMaker, - ops::ScatterNdAddGradMaker, - ops::ScatterNdAddGradMaker, - ScatterNdAddInferShapeFunctor); - -REGISTER_OPERATOR(scatter_nd_add_grad, - ops::ScatterNdAddGradOp, - ops::ScatterNdAddGradNoNeedBufferVarsInferer, - ScatterNdAddGradInferShapeFunctor); diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc deleted file mode 100644 index dd758fcbe39cd5..00000000000000 --- a/paddle/fluid/operators/scatter_op.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -class ScatterOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class ScatterGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); - } -}; - -class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The source input of scatter op"); - AddInput("Ids", "The index input of scatter op where X will be updated"); - AddInput("Updates", "The updated value of scatter op"); - AddOutput("Out", "The output of scatter op"); - AddAttr("overwrite", - "(bool, default: True) " - "The mode that updating the output when has same index," - "If True, use the overwrite mode to update the output" - "of the same index, if False, use the accumulate mode to" - "update the output of the same index,Default value is True." - "You can set overwrite=False to implement scatter_add.") - .SetDefault(true); - AddComment(R"DOC( -Scatter Operator. - -This operator obtains output by updating the input on selected indices on the first axis: - -$$ -Out = X \\ -Out[Ids] = Updates -$$ - -)DOC"); - } -}; - -template -class ScatterGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("scatter_grad"); - op->SetInput("Ids", this->Input("Ids")); - op->SetInput("Updates", this->Input("Updates")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Updates"), - this->InputGrad("Updates")); - op->SetAttrMap(this->Attrs()); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer, - "Updates"); - -DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"}); - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(scatter, - ScatterInferShapeFunctor, - PD_INFER_META(phi::ScatterInferMeta)); - -DECLARE_INFER_SHAPE_FUNCTOR(scatter_grad, - ScatterGradInferShapeFunctor, - PD_INFER_META(phi::ScatterGradInferMeta)); - -namespace ops = paddle::operators; -REGISTER_OPERATOR(scatter, - ops::ScatterOp, - ops::ScatterOpMaker, - ops::ScatterGradMaker, - ops::ScatterGradMaker, - ops::ScatterInplaceInferer, - ScatterInferShapeFunctor); -REGISTER_OPERATOR(scatter_grad, - ops::ScatterGradOp, - ops::ScatterGradNoNeedBufferVarsInferer, - ScatterGradInferShapeFunctor); diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc deleted file mode 100644 index 0bf180e27d1424..00000000000000 --- a/paddle/fluid/operators/selu_op.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class SeluOp : public framework::OperatorWithKernel { - public: - SeluOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -class SeluOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map &GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -class SeluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input tensor of selu operator."); - AddOutput("Out", "The output tensor of selu operator."); - AddAttr("scale", - "(float) the default value is 1.0507~. For more " - "information about this value, please refer to:" - "https://arxiv.org/abs/1706.02515.") - .SetDefault(1.0507009873554804934193349852946); - AddAttr("alpha", - "(float) the default value is 1.6732~. For more " - "information about this value, please refer to:" - "https://arxiv.org/abs/1706.02515.") - .SetDefault(1.6732632423543772848170429916717); - AddComment(R"DOC( -Selu Operator. - -The equation is: -$$ -f(x) =\lambda* -\begin{cases} - \quad \quad x, \quad \quad \quad \text{if} \ x > 0 \\ - \alpha * e^x - \alpha, \qquad \text{if} \ x <= 0 -\end{cases} -$$ - -The input `X` can carry the LoD (Level of Details) information, -or not. And the output shares the LoD information with input `X`. -)DOC"); - } -}; - -template -class SeluGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("selu_grad"); - grad_op->SetInput("Out", this->Output("Out")); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetAttrMap(this->Attrs()); - } -}; - -class SeluGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "selu_grad"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "selu_grad"); - auto x_grad_name = framework::GradVarName("X"); - ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Out")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(selu, - SeluInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -REGISTER_OPERATOR(selu, - ops::SeluOp, - ops::SeluOpMaker, - ops::SeluOpInferVarType, - ops::SeluGradMaker, - ops::SeluGradMaker, - SeluInferShapeFunctor); - -REGISTER_OPERATOR(selu_grad, ops::SeluGradOp); diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc deleted file mode 100644 index 4c22efc2af2993..00000000000000 --- a/paddle/fluid/operators/shard_index_op.cc +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class ShardIndexOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); - } -}; - -class ShardIndexOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(phi::DenseTensor, phi::DenseTensor) Input variable. " - "Each value " - "of X is an index."); - AddOutput( - "Out", - "(Tensor, Tensor) Output tensor with same shape as X. " - "The tensor consists of sharding representations of values in X."); - AddAttr("index_num", - "A positive integer to specify the range of the input X."); - - AddAttr("nshards", - "A positive integer to specify the number of shards."); - AddAttr("shard_id", "The current shard id"); - AddAttr("ignore_value", "An integer value out of sharded range") - .SetDefault(-1); - AddComment(R"DOC( -This layer creates the sharded index for input. This layers is used in -model- and data- parallel mixed training generally, in which the index -data (usually the label) should be recaculated in each trainer according -to - -.. math:: - - assert index_num % nshards == 0 - - shard_size = index_num / nshards - - y = x % shard_size if x / shard_size == shard_id else ignore_value - -We take the distributed one-hot representation to show what this layer is -used for. The distributed one-hot representation is separated into multiple -shards, and each shard is filling zeros except the one with the index -inside. In order to create these sharded representation in each trainer, -the original index should be recalculated (i.e. sharded) before. - -Examples: - - X is a Tensor of integer values: - X.shape = [4, 1] - X.data = [[1], [6], [12], [19]] - - suppose index_num = 20 and nshards = 2, then we get shard_size = 10 - - if shard_id == 0, we get the Out: - Out.shape = [4, 1] - Out.data = [[1], [6], [-1], [-1]] - - if shard_id == 1, we get the Out: - Out.shape = [4, 1] - Out.data = [[-1], [-1], [2], [9]] - - the default `ignore_value` -1 is used in this example. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(shard_index, - ShardIndexInferShapeFunctor, - PD_INFER_META(phi::ShardIndexInferMeta)); -REGISTER_OPERATOR( - shard_index, - ops::ShardIndexOp, - ops::ShardIndexOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - ShardIndexInferShapeFunctor); diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc deleted file mode 100644 index 62dc419fd02627..00000000000000 --- a/paddle/fluid/operators/triangular_solve_op.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class TriangularSolveOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -class TriangularSolveOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "(Tensor), The first input tensor of triangular solve op, which " - "is the triangular coefficient matrix."); - AddInput("Y", - "(Tensor), The second input tensor of triangular solve op, which " - "is multiple right-hand."); - AddOutput("Out", "(Tensor), The solution tensor of triangular solve op."); - AddAttr("upper", - "whether to solve the upper-triangular or the " - "lower-triangular system of equations") - .SetDefault(true); - AddAttr("transpose", "whether X should be transposed firstly.") - .SetDefault(false); - AddAttr("unitriangular", "whether X is unit triangular.") - .SetDefault(false); - AddComment(R"DOC( - Triangular Solve Operator. - This operator is used to computes the solution of equations with a triangular coefficient matrix. - - The equation is: - $$Out = X^-1 * Y$$ -)DOC"); - } -}; - -class TriangularSolveOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -class TriangularSolveGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "triangular_solve"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "triangular_solve"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "triangular_solve"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "triangular_solve"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); - } - } -}; - -template -class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("triangular_solve_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Y", this->Input("Y")); - retv->SetInput("Out", this->Output("Out")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(triangular_solve, - TriangularSolveInferShapeFunctor, - PD_INFER_META(phi::TriangularSolveInferMeta)); - -REGISTER_OPERATOR(triangular_solve, - ops::TriangularSolveOp, - ops::TriangularSolveOpMaker, - ops::TriangularSolveOpInferVarType, - ops::TriangularSolveOpGradMaker, - ops::TriangularSolveOpGradMaker, - TriangularSolveInferShapeFunctor); - -REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp); diff --git a/paddle/fluid/operators/viterbi_decode_op.cc b/paddle/fluid/operators/viterbi_decode_op.cc deleted file mode 100644 index 13c25a80dd731c..00000000000000 --- a/paddle/fluid/operators/viterbi_decode_op.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/ternary.h" - -namespace paddle { -namespace operators { - -class ViterbiDecodeOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); - } -}; - -class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "The unary emission tensor. The shape of Input must be (batch_size," - "sequence_length, num_tags). "); - AddInput("Transition", - "The transition matrix. The shape of Transition must be ( " - "num_tags, num_tags). "); - AddInput("Length", - "The input length tensor storing real length of each sequence for " - "correctness. The shape of Length MUST be (batch_size)."); - AddOutput("Scores", - "The scores tensor containing the score for the Viterbi " - "sequence. The shape of Scores MUST be (batch_size)."); - AddOutput("Path", - "The paths tensor containing the highest scoring tag indices. " - "The shape of Scores MUST be (batch_size, sequence_length)."); - AddAttr("include_bos_eos_tag", - "If set to True, the last row and the last column of " - "transitions will be considered as start tag.") - .SetDefault(true); - AddComment(R"DOC( - )DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace platform = paddle::platform; -DECLARE_INFER_SHAPE_FUNCTOR(viterbi_decode, - ViterbiDecodeInferShapeFunctor, - PD_INFER_META(phi::ViterbiDecodeInferMeta)); -REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, - ops::ViterbiDecodeOp, - ops::ViterbiDecodeOpMaker, - ViterbiDecodeInferShapeFunctor); diff --git a/paddle/fluid/operators/where_op.cc b/paddle/fluid/operators/where_op.cc deleted file mode 100644 index 420ef74b830806..00000000000000 --- a/paddle/fluid/operators/where_op.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/multiary.h" -namespace paddle { -namespace operators { - -class WhereOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } -}; - -class WhereGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "Where"); - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Where"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Where"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "Where"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); - } - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); - } -}; - -class WhereOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Condition", - "(Tensor) A bool tensor whose rank is at least 1. When Condition " - "is True, yield x, otherwise yield y"); - AddInput("X", - "(Tensor), The first input tensor of where op. When the " - "corresponding position of the condition is true, the output " - "takes the element of X."); - AddInput("Y", - "(Tensor), The second input tensor of where op. When the " - "corresponding position of condition is false, the output takes " - "the element of Y."); - AddOutput("Out", "(Tensor), The output tensor of where op."); - AddComment(R"DOC( - Where Operator. - Return a tensor of elements selected from either $X$ or $Y$, depending on condition. - The equation is: - $$ - Out_i = - \begin{cases} - \X_i, \quad \text{if} \ cond_i is True \\ - \Y_i, \quad \text{if} \ cond_i is False \\ - \end{cases} - $$ -)DOC"); - } -}; - -template -class WhereOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad) const override { - grad->SetType("where_grad"); - grad->SetInput("Condition", this->Input("Condition")); - grad->SetInput("X", this->Input("X")); - grad->SetInput("Y", this->Input("Y")); - grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - } -}; - -DECLARE_NO_NEED_BUFFER_VARS_INFERER(WhereGradNoNeedBufferVarsInferer, "X", "Y"); -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(where, - WhereInferShapeFunctor, - PD_INFER_META(phi::WhereInferMeta)); -REGISTER_OPERATOR(where, - ops::WhereOp, - ops::WhereOpMaker, - ops::WhereOpGradMaker, - ops::WhereOpGradMaker, - WhereInferShapeFunctor); - -REGISTER_OPERATOR(where_grad, - ops::WhereGradOp, - ops::WhereGradNoNeedBufferVarsInferer); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 2e5ca9ff4916e2..34f2b33e0fde26 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -867,6 +867,39 @@ backward : rsqrt_double_grad inplace : (out_grad -> x_grad) +- backward_op : scatter_grad + forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite=true) -> Tensor(out) + args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite) + output : Tensor(x_grad), Tensor(updates_grad) + infer_meta : + func : ScatterGradInferMeta + param : [index, updates, out_grad, overwrite] + kernel : + func : scatter_grad + no_need_buffer : updates + +- backward_op : scatter_nd_add_grad + forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out) + args : (Tensor index, Tensor updates, Tensor out_grad) + output : Tensor(x_grad), Tensor(updates_grad) + infer_meta : + func : ScatterNdAddGradInferMeta + param : [index, updates, out_grad] + kernel : + func : scatter_nd_add_grad + no_need_buffer : updates + +- backward_op : selu_grad + forward : selu (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) -> Tensor(out) + args : (Tensor out, Tensor out_grad, float scale, float alpha) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : selu_grad + data_type : out + - backward_op : send_uv_grad forward : send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out) args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD") @@ -1154,6 +1187,16 @@ data_type : out_grad no_need_buffer : x +- backward_op : triangular_solve_grad + forward : triangular_solve (Tensor x, Tensor y, bool upper = true, bool tranpose = false, bool unitriangular = false) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool tranpose, bool unitriangular) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : triangular_solve_grad + - backward_op : trunc_grad forward : trunc (Tensor input) -> Tensor(out) args : (Tensor out_grad) @@ -1175,3 +1218,14 @@ func : unfold_grad data_type : out_grad no_need_buffer : x + +- backward_op : where_grad + forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out) + args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : where_grad + no_need_buffer : x, y diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 4001a75d0fa1d4..332407f53caa57 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1333,28 +1333,6 @@ output : Tensor(x_grad) invoke : scale(out_grad, scale, 0.0, bias_after_scale) -- backward_op : scatter_grad - forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite) -> Tensor(out) - args : (Tensor index, Tensor updates, Tensor out_grad, bool overwrite) - output : Tensor(x_grad), Tensor(updates_grad) - infer_meta : - func : ScatterGradInferMeta - param : [index, updates, out_grad, overwrite] - kernel : - func : scatter_grad - no_need_buffer : updates - -- backward_op : scatter_nd_add_grad - forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out) - args : (Tensor index, Tensor updates, Tensor out_grad) - output : Tensor(x_grad), Tensor(updates_grad) - infer_meta : - func : ScatterNdAddGradInferMeta - param : [index, updates, out_grad] - kernel : - func : scatter_nd_add_grad - no_need_buffer : updates - - backward_op : segment_pool_grad forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype) -> Tensor(out), Tensor(summed_ids) args : (Tensor x, Tensor segment_ids, Tensor out, Tensor summed_ids, Tensor out_grad, str pooltype) @@ -1367,16 +1345,6 @@ data_type : x optional : summed_ids -- backward_op : selu_grad - forward : selu (Tensor x, float scale, float alpha) -> Tensor(out) - args : (Tensor out, Tensor out_grad, float scale, float alpha) - output : Tensor(x_grad) - infer_meta : - func : UnchangedInferMeta - param : [out] - kernel : - func : selu_grad - - backward_op : send_u_recv_grad forward : send_u_recv (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) -> Tensor(out), Tensor(dst_count) args : (Tensor x, Tensor src_index, Tensor dst_index, Tensor out, Tensor dst_count, Tensor out_grad, str reduce_op = "SUM") @@ -1662,16 +1630,6 @@ func : transpose_grad backward : transpose_double_grad -- backward_op : triangular_solve_grad - forward : triangular_solve (Tensor x, Tensor y, bool upper, bool tranpose, bool unitriangular) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool tranpose, bool unitriangular) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param : [x, y] - kernel : - func : triangular_solve_grad - - backward_op : tril_grad forward : tril(Tensor x, int diagonal) -> Tensor(out) args : (Tensor out_grad, int diagonal) @@ -1761,17 +1719,6 @@ optional : logits_length no_need_buffer : logits -- backward_op : where_grad - forward : where (Tensor condition, Tensor x, Tensor y) -> Tensor(out) - args : (Tensor condition, Tensor x, Tensor y, Tensor out_grad) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param : [x, y] - kernel : - func : where_grad - no_need_buffer : x, y - - backward_op : yolo_loss_grad forward : yolo_loss(Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) -> Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, Tensor objectness_mask, Tensor gt_match_mask, Tensor loss_grad, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 2382739377eece..12818204730b47 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1729,27 +1729,6 @@ inplace : (x -> out) backward : scale_grad -- op : scatter - args : (Tensor x, Tensor index, Tensor updates, bool overwrite) - output : Tensor(out) - infer_meta : - func : ScatterInferMeta - dtype : x - kernel : - func : scatter - inplace : (x -> out) - backward : scatter_grad - -- op : scatter_nd_add - args : (Tensor x, Tensor index, Tensor updates) - output : Tensor - infer_meta : - func : ScatterNdAddInferMeta - dtype : x - kernel : - func : scatter_nd_add - backward : scatter_nd_add_grad - - op : searchsorted args : (Tensor sorted_sequence, Tensor values, bool out_int32, bool right) output : Tensor(out) @@ -1769,16 +1748,6 @@ data_type : x backward : segment_pool_grad -- op : selu - args : (Tensor x, float scale, float alpha) - output : Tensor - infer_meta : - func : UnchangedInferMeta - param : [x] - kernel : - func : selu - backward : selu_grad - - op : send_u_recv args : (Tensor x, Tensor src_index, Tensor dst_index, str reduce_op = "SUM", IntArray out_size = {0}) output : Tensor(out), Tensor(dst_count) @@ -1827,14 +1796,6 @@ data_transform: skip_transform : input -- op : shard_index - args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value) - output : Tensor(out) - infer_meta : - func : ShardIndexInferMeta - kernel : - func : shard_index - - op : sigmoid_cross_entropy_with_logits args : (Tensor x, Tensor label, bool normalize, int ignore_index) output : Tensor @@ -2036,15 +1997,6 @@ func : transpose backward : transpose_grad -- op : triangular_solve - args : (Tensor x, Tensor y, bool upper, bool transpose, bool unitriangular) - output : Tensor - infer_meta : - func : TriangularSolveInferMeta - kernel : - func : triangular_solve - backward : triangular_solve_grad - - op : tril args : (Tensor x, int diagonal) output : Tensor(out) @@ -2214,15 +2166,6 @@ data_type : x inplace : (x -> out), (prev_loss_scaling -> loss_scaling), (in_good_steps -> out_good_steps), (in_bad_steps -> out_bad_steps) -- op : viterbi_decode - args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag) - output : Tensor(scores), Tensor(path) - infer_meta : - func : ViterbiDecodeInferMeta - kernel : - func : viterbi_decode - data_type : potentials - - op : warpctc args : (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) output : Tensor(loss), Tensor(warpctcgrad) @@ -2235,15 +2178,6 @@ intermediate: warpctcgrad backward : warpctc_grad -- op : where - args : (Tensor condition, Tensor x, Tensor y) - output : Tensor - infer_meta : - func : WhereInferMeta - kernel : - func : where - backward : where_grad - - op : yolo_box args : (Tensor x, Tensor img_size, int[] anchors, int class_num, float conf_thresh, int downsample_ratio, bool clip_bbox, float scale_x_y=1.0, bool iou_aware=false, float iou_aware_factor=0.5) output : Tensor(boxes), Tensor(scores) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 4945794dc58f5c..30c9ae872699aa 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1029,10 +1029,31 @@ extra : attrs : [bool use_mkldnn = false] +- op : scatter + backward : scatter_grad + inputs : + {x : X, index : Ids, updates : Updates} + outputs : + out : Out + +- op : scatter_nd_add + backward : scatter_nd_add_grad + inputs : + {x : X, index : Index, updates : Updates} + outputs : + out : Out + - op : seed extra : attrs : [bool deterministic = false, str rng_name = "", bool force_cpu = false] +- op : selu + backward : selu_grad + inputs : + x : X + outputs : + out : Out + - op : send_uv (graph_send_uv) backward : send_uv_grad (graph_send_uv_grad) @@ -1052,6 +1073,12 @@ out : Out xout : XOut +- op : shard_index + inputs : + input : X + outputs : + out : Out + - op : shuffle_channel backward : shuffle_channel_grad extra : @@ -1242,6 +1269,13 @@ attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false, str mkldnn_data_type = "float32"] +- op : triangular_solve + backward : triangular_solve_grad + inputs : + {x : X, y : Y} + outputs : + out : Out + - op : trilinear_interp (trilinear_interp_v2) backward : trilinear_interp_grad (trilinear_interp_v2_grad) extra : @@ -1259,6 +1293,19 @@ outputs : out : Y +- op : viterbi_decode + inputs : + {potentials : Input, transition_params : Transition, lengths : Length} + outputs : + {scores : Scores, path : Path} + +- op : where + backward : where_grad + inputs : + {condition : Condition, x : X, y : Y} + outputs : + out : Out + - op : while backward : while_grad extra : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e51f23dda220fa..6f46ea623b090a 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -800,6 +800,37 @@ inplace : (x -> out) backward : rsqrt_grad +- op : scatter + args : (Tensor x, Tensor index, Tensor updates, bool overwrite=true) + output : Tensor(out) + infer_meta : + func : ScatterInferMeta + dtype : x + kernel : + func : scatter + inplace : (x -> out) + backward : scatter_grad + +- op : scatter_nd_add + args : (Tensor x, Tensor index, Tensor updates) + output : Tensor + infer_meta : + func : ScatterNdAddInferMeta + dtype : x + kernel : + func : scatter_nd_add + backward : scatter_nd_add_grad + +- op : selu + args : (Tensor x, float scale=1.0507009873554804934193349852946, float alpha=1.6732632423543772848170429916717) + output : Tensor + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : selu + backward : selu_grad + - op : send_uv args : (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") output : Tensor(out) @@ -810,6 +841,14 @@ data_type : x backward : send_uv_grad +- op : shard_index + args : (Tensor input, int index_num, int nshards, int shard_id, int ignore_value=-1) + output : Tensor(out) + infer_meta : + func : ShardIndexInferMeta + kernel : + func : shard_index + - op : sigmoid args : (Tensor x) output : Tensor @@ -964,6 +1003,16 @@ func : trace backward : trace_grad +- op : triangular_solve + args : (Tensor x, Tensor y, bool upper = true, bool transpose = false, bool unitriangular = false) + output : Tensor + infer_meta : + func : TriangularSolveInferMeta + kernel : + func : triangular_solve + data_type : x + backward : triangular_solve_grad + - op : trunc args : (Tensor input) output : Tensor @@ -989,3 +1038,21 @@ func : ShareBufferInferMeta kernel : func : share_buffer + +- op : viterbi_decode + args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag = true) + output : Tensor(scores), Tensor(path) + infer_meta : + func : ViterbiDecodeInferMeta + kernel : + func : viterbi_decode + data_type : potentials + +- op : where + args : (Tensor condition, Tensor x, Tensor y) + output : Tensor + infer_meta : + func : WhereInferMeta + kernel : + func : where + backward : where_grad diff --git a/paddle/phi/ops/compat/gather_scatter_sig.cc b/paddle/phi/ops/compat/gather_scatter_sig.cc index a942ebb44086f5..e37ba0ff401eba 100644 --- a/paddle/phi/ops/compat/gather_scatter_sig.cc +++ b/paddle/phi/ops/compat/gather_scatter_sig.cc @@ -21,24 +21,6 @@ KernelSignature GatherNdGradArgumentMapping(const ArgumentMappingContext& ctx) { "gather_nd_grad", {"X", "Index", "Out@GRAD"}, {}, {"X@GRAD"}); } -KernelSignature ScatterGradArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("scatter_grad", - {"Ids", "Updates", "Out@GRAD"}, - {"overwrite"}, - {"X@GRAD", "Updates@GRAD"}); -} - -KernelSignature ScatterNdAddGradArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("scatter_nd_add_grad", - {"Index", "Updates", "Out@GRAD"}, - {}, - {"X@GRAD", "Updates@GRAD"}); -} - } // namespace phi PD_REGISTER_ARG_MAPPING_FN(gather_nd_grad, phi::GatherNdGradArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(scatter_grad, phi::ScatterGradArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(scatter_nd_add_grad, - phi::ScatterNdAddGradArgumentMapping); diff --git a/paddle/phi/ops/compat/selu_sig.cc b/paddle/phi/ops/compat/selu_sig.cc deleted file mode 100644 index 08087584a10945..00000000000000 --- a/paddle/phi/ops/compat/selu_sig.cc +++ /dev/null @@ -1,26 +0,0 @@ - -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature SeluGradGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "selu_grad", {"Out", "Out@GRAD"}, {"scale", "alpha"}, {"X@GRAD"}); -} -} // namespace phi -PD_REGISTER_ARG_MAPPING_FN(selu_grad, phi::SeluGradGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/triangular_solve_sig.cc b/paddle/phi/ops/compat/triangular_solve_sig.cc deleted file mode 100644 index 851db32a032d65..00000000000000 --- a/paddle/phi/ops/compat/triangular_solve_sig.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature TriangularSolveGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature("triangular_solve_grad", - {"X", "Y", "Out", "Out@GRAD"}, - {"upper", "transpose", "unitriangular"}, - {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(triangular_solve_grad, - phi::TriangularSolveGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/where_grad_sig.cc b/paddle/phi/ops/compat/where_grad_sig.cc deleted file mode 100644 index e0c380672c895c..00000000000000 --- a/paddle/phi/ops/compat/where_grad_sig.cc +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature WhereGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("where_grad", - {"Condition", "X", "Y", "Out@GRAD"}, - {}, - {"X@GRAD", "Y@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(where_grad, phi::WhereGradOpArgumentMapping); From 9bc4b8c9cb92e7e501854394da36f15f76969764 Mon Sep 17 00:00:00 2001 From: lizhiyu02 <1528794076@qq.com> Date: Fri, 9 Dec 2022 12:46:39 +0000 Subject: [PATCH 2/5] fix the code-style of yaml --- paddle/phi/api/yaml/op_compat.yaml | 10 +++++----- paddle/phi/api/yaml/ops.yaml | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 30c9ae872699aa..eded3b07a556d1 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1066,18 +1066,18 @@ extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"] -- op : share_buffer +- op : shard_index inputs : - x : X + input : X outputs : out : Out - xout : XOut -- op : shard_index +- op : share_buffer inputs : - input : X + x : X outputs : out : Out + xout : XOut - op : shuffle_channel backward : shuffle_channel_grad diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 6f46ea623b090a..6f0c189b4a23f7 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1031,14 +1031,6 @@ func : unfold backward : unfold_grad -- op: share_buffer - args : (Tensor[] x, bool[] share_dims_and_dtype={}) - output : Tensor[](out){x.size()}, Tensor[](xout){x.size()} - infer_meta : - func : ShareBufferInferMeta - kernel : - func : share_buffer - - op : viterbi_decode args : (Tensor potentials, Tensor transition_params, Tensor lengths, bool include_bos_eos_tag = true) output : Tensor(scores), Tensor(path) @@ -1056,3 +1048,11 @@ kernel : func : where backward : where_grad + +- op: share_buffer + args : (Tensor[] x, bool[] share_dims_and_dtype={}) + output : Tensor[](out){x.size()}, Tensor[](xout){x.size()} + infer_meta : + func : ShareBufferInferMeta + kernel : + func : share_buffer From da71e701583da6a909dbe624ea381c247832675c Mon Sep 17 00:00:00 2001 From: lizhiyu02 <1528794076@qq.com> Date: Fri, 9 Dec 2022 16:15:58 +0000 Subject: [PATCH 3/5] fix the framework_ci for triangular_solve --- paddle/fluid/operators/triangular_solve_op.cc | 135 ++++++++++++++++++ paddle/phi/api/yaml/backward.yaml | 10 -- paddle/phi/api/yaml/legacy_backward.yaml | 10 ++ paddle/phi/api/yaml/legacy_ops.yaml | 10 ++ paddle/phi/api/yaml/op_compat.yaml | 7 - paddle/phi/api/yaml/ops.yaml | 10 -- paddle/phi/ops/compat/triangular_solve_sig.cc | 30 ++++ 7 files changed, 185 insertions(+), 27 deletions(-) create mode 100644 paddle/fluid/operators/triangular_solve_op.cc create mode 100644 paddle/phi/ops/compat/triangular_solve_sig.cc diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc new file mode 100644 index 00000000000000..62dc419fd02627 --- /dev/null +++ b/paddle/fluid/operators/triangular_solve_op.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/binary.h" + +namespace paddle { +namespace operators { + +class TriangularSolveOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class TriangularSolveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), The first input tensor of triangular solve op, which " + "is the triangular coefficient matrix."); + AddInput("Y", + "(Tensor), The second input tensor of triangular solve op, which " + "is multiple right-hand."); + AddOutput("Out", "(Tensor), The solution tensor of triangular solve op."); + AddAttr("upper", + "whether to solve the upper-triangular or the " + "lower-triangular system of equations") + .SetDefault(true); + AddAttr("transpose", "whether X should be transposed firstly.") + .SetDefault(false); + AddAttr("unitriangular", "whether X is unit triangular.") + .SetDefault(false); + AddComment(R"DOC( + Triangular Solve Operator. + This operator is used to computes the solution of equations with a triangular coefficient matrix. + + The equation is: + $$Out = X^-1 * Y$$ +)DOC"); + } +}; + +class TriangularSolveOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +class TriangularSolveGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "triangular_solve"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "triangular_solve"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "triangular_solve"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), + "Input", + "Out@GRAD", + "triangular_solve"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +template +class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("triangular_solve_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput("Out", this->Output("Out")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(triangular_solve, + TriangularSolveInferShapeFunctor, + PD_INFER_META(phi::TriangularSolveInferMeta)); + +REGISTER_OPERATOR(triangular_solve, + ops::TriangularSolveOp, + ops::TriangularSolveOpMaker, + ops::TriangularSolveOpInferVarType, + ops::TriangularSolveOpGradMaker, + ops::TriangularSolveOpGradMaker, + TriangularSolveInferShapeFunctor); + +REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 34f2b33e0fde26..9d7ca3c9af2564 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1187,16 +1187,6 @@ data_type : out_grad no_need_buffer : x -- backward_op : triangular_solve_grad - forward : triangular_solve (Tensor x, Tensor y, bool upper = true, bool tranpose = false, bool unitriangular = false) -> Tensor(out) - args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool tranpose, bool unitriangular) - output : Tensor(x_grad), Tensor(y_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param : [x, y] - kernel : - func : triangular_solve_grad - - backward_op : trunc_grad forward : trunc (Tensor input) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 332407f53caa57..8060234979119f 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1630,6 +1630,16 @@ func : transpose_grad backward : transpose_double_grad +- backward_op : triangular_solve_grad + forward : triangular_solve (Tensor x, Tensor y, bool upper, bool tranpose, bool unitriangular) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out, Tensor out_grad, bool upper, bool tranpose, bool unitriangular) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : triangular_solve_grad + - backward_op : tril_grad forward : tril(Tensor x, int diagonal) -> Tensor(out) args : (Tensor out_grad, int diagonal) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 12818204730b47..d2171ad82a56a5 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1997,6 +1997,16 @@ func : transpose backward : transpose_grad +- op : triangular_solve + args : (Tensor x, Tensor y, bool upper, bool transpose, bool unitriangular) + output : Tensor + infer_meta : + func : TriangularSolveInferMeta + kernel : + func : triangular_solve + data_type : x + backward : triangular_solve_grad + - op : tril args : (Tensor x, int diagonal) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index eded3b07a556d1..5cd00b81e14384 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1269,13 +1269,6 @@ attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false, str mkldnn_data_type = "float32"] -- op : triangular_solve - backward : triangular_solve_grad - inputs : - {x : X, y : Y} - outputs : - out : Out - - op : trilinear_interp (trilinear_interp_v2) backward : trilinear_interp_grad (trilinear_interp_v2_grad) extra : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 6f0c189b4a23f7..e37946a28bc185 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1003,16 +1003,6 @@ func : trace backward : trace_grad -- op : triangular_solve - args : (Tensor x, Tensor y, bool upper = true, bool transpose = false, bool unitriangular = false) - output : Tensor - infer_meta : - func : TriangularSolveInferMeta - kernel : - func : triangular_solve - data_type : x - backward : triangular_solve_grad - - op : trunc args : (Tensor input) output : Tensor diff --git a/paddle/phi/ops/compat/triangular_solve_sig.cc b/paddle/phi/ops/compat/triangular_solve_sig.cc new file mode 100644 index 00000000000000..851db32a032d65 --- /dev/null +++ b/paddle/phi/ops/compat/triangular_solve_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TriangularSolveGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("triangular_solve_grad", + {"X", "Y", "Out", "Out@GRAD"}, + {"upper", "transpose", "unitriangular"}, + {"X@GRAD", "Y@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(triangular_solve_grad, + phi::TriangularSolveGradOpArgumentMapping); From 69184cec1ba8d5e62acb01a5ba67aca63c49cb38 Mon Sep 17 00:00:00 2001 From: lizhiyu02 <1528794076@qq.com> Date: Mon, 12 Dec 2022 02:45:57 +0000 Subject: [PATCH 4/5] change the 'data_type' of scatter --- paddle/phi/api/yaml/ops.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e37946a28bc185..c125fe0ac30712 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -805,9 +805,9 @@ output : Tensor(out) infer_meta : func : ScatterInferMeta - dtype : x kernel : func : scatter + data_type : x inplace : (x -> out) backward : scatter_grad @@ -816,9 +816,9 @@ output : Tensor infer_meta : func : ScatterNdAddInferMeta - dtype : x kernel : func : scatter_nd_add + data_type : x backward : scatter_nd_add_grad - op : selu From aec9df463f2430bf4af6c92278a5d9d8f650b2c3 Mon Sep 17 00:00:00 2001 From: lizhiyu02 <1528794076@qq.com> Date: Mon, 12 Dec 2022 15:13:08 +0000 Subject: [PATCH 5/5] add the 'out: Out' of scatter_nd_add --- paddle/phi/api/yaml/op_compat.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 2b0abed2cfdb5b..5033a1932862ad 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1056,6 +1056,8 @@ backward : scatter_nd_add_grad inputs : {x : X, index : Index, updates : Updates} + outputs : + out : Out - op : searchsorted inputs :