From d6b733c7131832cf4547d608aa6140d73919a75b Mon Sep 17 00:00:00 2001 From: Jie Ren Date: Sun, 6 Nov 2022 21:58:58 +0800 Subject: [PATCH] fix(acc_op): use static_cast --- src/adam_op/adam_op_impl_cpu.cpp | 35 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index ed05c6f0..c2ae4ae0 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -40,8 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -93,8 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -126,8 +128,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -163,8 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -211,8 +215,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -247,8 +252,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -287,8 +293,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(std::min( \ - n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid];