Skip to content

Commit

Permalink
Remove fluid matmul (PaddlePaddle#47988)
Browse files Browse the repository at this point in the history
* remove layers.matmul in nets.py

* remove layers.matmul in rnn_impl/test_quantization_pass/auto_parallel_gpt_model/test_auto_parallel_completion_gpt

* remove layers.matmul in other files

* fix

* fix

* remove layers.matmul itself

* remove ref in CMakeLists.txt and tools directory

* remove matmul in fluid.layers.nn.py

* remove matmul in fluid.dygraph.rnn.py && resotre test_matmul_op.py

* replace matmul in fluid.dygraph.rnn.py && clean api_test in test_matmul_op.py

* fix error && restore empty test_auto_search_dist_matmul_op.py

* fix check in test_auto_parallel_partitioner.py

* fix test_dist_matmul && test_flags_mkldnn_ops_on_off

* fix test_fused_attention_op_xpu.py && test_matmul_op_xpu.py

* remove test_auto_search_dist_matmul_op.py

* remove layers.matmul in auto_parallel_gpt_model.py && fix doc in fluid/io.py

* fix for matmul_grad

* fix codestyle

* fix codestyle

* resolve conflicts error

* restore unit test file but not compiled it for later remove

* fix codestyle

* fix wrong unittest skip

* fix unittest delete

* fix scale cost

* fix scale cost

* resolve conflicts error

* resolve conflicts error

Co-authored-by: jakpiase <jakpia21@gmail.com>
  • Loading branch information
kangguangli and jakpiase authored Dec 6, 2022
1 parent ab38fef commit ecebf45
Show file tree
Hide file tree
Showing 61 changed files with 183 additions and 591 deletions.
8 changes: 7 additions & 1 deletion paddle/phi/kernels/onednn/matmul_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,13 @@ void MatmulGradKernel(const Context &dev_ctx,

if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
} else if (y_dims.size() != ndims) {
}
if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
if (dout_dims.size() != ndims) {
dout_dims = ExtendDimsWithOnes(dout_dims, ndims);
}

// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Expand Down Expand Up @@ -150,7 +154,9 @@ void MatmulGradKernel(const Context &dev_ctx,
}

dx->Resize(x.dims());
dx->set_mem_desc(x.mem_desc().reshape(vectorize(x.dims())));
dy->Resize(y.dims());
dy->set_mem_desc(y.mem_desc().reshape(vectorize(y.dims())));
}

template <typename T, typename Context>
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/contrib/layers/rnn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _build_once(self, input, pre_hidden):
def forward(self, input, pre_hidden):
concat_input_hidden = layers.concat([input, pre_hidden], 1)

gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight)
gate_input = paddle.matmul(x=concat_input_hidden, y=self._gate_weight)

gate_input = paddle.add(gate_input, self._gate_bias)

Expand All @@ -160,7 +160,7 @@ def forward(self, input, pre_hidden):

r_hidden = r * pre_hidden

candidate = layers.matmul(
candidate = paddle.matmul(
layers.concat([input, r_hidden], 1), self._candidate_weight
)
candidate = paddle.add(candidate, self._candidate_bias)
Expand Down Expand Up @@ -874,7 +874,7 @@ def _build_once(self, input, pre_hidden, pre_cell):

def forward(self, input, pre_hidden, pre_cell):
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = paddle.matmul(x=concat_input_hidden, y=self._weight)

gate_input = paddle.add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def conv_bn_layer(
matmul_weight = paddle.create_parameter(
shape=[1, 16, 32, 32], dtype='float32'
)
hidden = fluid.layers.matmul(hidden, matmul_weight, True, True)
hidden = paddle.matmul(hidden, matmul_weight, True, True)
if quant_skip_pattern:
with fluid.name_scope(quant_skip_pattern):
pool = fluid.layers.pool2d(
Expand Down Expand Up @@ -724,7 +724,7 @@ def conv_bn_layer(
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = paddle.nn.functional.relu(paddle.add(x=conv, y=short))
hidden = fluid.layers.matmul(hidden, data2, True, True)
hidden = paddle.matmul(hidden, data2, True, True)
if isinstance(quant_skip_pattern, str):
with fluid.name_scope(quant_skip_pattern):
pool1 = fluid.layers.pool2d(
Expand Down
24 changes: 14 additions & 10 deletions python/paddle/fluid/dygraph/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from ..layers import (
concat,
fill_constant,
matmul,
elementwise_mul,
split,
)
import copy
import paddle

__all__ = ['LSTMCell', 'GRUCell']

Expand Down Expand Up @@ -215,11 +215,12 @@ def __init__(
def forward(self, input, pre_hidden, pre_cell):

if self._use_cudnn_impl:
igates = matmul(input, y=self._weight_ih, transpose_y=True)
igates = paddle.matmul(input, y=self._weight_ih, transpose_y=True)
igates = paddle.add(igates, self._bias_ih)
hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True)
hgates = paddle.matmul(
pre_hidden, self._weight_hh, transpose_y=True
)
hgates = paddle.add(hgates, self._bias_hh)

chunked_igates = split(igates, num_or_sections=4, dim=1)
chunked_hgates = split(hgates, num_or_sections=4, dim=1)

Expand All @@ -241,7 +242,7 @@ def forward(self, input, pre_hidden, pre_cell):
else:

concat_input_hidden = concat([input, pre_hidden], 1)
gate_input = matmul(x=concat_input_hidden, y=self._weight)
gate_input = paddle.matmul(x=concat_input_hidden, y=self._weight)

gate_input = paddle.add(gate_input, self._bias)
i, j, f, o = split(gate_input, num_or_sections=4, dim=-1)
Expand Down Expand Up @@ -461,10 +462,11 @@ def __init__(
def forward(self, input, pre_hidden):

if self._use_cudnn_impl:

igates = matmul(input, y=self._weight_ih, transpose_y=True)
igates = paddle.matmul(input, y=self._weight_ih, transpose_y=True)
igates = paddle.add(igates, self._bias_ih)
hgates = matmul(pre_hidden, self._weight_hh, transpose_y=True)
hgates = paddle.matmul(
pre_hidden, self._weight_hh, transpose_y=True
)
hgates = paddle.add(hgates, self._bias_hh)

chunked_igates = split(igates, num_or_sections=3, dim=1)
Expand All @@ -486,15 +488,17 @@ def forward(self, input, pre_hidden):

concat_input_hidden = concat([input, pre_hidden], 1)

gate_input = matmul(x=concat_input_hidden, y=self._gate_weight)
gate_input = paddle.matmul(
x=concat_input_hidden, y=self._gate_weight
)

gate_input = paddle.add(gate_input, self._gate_bias)
gate_input = self._gate_activation(gate_input)
r, u = split(gate_input, num_or_sections=2, dim=1)

r_hidden = r * pre_hidden

candidate = matmul(
candidate = paddle.matmul(
concat([input, r_hidden], 1), self._candidate_weight
)
candidate = paddle.add(candidate, self._candidate_bias)
Expand Down
149 changes: 0 additions & 149 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
'dropout',
'split',
'l2_normalize',
'matmul',
'row_conv',
'layer_norm',
'spectral_norm',
Expand Down Expand Up @@ -2589,154 +2588,6 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
return out


@deprecated(since="2.0.0", update_to="paddle.matmul")
def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
"""
Applies matrix multiplication to two tensors.
Currently, the input tensors' rank can be any, but when the rank of any
inputs is bigger than 3, this two inputs' rank should be equal.
The actual behavior depends on the shapes of :math:`x`, :math:`y` and the
flag values of :attr:`transpose_x`, :attr:`transpose_y`. Specifically:
- If a transpose flag is specified, the last two dimensions of the tensor
are transposed. If the tensor is rank-1 of shape :math:`[D]`, then for
:math:`x` it is treated as :math:`[1, D]` in nontransposed form and as
:math:`[D, 1]` in transposed form, whereas for :math:`y` it is the
opposite: It is treated as :math:`[D, 1]` in nontransposed form and as
:math:`[1, D]` in transposed form.
- After transpose, the two tensors are 2-D or n-D and matrix multiplication
performs in the following way.
- If both are 2-D, they are multiplied like conventional matrices.
- If either is n-D, it is treated as a stack of matrices residing in the
last two dimensions and a batched matrix multiply supporting broadcast
applies on the two tensors.
Also note that if the raw tensor :math:`x` or :math:`y` is rank-1 and
nontransposed, the prepended or appended dimension :math:`1` will be
removed after matrix multiplication.
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a Tensor or LoDTensor.
transpose_x (bool): Whether to transpose :math:`x` before multiplication.
transpose_y (bool): Whether to transpose :math:`y` before multiplication.
alpha (float): The scale of output. Default 1.0.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The product Tensor (or LoDTensor) variable.
Examples:
.. code-block:: python
# Examples to clarify shapes of the inputs and output
# x: [B, ..., M, K], y: [B, ..., K, N]
# fluid.layers.matmul(x, y) # out: [B, ..., M, N]
# x: [B, M, K], y: [B, K, N]
# fluid.layers.matmul(x, y) # out: [B, M, N]
# x: [B, M, K], y: [K, N]
# fluid.layers.matmul(x, y) # out: [B, M, N]
# x: [M, K], y: [K, N]
# fluid.layers.matmul(x, y) # out: [M, N]
# x: [B, M, K], y: [K]
# fluid.layers.matmul(x, y) # out: [B, M]
# x: [K], y: [K]
# fluid.layers.matmul(x, y) # out: [1]
# x: [M], y: [N]
# fluid.layers.matmul(x, y, True, True) # out: [M, N]
import paddle
import paddle.fluid as fluid
paddle.enable_static()
x = fluid.layers.data(name='x', shape=[2, 3], dtype='float32')
y = fluid.layers.data(name='y', shape=[3, 2], dtype='float32')
out = fluid.layers.matmul(x, y, True, True)
"""
if _non_static_mode():
out = _varbase_creator(dtype=x.dtype)
_legacy_C_ops.matmul(
x,
y,
out,
'transpose_X',
transpose_x,
'transpose_Y',
transpose_y,
'alpha',
float(alpha),
)
return out

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'matmul'
)
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]

# check the inner 2 dimensions
if transpose_x:
x_shape[-2], x_shape[-1] = x_shape[-1], x_shape[-2]
if transpose_y:
y_shape[-2], y_shape[-1] = y_shape[-1], y_shape[-2]
if x_shape[-1] != y_shape[-2]:
assert (x_shape[-1] == -1) or (y_shape[-2] == -1), (
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: %s, Y's shape: %s\n"
% (x_shape, y_shape)
)

if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape)
)

attrs = {
'transpose_X': transpose_x,
'transpose_Y': transpose_y,
'alpha': float(alpha),
}

__check_input(x, y)

helper = LayerHelper('matmul', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='matmul',
inputs={'X': x, 'Y': y},
outputs={'Out': out},
attrs=attrs,
)
return out


@templatedoc()
def row_conv(input, future_context_size, param_attr=None, act=None):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def __combine_heads(x):

key_dim_per_head = keys.shape[-1] // num_heads
scaled_q = paddle.scale(x=q, scale=key_dim_per_head**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
product = paddle.matmul(x=scaled_q, y=k, transpose_y=True)

x = paddle.reshape(x=product, shape=[-1, product.shape[-1]])
x = paddle.nn.functional.softmax(x)
Expand All @@ -631,5 +631,5 @@ def __combine_heads(x):
weights = layers.dropout(
weights, dropout_prob=dropout_rate, is_test=False
)
ctx_multiheads = layers.matmul(weights, v)
ctx_multiheads = paddle.matmul(weights, v)
return __combine_heads(ctx_multiheads)
Loading

0 comments on commit ecebf45

Please # to comment.