diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index e1cbd58753ef37..93ed4074b31f20 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -890,21 +890,38 @@ std::tuple group_norm_decomp( if (need_cast) { x_cast = cast(x, DataType::FLOAT32); } - - auto x_dim = x.shape(); - std::vector one_axis(1, 1); - - std::vector x_shape{x_dim[0] * groups, -1}; - x_cast = reshape(x_cast, x_shape); - auto mean_ = mean_decomp(x_cast, IntArray(one_axis), true); - auto var_tmp_ = - mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; - auto var_ = - maximum(var_tmp_, full(var_tmp_.shape(), 0, var_tmp_.dtype())); - auto var_inv = 1 / sqrt_decomp(var_ + epsilon); - auto res = (x_cast - mean_) * var_inv; - auto out = reshape(res, x_dim); - + Tensor out, mean_, var_; + if (has_dynamic_shape(x.shape())) { + Tensor x_dim = shape(x); + std::vector one_axis(1, 1); + Tensor x_shape = get_slice(x_dim, 0) * groups; + Tensor dim_1 = full({1}, -1, x_dim.type()); + x_shape = concat({x_shape, dim_1}); + x_cast = backend::reshape(x_cast, x_shape); + mean_ = mean_decomp(x_cast, IntArray(one_axis), true); + Tensor var_tmp_ = + mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - + mean_ * mean_; + var_ = maximum( + var_tmp_, + backend::full_with_tensor(shape(var_tmp_), 0, var_tmp_.dtype())); + Tensor var_inv = 1 / sqrt_decomp(var_ + epsilon); + Tensor res = (x_cast - mean_) * var_inv; + out = backend::reshape(res, x_dim); + } else { + auto x_dim = x.shape(); + std::vector one_axis(1, 1); + + std::vector x_shape{x_dim[0] * groups, -1}; + x_cast = reshape(x_cast, x_shape); + mean_ = mean_decomp(x_cast, IntArray(one_axis), true); + auto var_tmp_ = mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - + mean_ * mean_; + var_ = maximum(var_tmp_, full(var_tmp_.shape(), 0, var_tmp_.dtype())); + auto var_inv = 1 / sqrt_decomp(var_ + epsilon); + auto res = (x_cast - mean_) * var_inv; + out = reshape(res, x_dim); + } auto scale_ptr = scale.get_ptr(); auto bias_ptr = bias.get_ptr(); @@ -933,11 +950,20 @@ std::tuple group_norm_decomp( } out = out + bias_cast; } - - std::vector res_shape{x_dim[0], groups}; - auto mean_out = reshape(mean_, res_shape); - auto var_out = reshape(var_, res_shape); - + Tensor mean_out, var_out; + if (has_dynamic_shape(x.shape())) { + Tensor x_dim = shape(x); + Tensor x_shape = get_slice(x_dim, 0); + Tensor dim_1 = full({1}, groups, x_shape.type()); + x_shape = concat({x_shape, dim_1}); + mean_out = backend::reshape(mean_, x_shape); + var_out = backend::reshape(var_, x_shape); + } else { + auto x_dim = x.shape(); + std::vector res_shape{x_dim[0], groups}; + mean_out = reshape(mean_, res_shape); + var_out = reshape(var_, res_shape); + } if (need_cast) { out = cast(out, org_dtype); } diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index d5762d1fc1f9ba..54fc95319b9094 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -92,6 +92,35 @@ def swiglu_net2(x): return paddle.incubate.nn.functional.swiglu(x) +def group_norm_net1(x): + group_norm = paddle.nn.GroupNorm(num_channels=x.shape[1], num_groups=32) + return group_norm(x) + + +def group_norm_net2(x): + group_norm = paddle.nn.GroupNorm( + num_channels=x.shape[1], num_groups=32, weight_attr=False + ) + return group_norm(x) + + +def group_norm_net3(x): + group_norm = paddle.nn.GroupNorm( + num_channels=x.shape[1], num_groups=32, bias_attr=False + ) + return group_norm(x) + + +def group_norm_net4(x): + group_norm = paddle.nn.GroupNorm( + num_channels=x.shape[1], + num_groups=32, + weight_attr=False, + bias_attr=False, + ) + return group_norm(x) + + def layer_norm_net1(x): return paddle.nn.functional.layer_norm(x, x.shape[1:]) @@ -365,5 +394,57 @@ def setUp(self): self.tol = 1e-6 +class TestPrimGroupNorm1(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net1 + self.necessary_ops = "pd_op.group_norm" + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimGroupNorm2(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net2 + self.necessary_ops = "pd_op.group_norm" + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimGroupNorm3(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net3 + self.necessary_ops = "pd_op.group_norm" + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimGroupNorm4(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.x_shape = [50, 640, 10, 20] + self.init_x_shape = [None, 640, None, None] + self.x = np.random.random(self.x_shape).astype(self.dtype) + self.net = group_norm_net4 + self.necessary_ops = "pd_op.group_norm" + self.enable_cinn = False + self.tol = 1e-6 + + if __name__ == "__main__": unittest.main()