From be576a6bd1d6301dd742adc488ba8f984c6a96f6 Mon Sep 17 00:00:00 2001 From: BHmingyang <124899631+BHmingyang@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:26:05 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Fix=20PIR=20Unittest=20No.38=20BUAA?= =?UTF-8?q?=E3=80=91fix=20TruncatedNormalInitializer=20according=20to=20No?= =?UTF-8?q?rmalInitializer=20and=20fix=20test=5Fcuda=5Frandom=5Fseed=20(#6?= =?UTF-8?q?6413)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix TruncatedNormalInitializer according to NormalInitializer * change the assertion * fix the code style --- python/paddle/nn/initializer/normal.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/initializer/normal.py b/python/paddle/nn/initializer/normal.py index 5cd372784ce0ba..47eab6023636a1 100644 --- a/python/paddle/nn/initializer/normal.py +++ b/python/paddle/nn/initializer/normal.py @@ -235,9 +235,10 @@ def forward( The initialization op """ block = self._check_block(block) - - assert isinstance(var, framework.Variable) - assert isinstance(block, framework.Block) + assert isinstance( + var, (framework.Variable, paddle.pir.core.ParameterMeta) + ) + assert isinstance(block, (framework.Block, pir.Block)) if self._seed == 0: self._seed = block.program.random_seed @@ -279,6 +280,25 @@ def forward( out_var._share_underline_tensor_to(var) return None + elif in_pir_mode(): + out_var = _C_ops.truncated_gaussian_random( + var.shape, + self._mean, + self._std_dev, + self._seed, + self._a, + self._b, + out_dtype, + _current_expected_place(), + ) + if var.dtype in [ + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.BF16, + ]: + var_tmp = _C_ops.cast(out_var, var.dtype) + var_tmp._share_underline_tensor_to(var) + return out_var + else: op = block.append_op( type="truncated_gaussian_random",