Skip to content

Commit

Permalink
Fix wrong RAII impls
Browse files Browse the repository at this point in the history
  • Loading branch information
range3 committed Jul 16, 2022
1 parent b17d799 commit 7c52fa8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 29 deletions.
38 changes: 9 additions & 29 deletions src/rdbench/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <type_traits>
#include <vector>

#include "raii_types.hpp"
#include "stopwatch.hpp"

using json = nlohmann::json;
Expand All @@ -34,28 +35,7 @@ const double Dv = 0.1;

typedef std::vector<double> vd;

struct mpi_datatype_deleter {
template <typename T> void operator()(T *p) const {
MPI_Datatype type = p;
if (type != MPI_DATATYPE_NULL) {
MPI_Type_free(&type);
}
}
};

struct mpi_comm_deleter {
template <typename T> void operator()(T *p) const {
MPI_Comm comm = p;
if (comm != MPI_COMM_NULL) {
MPI_Comm_free(&comm);
}
}
};

struct RdbenchInfo {
using unique_mpi_datatype
= std::unique_ptr<std::remove_pointer<MPI_Datatype>::type, mpi_datatype_deleter>;
using unique_mpi_comm = std::unique_ptr<std::remove_pointer<MPI_Comm>::type, mpi_comm_deleter>;
int rank;
int nprocs;
int xnp;
Expand All @@ -74,10 +54,10 @@ struct RdbenchInfo {
bool view = false;
bool sync = true;
bool validate = true;
unique_mpi_datatype filetype;
unique_mpi_datatype memtype;
unique_mpi_datatype vertical_halo_type;
unique_mpi_comm comm_2d;
Datatype filetype;
Datatype memtype;
Datatype vertical_halo_type;
Comm comm_2d;
size_t total_steps;
size_t interval;
bool fixed_x = false;
Expand Down Expand Up @@ -110,7 +90,7 @@ struct RdbenchInfo {
MPI_Comm comm_2d;
int periods[] = {info.fixed_y ? 0 : 1, info.fixed_x ? 0 : 1};
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &comm_2d);
info.comm_2d = unique_mpi_comm{comm_2d};
info.comm_2d = Comm{comm_2d};
int coords[2];
MPI_Cart_coords(info.comm_2d.get(), info.rank, 2, coords);
MPI_Cart_shift(info.comm_2d.get(), 0, 1, &info.rank_up, &info.rank_down);
Expand Down Expand Up @@ -138,7 +118,7 @@ struct RdbenchInfo {
MPI_Type_create_subarray(2, array_shape, chunk_shape, chunk_start, MPI_ORDER_C, MPI_DOUBLE,
&t);
MPI_Type_commit(&t);
info.filetype = unique_mpi_datatype{t};
info.filetype = Datatype{t};

array_shape[0] = info.chunk_size_y + 2;
array_shape[1] = info.chunk_size_x + 2;
Expand All @@ -148,11 +128,11 @@ struct RdbenchInfo {
MPI_Type_create_subarray(2, array_shape, chunk_shape, chunk_start, MPI_ORDER_C, MPI_DOUBLE,
&t);
MPI_Type_commit(&t);
info.memtype = unique_mpi_datatype{t};
info.memtype = Datatype{t};
}
MPI_Type_vector(info.chunk_size_y, 1, info.chunk_size_x + 2, MPI_DOUBLE, &t);
MPI_Type_commit(&t);
info.vertical_halo_type = unique_mpi_datatype{t};
info.vertical_halo_type = Datatype{t};

return info;
}
Expand Down
56 changes: 56 additions & 0 deletions src/rdbench/raii_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright 2022 range3 ( https://github.com/range3 )
#pragma once

#include <mpi.h>

#include <cstddef>
#include <memory>

struct DatatypeWrap {
DatatypeWrap() = default;
DatatypeWrap(std::nullptr_t) {}
explicit DatatypeWrap(MPI_Datatype dt) : dt(dt) {}
explicit operator bool() const { return dt != MPI_DATATYPE_NULL; }
friend bool operator==(DatatypeWrap l, DatatypeWrap r) { return l.dt == r.dt; }
friend bool operator!=(DatatypeWrap l, DatatypeWrap r) { return !(l == r); }
MPI_Datatype dt = MPI_DATATYPE_NULL;
};

class Datatype {
public:
Datatype() = default;
Datatype(MPI_Datatype dt) { dt_ = decltype(dt_)(DatatypeWrap(dt), Deleter()); }
MPI_Datatype get() { return dt_.get().dt; }

private:
struct Deleter {
using pointer = DatatypeWrap;
void operator()(DatatypeWrap dt) { MPI_Type_free(&dt.dt); }
};
std::unique_ptr<void, Deleter> dt_;
};

struct CommWrap {
CommWrap() = default;
CommWrap(std::nullptr_t) {}
explicit CommWrap(MPI_Comm comm) : comm(comm) {}
explicit operator bool() const { return comm != MPI_COMM_NULL; }
friend bool operator==(CommWrap l, CommWrap r) { return l.comm == r.comm; }
friend bool operator!=(CommWrap l, CommWrap r) { return !(l == r); }
MPI_Comm comm = MPI_COMM_NULL;
};

struct Comm {
public:
Comm() = default;
Comm(MPI_Comm comm) { comm_ = decltype(comm_)(CommWrap(comm), Deleter()); }
MPI_Comm get() { return comm_.get().comm; }

private:
struct Deleter {
using pointer = CommWrap;
void operator()(CommWrap comm) { MPI_Comm_free(&comm.comm); }
};
std::unique_ptr<void, Deleter> comm_;
};

0 comments on commit 7c52fa8

Please # to comment.