Skip to content

Commit

Permalink
fix(acc_op): use static_cast
Browse files Browse the repository at this point in the history
  • Loading branch information
JieRen98 committed Nov 6, 2022
1 parent 2a0d9a4 commit d6b733c
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions src/adam_op/adam_op_impl_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ void adamForwardInplaceCPUKernel(const other_t b1,
scalar_t *__restrict__ updates_ptr,
scalar_t *__restrict__ mu_ptr,
scalar_t *__restrict__ nu_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -93,8 +94,9 @@ void adamForwardMuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b1,
const size_t n,
scalar_t *__restrict__ mu_out_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t mu = mu_ptr[tid];
Expand Down Expand Up @@ -126,8 +128,9 @@ void adamForwardNuCPUKernel(const scalar_t *__restrict__ updates_ptr,
const other_t b2,
const size_t n,
scalar_t *__restrict__ nu_out_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t updates = updates_ptr[tid];
const scalar_t nu = nu_ptr[tid];
Expand Down Expand Up @@ -163,8 +166,9 @@ void adamForwardUpdatesCPUKernel(const scalar_t *__restrict__ new_mu_ptr,
const other_t eps_root,
const size_t n,
scalar_t *__restrict__ updates_out_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t new_mu = new_mu_ptr[tid];
const scalar_t new_nu = new_nu_ptr[tid];
Expand Down Expand Up @@ -211,8 +215,9 @@ void adamBackwardMuCPUKernel(const scalar_t *__restrict__ dmu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dmu_out_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dmu = dmu_ptr[tid];

Expand Down Expand Up @@ -247,8 +252,9 @@ void adamBackwardNuCPUKernel(const scalar_t *__restrict__ dnu_ptr,
const size_t n,
scalar_t *__restrict__ dupdates_out_ptr,
scalar_t *__restrict__ dnu_out_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dnu = dnu_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down Expand Up @@ -287,8 +293,9 @@ void adamBackwardUpdatesCPUKernel(const scalar_t *__restrict__ dupdates_ptr,
const size_t n,
scalar_t *__restrict__ dnew_mu_out_ptr,
scalar_t *__restrict__ dnew_nu_out_ptr) {
#pragma omp parallel for num_threads(std::min( \
n / MIN_NUMEL_USE_OMP, (size_t)omp_get_num_procs())) if (n > MIN_NUMEL_USE_OMP) // NOLINT
#pragma omp parallel for num_threads( \
std::min(n / MIN_NUMEL_USE_OMP, \
static_cast <size_t>(omp_get_num_procs()))) if (n > MIN_NUMEL_USE_OMP) // NOLINT
for (size_t tid = 0; tid < n; ++tid) {
const scalar_t dupdates = dupdates_ptr[tid];
const scalar_t updates = updates_ptr[tid];
Expand Down

0 comments on commit d6b733c

Please # to comment.