Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【Hackathon 5th No.26】为 Paddle 新增 select_scatter API #57874

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -247,6 +247,8 @@
view,
view_as,
unfold,
select_scatter,
select_scatter_,
)

from .tensor.math import ( # noqa: F401
@@ -899,4 +901,6 @@
'i1e',
'polygamma',
'polygamma_',
'select_scatter',
'select_scatter_',
]
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -165,6 +165,8 @@
from .manipulation import view # noqa: F401
from .manipulation import view_as # noqa: F401
from .manipulation import unfold # noqa: F401
from .manipulation import select_scatter
from .manipulation import select_scatter_
from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401
@@ -704,6 +706,8 @@
'triu_',
'stft',
'istft',
'select_scatter',
'select_scatter_',
'abs_',
'acos_',
'atan_',
109 changes: 109 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
@@ -5166,3 +5166,112 @@ def unfold(x, axis, size, step, name=None):
}
for name, func in __METHODS.items():
setattr(core.eager.Tensor, name, func)


def select_scatter(x, value, dim, index):
"""
Embeds the values of the value tensor into x at the given dim and index. This function returns a tensor with fresh storage.

Args:
x (Tensor) : The Input Tensor. Supported data types are bool, float16, float32, float64, int32, int64.
value (Tensor) : The Tensor to embed into x. (assert value.dtype == x.dtype)
dim (int) : The dimension to insert the slice into. Supported data types are int32, int64.
index (int) : The index to select with. Supported data types are int32, int64.

Returns:
Tensor, same dimension and dtype with x.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> x = paddle.to_tensor([[0, 0],
[0, 0]])
>>> value = paddle.to_tensor([1, 2])
>>> out = paddle.select_scatter(x, value, 0, 0)
>>> print(out)
Tensor(shape=[2, 2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1, 2],
[0, 0]])
"""

if len(x.shape) == 0:
print("select_scatter() can not be applied to a 0-dim tensor")
return None
if (x.shape[dim] < -index) | (x.shape[dim] <= index):
raise ValueError(
f"select_scatter() index: {index} out of range for tensor of size: {x.shape[dim]}"
)
return None
if index < 0:
index += x.shape[dim]
if value.dtype != x.dtype:
print(
"The data type of tensor value must be same to the data type of tensor x"
)
from builtins import slice as original_slice

indices = [original_slice(None)] * len(x.shape)
indices[dim] = index
if paddle.in_dynamic_mode():
out = x.clone()
out[tuple(indices)] = value
else:
out = paddle.static.setitem(x, tuple(indices), value)
return out


def select_scatter_(x, value, dim, index):
"""
Inplace version of ``select_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_select_scatter`.

Args:
x (Tensor) : The Input Tensor. Supported data types are bool, float16, float32, float64, int32, int64.
value (Tensor) : The Tensor to embed into x. (assert value.dtype == x.dtype)
dim (int) : The dimension to insert the slice into. Supported data types are int32, int64.
index (int) : The index to select with. Supported data types are int32, int64.

Returns:
Tensor, same dimension and dtype with x.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> x = paddle.to_tensor([[0, 0],
[0, 0]])
>>> value = paddle.to_tensor([1, 2])
>>> out = paddle.select_scatter(x, value, 0, 0)
>>> print(out)
Tensor(shape=[2, 2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[[1, 2],
[0, 0]])
"""
if len(x.shape) == 0:
print("select_scatter() can not be applied to a 0-dim tensor")
return None
if (x.shape[dim] < -index) | (x.shape[dim] <= index):
raise ValueError(
f"select_scatter() index: {index} out of range for tensoe of size: {x.shape[dim]}"
)
return None
if index < 0:
index += x.shape[dim]
if value.dtype != x.dtype:
print(
"The data type of tensor value must be same to the data type of tensor x"
)
from builtins import slice as original_slice

if paddle.in_dynamic_mode():
indices = [original_slice(None)] * len(x.shape)
indices[dim] = index
x[tuple(indices)] = value
else:
mask = paddle.zeros_like(x)
dim_size = mask.shape[dim]
indices = [original_slice(None)] * len(mask.shape)
indices[dim] = index
mask = paddle.static.setitem(mask, tuple(indices), value)
indexs = tuple(item.squeeze() for item in paddle.where(mask))
paddle.index_put_(x, indexs, value.flatten())
return x
16 changes: 16 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
@@ -250,6 +250,22 @@ def test_backward_success_2(self):
np.testing.assert_array_equal(grad_var_a_inplace, grad_var_a)


class TestDygraphInplaceSelectScatter(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-10, 10, [2, 2])
self.dtype = "float32"
self.value = paddle.to_tensor([1, 2], dtype='float32')

self.dim = 1
self.index = 1

def non_inplace_api_processing(self, var):
return paddle.select_scatter(var, self.value, self.dim, self.index)

def inplace_api_processing(self, var):
return paddle.select_scatter_(var, self.value, self.dim, self.index)


class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
192 changes: 192 additions & 0 deletions test/legacy_test/test_select_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle import base
from paddle.base import core


class TestSelectScatterAPI(unittest.TestCase):
def setUp(self):
self.dtype = "float32"
self.x_shape = (10, 10)
self.x = np.arange(100).reshape(self.x_shape).astype(self.dtype)
self.value_shape = (10,)
self.value = np.ones(self.value_shape).astype(self.dtype)
self.dim = 0
self.index = 0
self.out = np.array(
[
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0],
[20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0],
[30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
[40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0],
[50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0],
[60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0],
[70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0],
[80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0],
[90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0],
]
)

def test_static_graph(self):
paddle.enable_static()
startup_program = base.Program()
train_program = base.Program()
with base.program_guard(startup_program, train_program):
x = paddle.static.data(
name='x', shape=self.x_shape, dtype=self.dtype
)
value = paddle.static.data(
name='value', shape=self.value_shape, dtype=self.dtype
)
out = paddle.select_scatter(
x, value, dim=self.dim, index=self.index
)

place = (
base.CUDAPlace(0)
if core.is_compiled_with_cuda()
else base.CPUPlace()
)
exe = base.Executor(place)
res = exe.run(
base.default_main_program(),
feed={'x': self.x, 'value': self.value},
fetch_list=[out],
)
np.testing.assert_allclose(res[0], self.out, atol=1e-5, rtol=1e-5)
paddle.disable_static()

def test_dygraph(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
value = paddle.to_tensor(self.value)
result = paddle.select_scatter(x, value, dim=self.dim, index=self.index)
np.testing.assert_allclose(self.out, result.numpy(), rtol=1e-5)

paddle.enable_static()

def test_error(self):
x = paddle.to_tensor(self.x)
value = paddle.to_tensor(self.value)
self.assertRaises(ValueError, paddle.select_scatter, x, value, 0, 11)


class TestSelectScatterAPI1(TestSelectScatterAPI):
def setUp(self):
self.dtype = "float32"
self.x_shape = (3, 4, 5)
self.x = np.arange(60).reshape(self.x_shape).astype(self.dtype)
self.value_shape = (4, 5)
self.value = np.ones(self.value_shape).astype(self.dtype)
self.dim = 0
self.index = 1
self.out = np.array(
[
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[10.0, 11.0, 12.0, 13.0, 14.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
],
[
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
],
[
[40.0, 41.0, 42.0, 43.0, 44.0],
[45.0, 46.0, 47.0, 48.0, 49.0],
[50.0, 51.0, 52.0, 53.0, 54.0],
[55.0, 56.0, 57.0, 58.0, 59.0],
],
]
)


class TestSelectScatterAPI2(TestSelectScatterAPI):
def setUp(self):
self.dtype = "float32"
self.x_shape = (3, 4, 5)
self.x = np.arange(60).reshape(self.x_shape).astype(self.dtype)
self.value_shape = (3, 5)
self.value = np.ones(self.value_shape).astype(self.dtype)
self.dim = 1
self.index = 2
self.out = np.array(
[
[
[0.0, 1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0, 9.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[15.0, 16.0, 17.0, 18.0, 19.0],
],
[
[20.0, 21.0, 22.0, 23.0, 24.0],
[25.0, 26.0, 27.0, 28.0, 29.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[35.0, 36.0, 37.0, 38.0, 39.0],
],
[
[40.0, 41.0, 42.0, 43.0, 44.0],
[45.0, 46.0, 47.0, 48.0, 49.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[55.0, 56.0, 57.0, 58.0, 59.0],
],
]
)


class TestSelectScatterAPI3(TestSelectScatterAPI):
def setUp(self):
self.dtype = "float32"
self.x_shape = (3, 4, 5)
self.x = np.arange(60).reshape(self.x_shape).astype(self.dtype)
self.value_shape = (3, 4)
self.value = np.ones(self.value_shape).astype(self.dtype)
self.dim = 2
self.index = 3
self.out = np.array(
[
[
[0.0, 1.0, 2.0, 1.0, 4.0],
[5.0, 6.0, 7.0, 1.0, 9.0],
[10.0, 11.0, 12.0, 1.0, 14.0],
[15.0, 16.0, 17.0, 1.0, 19.0],
],
[
[20.0, 21.0, 22.0, 1.0, 24.0],
[25.0, 26.0, 27.0, 1.0, 29.0],
[30.0, 31.0, 32.0, 1.0, 34.0],
[35.0, 36.0, 37.0, 1.0, 39.0],
],
[
[40.0, 41.0, 42.0, 1.0, 44.0],
[45.0, 46.0, 47.0, 1.0, 49.0],
[50.0, 51.0, 52.0, 1.0, 54.0],
[55.0, 56.0, 57.0, 1.0, 59.0],
],
]
)


if __name__ == "__main__":
unittest.main()