Skip to content

Commit

Permalink
modify tensor_split, hsplit, dsplit, vsplit API documentation and add…
Browse files Browse the repository at this point in the history
… legends
  • Loading branch information
fufu0615 committed Aug 4, 2024
1 parent 037a293 commit 7ad2888
Showing 1 changed file with 55 additions and 5 deletions.
60 changes: 55 additions & 5 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,19 +2810,30 @@ def tensor_split(
Examples:
.. code-block:: python
:name: code-example-1
>>> import paddle
>>> # x is a Tensor of shape [8]
>>> # evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split-1.png
.. code-block:: python
:name: code-example-2
>>> import paddle
>>> # not evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
Expand All @@ -2831,7 +2842,16 @@ def tensor_split(
>>> print(out2.shape)
[2]
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split-2.png
.. code-block:: python
:name: code-example-3
>>> import paddle
>>> # split with indices
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
Expand All @@ -2840,6 +2860,13 @@ def tensor_split(
>>> print(out2.shape)
[5]
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split-3.png
.. code-block:: python
:name: code-example-4
>>> import paddle
>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
Expand All @@ -2849,6 +2876,16 @@ def tensor_split(
>>> print(out1.shape)
[7, 4]
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split-4.png
.. code-block:: python
:name: code-example-5
>>> import paddle
>>> # split along axis with indice
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
Expand All @@ -2857,6 +2894,8 @@ def tensor_split(
>>> print(out2.shape)
[7, 5]
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split-5.png
"""
if x.ndim <= 0 or x.ndim <= axis:
raise ValueError(
Expand Down Expand Up @@ -2910,8 +2949,11 @@ def hsplit(
x: Tensor, num_or_indices: int | Sequence[int], name: str | None = None
) -> list[Tensor]:
"""
Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.tensor_split`` with ``axis=1``
when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.tensor_split`` with ``axis=0`` when ``x`` 's dimension is 1.
``hsplit`` Full name Horizontal Split, splits the input Tensor into multiple sub-Tensors along the horizontal axis, in the following two cases:
1. When the dimension of x is equal to 1, it is equivalent to ``paddle.tensor_split`` with ``axis=0``;
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/hsplit/hsplit-1.png
2. when the dimension of x is greater than 1, it is equivalent to ``paddle.tensor_split`` with ``axis=1``.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/hsplit/hsplit-2.png
Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
Expand Down Expand Up @@ -2966,7 +3008,11 @@ def dsplit(
x: Tensor, num_or_indices: int | Sequence[int], name: str | None = None
) -> list[Tensor]:
"""
Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``.
``dsplit`` Full name Depth Split, splits the input Tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``.
Note: The number of Tensor dimensions transformed using ``paddle.dsplit`` must be no less than 3.
In the following figure, Tenser ``x`` has shape [4, 4, 4], and after ``paddle.dsplit(x, num_or_indices=2)`` transformation, we get ``out0`` and ``out1`` sub-Tensors whose shapes are both [4, 4, 2] :
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/dsplit/dsplit.png
Args:
x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
Expand Down Expand Up @@ -3010,7 +3056,11 @@ def vsplit(
x: Tensor, num_or_indices: int | Sequence[int], name: str | None = None
) -> list[Tensor]:
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``.
``vsplit`` Full name Vertical Split, splits the input Tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``.
1. When the number of Tensor dimensions is equal to 1:
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/vsplit/vsplit-1.png
2. When the number of Tensor dimensions is greater than 1:
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/vsplit/vsplit-2.png
Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
Expand Down

0 comments on commit 7ad2888

Please # to comment.