diff --git a/CHANGELOG.md b/CHANGELOG.md index a01f6751..d0747f6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add if condition of number of threads for CPU OPs by [@JieRen98](https://github.com/JieRen98) in [#105](https://github.com/metaopt/torchopt/pull/105). - Add implicit MAML omniglot few-shot classification example with OOP APIs by [@XuehaiPan](https://github.com/XuehaiPan) in [#107](https://github.com/metaopt/torchopt/pull/107). - Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48). - Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101). diff --git a/src/adam_op/adam_op_impl_cpu.cpp b/src/adam_op/adam_op_impl_cpu.cpp index 82accd8c..c2ae4ae0 100644 --- a/src/adam_op/adam_op_impl_cpu.cpp +++ b/src/adam_op/adam_op_impl_cpu.cpp @@ -27,6 +27,8 @@ using std::size_t; namespace adam_op { +constexpr size_t MIN_NUMEL_USE_OMP = 1000; + template void adamForwardInplaceCPUKernel(const other_t b1, const other_t inv_one_minus_pow_b1, @@ -38,7 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1, scalar_t *__restrict__ updates_ptr, scalar_t *__restrict__ mu_ptr, scalar_t *__restrict__ nu_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -90,7 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b1, const size_t n, scalar_t *__restrict__ mu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t mu = mu_ptr[tid]; @@ -122,7 +128,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr, const other_t b2, const size_t n, scalar_t *__restrict__ nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t updates = updates_ptr[tid]; const scalar_t nu = nu_ptr[tid]; @@ -158,7 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr, const other_t eps_root, const size_t n, scalar_t *__restrict__ updates_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t new_mu = new_mu_ptr[tid]; const scalar_t new_nu = new_nu_ptr[tid]; @@ -205,7 +215,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dmu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dmu = dmu_ptr[tid]; @@ -240,7 +252,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr, const size_t n, scalar_t *__restrict__ dupdates_out_ptr, scalar_t *__restrict__ dnu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dnu = dnu_ptr[tid]; const scalar_t updates = updates_ptr[tid]; @@ -279,7 +293,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr, const size_t n, scalar_t *__restrict__ dnew_mu_out_ptr, scalar_t *__restrict__ dnew_nu_out_ptr) { -#pragma omp parallel for num_threads(omp_get_num_procs()) +#pragma omp parallel for num_threads( \ + std::min(n / MIN_NUMEL_USE_OMP, \ + static_cast (omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT for (size_t tid = 0; tid < n; ++tid) { const scalar_t dupdates = dupdates_ptr[tid]; const scalar_t updates = updates_ptr[tid];