From 7c52fa8545ca94f20d394f86f5685b17aae6b238 Mon Sep 17 00:00:00 2001 From: range3 Date: Sat, 16 Jul 2022 06:54:44 +0000 Subject: [PATCH] Fix wrong RAII impls --- src/rdbench/main.cpp | 38 ++++++-------------------- src/rdbench/raii_types.hpp | 56 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 29 deletions(-) create mode 100644 src/rdbench/raii_types.hpp diff --git a/src/rdbench/main.cpp b/src/rdbench/main.cpp index c4a1f8c..c3cae89 100644 --- a/src/rdbench/main.cpp +++ b/src/rdbench/main.cpp @@ -22,6 +22,7 @@ #include #include +#include "raii_types.hpp" #include "stopwatch.hpp" using json = nlohmann::json; @@ -34,28 +35,7 @@ const double Dv = 0.1; typedef std::vector vd; -struct mpi_datatype_deleter { - template void operator()(T *p) const { - MPI_Datatype type = p; - if (type != MPI_DATATYPE_NULL) { - MPI_Type_free(&type); - } - } -}; - -struct mpi_comm_deleter { - template void operator()(T *p) const { - MPI_Comm comm = p; - if (comm != MPI_COMM_NULL) { - MPI_Comm_free(&comm); - } - } -}; - struct RdbenchInfo { - using unique_mpi_datatype - = std::unique_ptr::type, mpi_datatype_deleter>; - using unique_mpi_comm = std::unique_ptr::type, mpi_comm_deleter>; int rank; int nprocs; int xnp; @@ -74,10 +54,10 @@ struct RdbenchInfo { bool view = false; bool sync = true; bool validate = true; - unique_mpi_datatype filetype; - unique_mpi_datatype memtype; - unique_mpi_datatype vertical_halo_type; - unique_mpi_comm comm_2d; + Datatype filetype; + Datatype memtype; + Datatype vertical_halo_type; + Comm comm_2d; size_t total_steps; size_t interval; bool fixed_x = false; @@ -110,7 +90,7 @@ struct RdbenchInfo { MPI_Comm comm_2d; int periods[] = {info.fixed_y ? 0 : 1, info.fixed_x ? 0 : 1}; MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &comm_2d); - info.comm_2d = unique_mpi_comm{comm_2d}; + info.comm_2d = Comm{comm_2d}; int coords[2]; MPI_Cart_coords(info.comm_2d.get(), info.rank, 2, coords); MPI_Cart_shift(info.comm_2d.get(), 0, 1, &info.rank_up, &info.rank_down); @@ -138,7 +118,7 @@ struct RdbenchInfo { MPI_Type_create_subarray(2, array_shape, chunk_shape, chunk_start, MPI_ORDER_C, MPI_DOUBLE, &t); MPI_Type_commit(&t); - info.filetype = unique_mpi_datatype{t}; + info.filetype = Datatype{t}; array_shape[0] = info.chunk_size_y + 2; array_shape[1] = info.chunk_size_x + 2; @@ -148,11 +128,11 @@ struct RdbenchInfo { MPI_Type_create_subarray(2, array_shape, chunk_shape, chunk_start, MPI_ORDER_C, MPI_DOUBLE, &t); MPI_Type_commit(&t); - info.memtype = unique_mpi_datatype{t}; + info.memtype = Datatype{t}; } MPI_Type_vector(info.chunk_size_y, 1, info.chunk_size_x + 2, MPI_DOUBLE, &t); MPI_Type_commit(&t); - info.vertical_halo_type = unique_mpi_datatype{t}; + info.vertical_halo_type = Datatype{t}; return info; } diff --git a/src/rdbench/raii_types.hpp b/src/rdbench/raii_types.hpp new file mode 100644 index 0000000..406011d --- /dev/null +++ b/src/rdbench/raii_types.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2022 range3 ( https://github.com/range3 ) +#pragma once + +#include + +#include +#include + +struct DatatypeWrap { + DatatypeWrap() = default; + DatatypeWrap(std::nullptr_t) {} + explicit DatatypeWrap(MPI_Datatype dt) : dt(dt) {} + explicit operator bool() const { return dt != MPI_DATATYPE_NULL; } + friend bool operator==(DatatypeWrap l, DatatypeWrap r) { return l.dt == r.dt; } + friend bool operator!=(DatatypeWrap l, DatatypeWrap r) { return !(l == r); } + MPI_Datatype dt = MPI_DATATYPE_NULL; +}; + +class Datatype { +public: + Datatype() = default; + Datatype(MPI_Datatype dt) { dt_ = decltype(dt_)(DatatypeWrap(dt), Deleter()); } + MPI_Datatype get() { return dt_.get().dt; } + +private: + struct Deleter { + using pointer = DatatypeWrap; + void operator()(DatatypeWrap dt) { MPI_Type_free(&dt.dt); } + }; + std::unique_ptr dt_; +}; + +struct CommWrap { + CommWrap() = default; + CommWrap(std::nullptr_t) {} + explicit CommWrap(MPI_Comm comm) : comm(comm) {} + explicit operator bool() const { return comm != MPI_COMM_NULL; } + friend bool operator==(CommWrap l, CommWrap r) { return l.comm == r.comm; } + friend bool operator!=(CommWrap l, CommWrap r) { return !(l == r); } + MPI_Comm comm = MPI_COMM_NULL; +}; + +struct Comm { +public: + Comm() = default; + Comm(MPI_Comm comm) { comm_ = decltype(comm_)(CommWrap(comm), Deleter()); } + MPI_Comm get() { return comm_.get().comm; } + +private: + struct Deleter { + using pointer = CommWrap; + void operator()(CommWrap comm) { MPI_Comm_free(&comm.comm); } + }; + std::unique_ptr comm_; +};