Skip to content

Commit

Permalink
【Fix PIR Unittest No.38 BUAA】fix TruncatedNormalInitializer according…
Browse files Browse the repository at this point in the history
… to NormalInitializer and fix test_cuda_random_seed (PaddlePaddle#66413)

* fix TruncatedNormalInitializer according to NormalInitializer

* change the assertion

* fix the code style
  • Loading branch information
BHmingyang authored and lixcli committed Aug 5, 2024
1 parent 9098290 commit be576a6
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions python/paddle/nn/initializer/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,10 @@ def forward(
The initialization op
"""
block = self._check_block(block)

assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
assert isinstance(
var, (framework.Variable, paddle.pir.core.ParameterMeta)
)
assert isinstance(block, (framework.Block, pir.Block))

if self._seed == 0:
self._seed = block.program.random_seed
Expand Down Expand Up @@ -279,6 +280,25 @@ def forward(
out_var._share_underline_tensor_to(var)
return None

elif in_pir_mode():
out_var = _C_ops.truncated_gaussian_random(
var.shape,
self._mean,
self._std_dev,
self._seed,
self._a,
self._b,
out_dtype,
_current_expected_place(),
)
if var.dtype in [
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
]:
var_tmp = _C_ops.cast(out_var, var.dtype)
var_tmp._share_underline_tensor_to(var)
return out_var

else:
op = block.append_op(
type="truncated_gaussian_random",
Expand Down

0 comments on commit be576a6

Please # to comment.