diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 0e758b29f..b55858ecf 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -73,9 +73,11 @@ TensorArray<3> adamForwardInplaceCPU(const torch::Tensor &updates, nu.mul_(scalar_t(b2)).addcmul_(updates, updates.conj(), 1 - scalar_t(b2)); - updates.copy_( - mu.mul(scalar_t(inv_one_minus_pow_b1)) - .div_(nu.mul(inv_one_minus_pow_b2).add_(scalar_t(eps_root)).sqrt_().add_(scalar_t(eps)))); + updates.copy_(mu.mul(scalar_t(inv_one_minus_pow_b1)) + .div_(nu.mul(inv_one_minus_pow_b2) + .add_(scalar_t(eps_root)) + .sqrt_() + .add_(scalar_t(eps)))); })); return TensorArray<3>{updates, mu, nu}; }