Skip to content

Commit

Permalink
repo-sync-2025-02-26T11:37:57+0800 (#1009)
Browse files Browse the repository at this point in the history
  • Loading branch information
w-gc authored Feb 26, 2025
1 parent 720240d commit a72a806
Show file tree
Hide file tree
Showing 18 changed files with 1,555 additions and 48 deletions.
2 changes: 2 additions & 0 deletions .licenserc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ header: # <1>
- 'pyproject.toml'
- 'src/libspu/core/half.h' # MIT
- '.bazelignore'
- 'src/libspu/mpc/utils/waksman_net.h' # MIT
- 'src/libspu/mpc/utils/waksman_net.cc' # MIT

comment: never # <9>

Expand Down
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

module(
name = "spu",
version = "0.9.4.dev20250225",
version = "0.9.4.dev20250226",
compatibility_level = 1,
)

Expand Down
2 changes: 1 addition & 1 deletion src/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

module(
name = "spulib",
version = "0.9.4.dev20250225",
version = "0.9.4.dev20250226",
compatibility_level = 1,
)

Expand Down
13 changes: 11 additions & 2 deletions src/libspu/kernel/hal/permute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ hal::CompFn _get_cmp_func(SPUContext *ctx, int64_t num_keys,
return comp_fn;
}

bool _has_efficient_shuffle(SPUContext *ctx) {
const auto prot = ctx->config().protocol;

// semi2k and aby3 have highly efficient constant round implementation.
return prot == ProtocolKind::SEMI2K || prot == ProtocolKind::ABY3;
}

bool _check_method_require(SPUContext *ctx, RuntimeConfig::SortMethod method) {
bool pass = false;
switch (method) {
Expand Down Expand Up @@ -1610,9 +1617,11 @@ std::vector<spu::Value> simple_sort1d(SPUContext *ctx,
}

// if use default sort method, trying to find the most best method
// currently, radix sort -> quick sort -> sorting network
// currently, radix sort (has efficient `shuffle`) -> quick sort -> sorting
// network
if (sort_method == RuntimeConfig::SORT_DEFAULT) {
if (internal::_check_method_require(ctx, RuntimeConfig::SORT_RADIX)) {
if (internal::_check_method_require(ctx, RuntimeConfig::SORT_RADIX) &&
internal::_has_efficient_shuffle(ctx)) {
ret = internal::radix_sort(ctx, inputs, direction, num_keys, valid_bits);
} else if (internal::_check_method_require(ctx,
RuntimeConfig::SORT_QUICK)) {
Expand Down
13 changes: 13 additions & 0 deletions src/libspu/kernel/hlo/permute.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@

namespace spu::kernel::hlo {

/// Some conventions about permutation semantic:
/// 1. The `perm` in MPC layer is so-called shared with "composition"
/// semantic, i.e. let \pi is a perm, then P0 holds \pi0, P1 holds \pi1, and
/// \pi = \pi1 o \pi0 (`o` means composition of two permutations).
/// 2. There exists another semantic of `perm` (We call it "Additive"
/// semantic). The `perm` is a vector of destinations, and each element of the
/// vector is secret shared across the parties. e.g. let \pi = \pi0 + \pi1 if
/// we have 2 Parties. We implement the ops under this semantic of `perm` in
/// HAL layer based on the "composition" semantic of perm.
/// (REF: https://eprint.iacr.org/2019/695.pdf)
///
/// Note: In HLO layer, we always assume the perm is "Additive" semantic.

// Inverse permute vector `inputs` over permutation `perm`
// Let [n] = {0,1,2,...,n-1}, then perm: [n] -> [n] should be an invertible
// permutation, we denote prem^{-1} as its inversion.
Expand Down
59 changes: 36 additions & 23 deletions src/libspu/kernel/hlo/permute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ xt::xarray<T> evalSinglePermuteOp(SPUContext* ctx, VisType x_vis,
PtBufferView perm,
const PermuteFunc& perm_func,
int64_t perm_dim = 0) {
const auto prot = ctx->config().protocol;

auto x_v = makeTestValue(ctx, x, x_vis);
auto perm_v = makeTestValue(ctx, perm, perm_vis);

Expand All @@ -136,7 +138,10 @@ xt::xarray<T> evalSinglePermuteOp(SPUContext* ctx, VisType x_vis,
if (checkCommFree(x_vis, perm_vis)) {
EXPECT_EQ(send_round, 0);
}
if (ctx->hasKernel("inv_perm_av") && checkSpPass(x_vis, perm_vis)) {

// costs of cheetah is highly dependant of OT kind, so we skip it.
if (prot != CHEETAH && ctx->hasKernel("inv_perm_av") &&
checkSpPass(x_vis, perm_vis)) {
auto n_repeat = x_v.shape().numel() / x_v.shape().dim(perm_dim);
// For ss version, at least 3 rounds.
EXPECT_LE(std::min(send_round, recv_round), 2 * n_repeat);
Expand Down Expand Up @@ -182,16 +187,33 @@ std::vector<xt::xarray<T>> evalMultiplePermuteOp(

} // namespace

class PermuteTest : public ::testing::TestWithParam<
std::tuple<VisType, VisType, ProtocolKind, size_t>> {};
using PermuteParams = std::tuple<VisType, VisType, ProtocolKind, size_t>;

std::vector<PermuteParams> GetValidParamsCombinations() {
std::vector<PermuteParams> valid_combinations;

for (const auto& vis_x : kVisTypes) {
for (const auto& vis_perm : kVisTypes) {
for (const auto& protocol : {CHEETAH, SEMI2K, ABY3}) {
for (const auto& npc : {2, 3}) {
// npc=2/3 is not valid in ABY3/CHEETAH
if ((protocol == ABY3 && npc == 2) ||
(protocol == CHEETAH && npc == 3)) {
continue; // Skip invalid combinations
}
valid_combinations.emplace_back(vis_x, vis_perm, protocol, npc);
}
}
}
}
return valid_combinations;
}

class PermuteTest : public ::testing::TestWithParam<PermuteParams> {};

INSTANTIATE_TEST_SUITE_P(
GeneralPermute, PermuteTest,
testing::Combine(testing::ValuesIn(kVisTypes), // vis of x
testing::ValuesIn(kVisTypes), // vis of perm
testing::Values(SEMI2K, ABY3), // underlying protocol
testing::Values(2, 3) // npc=2 is not valid in ABY3
),
testing::ValuesIn(GetValidParamsCombinations()),
[](const testing::TestParamInfo<PermuteTest::ParamType>& p) {
return fmt::format("{}x{}x{}x{}", get_vis_str(std::get<0>(p.param)),
get_vis_str(std::get<1>(p.param)),
Expand All @@ -204,10 +226,6 @@ TEST_P(PermuteTest, SinglePermuteWork) {
const ProtocolKind protocol = std::get<2>(GetParam());
const size_t npc = std::get<3>(GetParam());

if (protocol == ABY3 && npc == 2) {
return;
}

xt::xarray<int64_t> x = {10, 0, 2, 3, 9, 1, 5, 6};
xt::xarray<int64_t> perm = {2, 7, 1, 6, 0, 4, 3, 5};

Expand Down Expand Up @@ -240,10 +258,6 @@ TEST_P(PermuteTest, PermDimWork) {
const ProtocolKind protocol = std::get<2>(GetParam());
const size_t npc = std::get<3>(GetParam());

if (protocol == ABY3 && npc == 2) {
return;
}

xt::xarray<int64_t> x = {{10, 0, 2, 3, 9, 1, 5, 6},
{-10, 0, -2, -3, -9, -1, -5, -6}};
xt::xarray<int64_t> perm = {2, 7, 1, 6, 0, 4, 3, 5};
Expand Down Expand Up @@ -279,10 +293,6 @@ TEST_P(PermuteTest, MultiplePermuteWork) {
const ProtocolKind protocol = std::get<2>(GetParam());
const size_t npc = std::get<3>(GetParam());

if (protocol == ABY3 && npc == 2) {
return;
}

xt::xarray<int64_t> x = {10, 0, 2, 3, 9, 1, 5, 6};
xt::xarray<int64_t> perm = {2, 7, 1, 6, 0, 4, 3, 5};

Expand Down Expand Up @@ -317,17 +327,20 @@ TEST_P(PermuteTest, MultiplePermuteWork) {
class PermuteEmptyTest : public ::testing::TestWithParam<ProtocolKind> {};

INSTANTIATE_TEST_SUITE_P(
PermuteEmpty, PermuteEmptyTest,
testing::Values(ProtocolKind::SEMI2K, ProtocolKind::ABY3),
PermuteEmpty, PermuteEmptyTest, testing::Values(CHEETAH, SEMI2K, ABY3),
[](const testing::TestParamInfo<PermuteEmptyTest::ParamType>& p) {
return fmt::format("{}", p.param);
});

TEST_P(PermuteEmptyTest, Empty) {
ProtocolKind prot = GetParam();
size_t npc = 3;
if (prot == CHEETAH) {
npc = 2;
}

mpc::utils::simulate(
3, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
npc, [&](const std::shared_ptr<yacl::link::Context>& lctx) {
SPUContext sctx = test::makeSPUContext(prot, kField, lctx);

auto empty_x =
Expand Down
18 changes: 18 additions & 0 deletions src/libspu/mpc/cheetah/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ spu_cc_library(
":arithmetic",
":boolean",
":conversion",
":permute",
":state",
"//libspu/mpc/common:prg_state",
"//libspu/mpc/common:pv2k",
Expand Down Expand Up @@ -145,3 +146,20 @@ spu_cc_library(
"//libspu/mpc/common:pv2k",
],
)

spu_cc_library(
name = "permute",
srcs = ["permute.cc"],
hdrs = ["permute.h"],
deps = [
":type",
"//libspu/mpc:ab_api",
"//libspu/mpc:kernel",
"//libspu/mpc/common:communicator",
"//libspu/mpc/common:prg_state",
"//libspu/mpc/common:pv2k",
"//libspu/mpc/utils:permute",
"//libspu/mpc/utils:ring_ops",
"//libspu/mpc/utils:waksman_net",
],
)
Loading

0 comments on commit a72a806

Please # to comment.