-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[Prim][PIR] group_norm decomposite rule support dynamic shape #62793
Merged
Merged
Changes from 6 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
cec228a
support dynamic shape for group_norm but it need to support dynamic s…
zeroRains feaa27e
fix code style
zeroRains 233b730
remove todo
zeroRains 845eaee
fix conflict
zeroRains 9fa990d
modify the test
zeroRains 451eb16
remote debug tag
zeroRains c0eb649
fix a typo
zeroRains File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -890,21 +890,38 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp( | |
if (need_cast) { | ||
x_cast = cast<T>(x, DataType::FLOAT32); | ||
} | ||
|
||
auto x_dim = x.shape(); | ||
std::vector<int64_t> one_axis(1, 1); | ||
|
||
std::vector<int64_t> x_shape{x_dim[0] * groups, -1}; | ||
x_cast = reshape<T>(x_cast, x_shape); | ||
auto mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true); | ||
auto var_tmp_ = | ||
mean_decomp<T>(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; | ||
auto var_ = | ||
maximum<T>(var_tmp_, full<T>(var_tmp_.shape(), 0, var_tmp_.dtype())); | ||
auto var_inv = 1 / sqrt_decomp<T>(var_ + epsilon); | ||
auto res = (x_cast - mean_) * var_inv; | ||
auto out = reshape<T>(res, x_dim); | ||
|
||
Tensor out, mean_, var_; | ||
if (has_dynamic_shape(x.shape())) { | ||
Tensor x_dim = shape<T>(x); | ||
std::vector<int64_t> one_axis(1, 1); | ||
Tensor x_shape = get_slice<T>(x_dim, 0) * groups; | ||
Tensor dim_1 = full<T>({1}, -1, x_dim.type()); | ||
x_shape = concat<T>({x_shape, dim_1}); | ||
x_cast = backend::reshape<T>(x_cast, x_shape); | ||
mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove IntArray |
||
Tensor var_tmp_ = | ||
mean_decomp<T>(x_cast * x_cast, IntArray(one_axis), true) - | ||
mean_ * mean_; | ||
var_ = maximum<T>( | ||
var_tmp_, | ||
backend::full_with_tensor<T>(shape<T>(var_tmp_), 0, var_tmp_.dtype())); | ||
Tensor var_inv = 1 / sqrt_decomp<T>(var_ + epsilon); | ||
Tensor res = (x_cast - mean_) * var_inv; | ||
out = backend::reshape<T>(res, x_dim); | ||
} else { | ||
auto x_dim = x.shape(); | ||
std::vector<int64_t> one_axis(1, 1); | ||
|
||
std::vector<int64_t> x_shape{x_dim[0] * groups, -1}; | ||
x_cast = reshape<T>(x_cast, x_shape); | ||
mean_ = mean_decomp<T>(x_cast, IntArray(one_axis), true); | ||
auto var_tmp_ = mean_decomp<T>(x_cast * x_cast, IntArray(one_axis), true) - | ||
mean_ * mean_; | ||
var_ = maximum<T>(var_tmp_, full<T>(var_tmp_.shape(), 0, var_tmp_.dtype())); | ||
auto var_inv = 1 / sqrt_decomp<T>(var_ + epsilon); | ||
auto res = (x_cast - mean_) * var_inv; | ||
out = reshape<T>(res, x_dim); | ||
} | ||
auto scale_ptr = scale.get_ptr(); | ||
auto bias_ptr = bias.get_ptr(); | ||
|
||
|
@@ -933,11 +950,20 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp( | |
} | ||
out = out + bias_cast; | ||
} | ||
|
||
std::vector<int64_t> res_shape{x_dim[0], groups}; | ||
auto mean_out = reshape<T>(mean_, res_shape); | ||
auto var_out = reshape<T>(var_, res_shape); | ||
|
||
Tensor mean_out, var_out; | ||
if (has_dynamic_shape(x.shape())) { | ||
Tensor x_dim = shape<T>(x); | ||
Tensor x_shape = get_slice<T>(x_dim, 0); | ||
Tensor dim_1 = full<T>({1}, groups, x_shape.type()); | ||
x_shape = concat<T>({x_shape, dim_1}); | ||
mean_out = backend::reshape<T>(mean_, x_shape); | ||
var_out = backend::reshape<T>(var_, x_shape); | ||
} else { | ||
auto x_dim = x.shape(); | ||
std::vector<int64_t> res_shape{x_dim[0], groups}; | ||
mean_out = reshape<T>(mean_, res_shape); | ||
var_out = reshape<T>(var_, res_shape); | ||
} | ||
if (need_cast) { | ||
out = cast<T>(out, org_dtype); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,6 +92,35 @@ def swiglu_net2(x): | |
return paddle.incubate.nn.functional.swiglu(x) | ||
|
||
|
||
def group_norm_net1(x): | ||
group_norm = paddle.nn.GroupNorm(num_channels=x.shape[1], num_groups=32) | ||
return group_norm(x) | ||
|
||
|
||
def group_norm_net2(x): | ||
group_norm = paddle.nn.GroupNorm( | ||
num_channels=x.shape[1], num_groups=32, weight_attr=False | ||
) | ||
return group_norm(x) | ||
|
||
|
||
def group_norm_net3(x): | ||
group_norm = paddle.nn.GroupNorm( | ||
num_channels=x.shape[1], num_groups=32, bias_attr=False | ||
) | ||
return group_norm(x) | ||
|
||
|
||
def group_norm_net4(x): | ||
group_norm = paddle.nn.GroupNorm( | ||
num_channels=x.shape[1], | ||
num_groups=32, | ||
weight_attr=False, | ||
bias_attr=False, | ||
) | ||
return group_norm(x) | ||
|
||
|
||
def layer_norm_net1(x): | ||
return paddle.nn.functional.layer_norm(x, x.shape[1:]) | ||
|
||
|
@@ -365,5 +394,57 @@ def setUp(self): | |
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm1(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net1 | ||
self.necessary_ops = "pd_op.flatten" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. group_norm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm2(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net2 | ||
self.necessary_ops = "pd_op.flatten" | ||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm3(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net3 | ||
self.necessary_ops = "pd_op.flatten" | ||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
class TestPrimGroupNorm4(unittest.TestCase): | ||
def setUp(self): | ||
np.random.seed(2023) | ||
self.dtype = "float32" | ||
self.x_shape = [50, 640, 10, 20] | ||
self.init_x_shape = [None, 640, None, None] | ||
self.x = np.random.random(self.x_shape).astype(self.dtype) | ||
self.net = group_norm_net4 | ||
self.necessary_ops = "pd_op.flatten" | ||
self.enable_cinn = False | ||
self.tol = 1e-6 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.