Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

483 implement compiler specific optimizations for ntt and ringmult #508

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions benchmark/src/Lattice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@

using namespace lbcrypto;

namespace lbcrypto {

template <typename E>
static E makeElement(std::shared_ptr<lbcrypto::ILParamsImpl<typename E::Integer>> params) {
typename E::Vector vec = makeVector<typename E::Vector>(params->GetRingDimension(), params->GetModulus());
Expand Down Expand Up @@ -114,8 +112,6 @@ static void GeneratePolys(std::map<usint, std::shared_ptr<P>>& parmArray, std::m
}
}

} // namespace lbcrypto

std::map<usint, std::shared_ptr<ILNativeParams>> Nativeparms;
std::map<usint, std::vector<NativePoly>> Nativepolys;

Expand Down
5 changes: 0 additions & 5 deletions benchmark/src/poly-benchmark-16k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@

using namespace lbcrypto;

namespace lbcrypto {

static std::vector<usint> tow_args({1, 2, 4, 8});

static const usint DCRTBITS = MAX_MODULUS_SIZE;
Expand Down Expand Up @@ -136,8 +134,6 @@ static void GenerateDCRTPolys(std::map<usint, std::shared_ptr<ILDCRTParams<BigIn
}
}

} // namespace lbcrypto

std::shared_ptr<ILNativeParams> Nativeparms;
std::map<usint, std::shared_ptr<ILDCRTParams<BigInteger>>> DCRTparms;

Expand Down Expand Up @@ -209,7 +205,6 @@ static void Native_mul(benchmark::State& state) {
c = a->Times(*b);
}
}

BENCHMARK(Native_mul)->Unit(benchmark::kMicrosecond);

static void DCRT_mul(benchmark::State& state) {
Expand Down
4 changes: 0 additions & 4 deletions benchmark/src/poly-benchmark-1k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@

using namespace lbcrypto;

namespace lbcrypto {

static std::vector<usint> tow_args({1, 2, 4, 8});

static const usint DCRTBITS = MAX_MODULUS_SIZE;
Expand Down Expand Up @@ -136,8 +134,6 @@ static void GenerateDCRTPolys(std::map<usint, std::shared_ptr<ILDCRTParams<BigIn
}
}

} // namespace lbcrypto

std::shared_ptr<ILNativeParams> Nativeparms;
std::map<usint, std::shared_ptr<ILDCRTParams<BigInteger>>> DCRTparms;

Expand Down
4 changes: 0 additions & 4 deletions benchmark/src/poly-benchmark-4k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@

using namespace lbcrypto;

namespace lbcrypto {

static std::vector<usint> tow_args({1, 2, 4, 8});

static const usint DCRTBITS = MAX_MODULUS_SIZE;
Expand Down Expand Up @@ -136,8 +134,6 @@ static void GenerateDCRTPolys(std::map<usint, std::shared_ptr<ILDCRTParams<BigIn
}
}

} // namespace lbcrypto

std::shared_ptr<ILNativeParams> Nativeparms;
std::map<usint, std::shared_ptr<ILDCRTParams<BigInteger>>> DCRTparms;

Expand Down
4 changes: 0 additions & 4 deletions benchmark/src/poly-benchmark-64k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@

using namespace lbcrypto;

namespace lbcrypto {

static std::vector<usint> tow_args({1, 2, 4, 8});

static const usint DCRTBITS = MAX_MODULUS_SIZE;
Expand Down Expand Up @@ -137,8 +135,6 @@ static void GenerateDCRTPolys(std::map<usint, std::shared_ptr<ILDCRTParams<BigIn
}
}

} // namespace lbcrypto

std::shared_ptr<ILNativeParams> Nativeparms;
std::map<usint, std::shared_ptr<ILDCRTParams<BigInteger>>> DCRTparms;

Expand Down
12 changes: 1 addition & 11 deletions benchmark/src/vechelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,7 @@ using namespace lbcrypto;

template <typename V>
inline V makeVector(usint ringdim, const typename V::Integer& mod) {
DiscreteUniformGeneratorImpl<V> dug;
dug.SetModulus(mod);

return dug.GenerateVector(ringdim);
}

inline NativeVector makeNativeVector(usint ringdim, const NativeInteger& mod) {
DiscreteUniformGeneratorImpl<NativeVector> dug;
dug.SetModulus(mod);

return dug.GenerateVector(ringdim);
return DiscreteUniformGeneratorImpl<V>().GenerateVector(ringdim, mod);
}

#endif
1 change: 1 addition & 0 deletions src/binfhe/lib/lwe-pke.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ LWESwitchingKey LWEEncryptionScheme::KeySwitchGen(const std::shared_ptr<LWECrypt
// TODO (cpascoe/dsuponit): this pragma needs to be revised as it may have to be removed completely
// #if !defined(__MINGW32__) && !defined(__MINGW64__)
// #pragma omp parallel for num_threads(N)
// #pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(N))
// #endif
for (size_t i = 0; i < N; ++i) {
std::vector<std::vector<NativeVector>> vector1A;
Expand Down
7 changes: 4 additions & 3 deletions src/binfhe/lib/rgsw-acc-cggi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ RingGSWACCKey RingGSWAccumulatorCGGI::KeyGenAcc(const std::shared_ptr<RingGSWCry
auto& ek00 = (*ek)[0][0];
auto& ek01 = (*ek)[0][1];

// handles ternary secrets using signed mod 3 arithmetic; 0 -> {0,0}, 1 ->
// {1,0}, -1 -> {0,1}
#pragma omp parallel for
// handles ternary secrets using signed mod 3 arithmetic
// 0 -> {0,0}, 1 -> {1,0}, -1 -> {0,1}
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t i = 0; i < n; ++i) {
auto s = sv[i].ConvertToInt();
ek00[i] = KeyGenCGGI(params, skNTT, s == 1 ? 1 : 0);
Expand Down Expand Up @@ -112,6 +112,7 @@ void RingGSWAccumulatorCGGI::AddToAccCGGI(const std::shared_ptr<RingGSWCryptoPar

SignedDigitDecompose(params, ct, dct);

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(digitsG2))
for (uint32_t i = 0; i < digitsG2; ++i)
dct[i].SetFormat(Format::EVALUATION);

Expand Down
33 changes: 15 additions & 18 deletions src/binfhe/lib/rgsw-acc-dm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,21 @@ namespace lbcrypto {
// Key generation as described in Section 4 of https://eprint.iacr.org/2014/816
RingGSWACCKey RingGSWAccumulatorDM::KeyGenAcc(const std::shared_ptr<RingGSWCryptoParams>& params,
const NativePoly& skNTT, ConstLWEPrivateKey& LWEsk) const {
auto sv = LWEsk->GetElement();
int32_t mod = sv.GetModulus().ConvertToInt();

int32_t modHalf = mod >> 1;

uint32_t baseR = params->GetBaseR();
const std::vector<NativeInteger>& digitsR = params->GetDigitsR();
uint32_t n = sv.GetLength();
RingGSWACCKey ek = std::make_shared<RingGSWACCKeyImpl>(n, baseR, digitsR.size());

#pragma omp parallel for
for (size_t i = 0; i < n; ++i) {
for (size_t j = 1; j < baseR; ++j) {
auto sv{LWEsk->GetElement()};
auto mod{sv.GetModulus().ConvertToInt<int32_t>()};
auto modHalf{mod >> 1};
uint32_t n(sv.GetLength());
int32_t baseR(params->GetBaseR());
const auto& digitsR = params->GetDigitsR();
RingGSWACCKey ek = std::make_shared<RingGSWACCKeyImpl>(n, baseR, digitsR.size());

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (uint32_t i = 0; i < n; ++i) {
for (int32_t j = 1; j < baseR; ++j) {
for (size_t k = 0; k < digitsR.size(); ++k) {
int32_t s = (int32_t)sv[i].ConvertToInt();
if (s > modHalf) {
s -= mod;
}
(*ek)[i][j][k] = KeyGenDM(params, skNTT, s * j * (int32_t)digitsR[k].ConvertToInt());
auto s{sv[i].ConvertToInt<int32_t>()};
(*ek)[i][j][k] =
KeyGenDM(params, skNTT, (s > modHalf ? s - mod : s) * j * digitsR[k].ConvertToInt<int32_t>());
}
}
}
Expand Down Expand Up @@ -132,6 +128,7 @@ void RingGSWAccumulatorDM::AddToAccDM(const std::shared_ptr<RingGSWCryptoParams>

SignedDigitDecompose(params, ct, dct);

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(digitsG2))
for (uint32_t j = 0; j < digitsG2; ++j)
dct[j].SetFormat(Format::EVALUATION);

Expand Down
30 changes: 13 additions & 17 deletions src/binfhe/lib/rgsw-acc-lmkcdey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,35 @@

namespace lbcrypto {

// TODO: optimize this

// Key generation as described in https://eprint.iacr.org/2022/198
RingGSWACCKey RingGSWAccumulatorLMKCDEY::KeyGenAcc(const std::shared_ptr<RingGSWCryptoParams>& params,
const NativePoly& skNTT, ConstLWEPrivateKey& LWEsk) const {
auto sv = LWEsk->GetElement();
int32_t mod = sv.GetModulus().ConvertToInt();
int32_t modHalf = mod >> 1;
uint32_t N = params->GetN();
size_t n = sv.GetLength();
uint32_t numAutoKeys = params->GetNumAutoKeys();
auto sv{LWEsk->GetElement()};
auto mod{sv.GetModulus().ConvertToInt<int32_t>()};
auto modHalf{mod >> 1};
uint32_t N{params->GetN()};
size_t n{sv.GetLength()};
uint32_t numAutoKeys{params->GetNumAutoKeys()};

// dim2, 0: for RGSW(X^si), 1: for automorphism keys
// only w automorphism keys required
// allocates (n - w) more memory for pointer (not critical for performance)
RingGSWACCKey ek = std::make_shared<RingGSWACCKeyImpl>(1, 2, n);

#pragma omp parallel for
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(n))
for (size_t i = 0; i < n; ++i) {
int32_t s = (int32_t)sv[i].ConvertToInt();
if (s > modHalf) {
s -= mod;
}

(*ek)[0][0][i] = KeyGenLMKCDEY(params, skNTT, s);
auto s{sv[i].ConvertToInt<int32_t>()};
(*ek)[0][0][i] = KeyGenLMKCDEY(params, skNTT, s > modHalf ? s - mod : s);
}

NativeInteger gen = NativeInteger(5);

(*ek)[0][1][0] = KeyGenAuto(params, skNTT, 2 * N - gen.ConvertToInt());

// m_window: window size, consider parameterization in the future
#pragma omp parallel for
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(numAutoKeys))
for (uint32_t i = 1; i <= numAutoKeys; ++i)
(*ek)[0][1][i] = KeyGenAuto(params, skNTT, gen.ModExp(i, 2 * N).ConvertToInt());
(*ek)[0][1][i] = KeyGenAuto(params, skNTT, gen.ModExp(i, 2 * N).ConvertToInt<LWEPlaintext>());
return ek;
}

Expand Down Expand Up @@ -247,6 +241,7 @@ void RingGSWAccumulatorLMKCDEY::AddToAccLMKCDEY(const std::shared_ptr<RingGSWCry
SignedDigitDecompose(params, ct, dct);

// calls digitsG2 NTTs
#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(digitsG2))
for (uint32_t d = 0; d < digitsG2; ++d)
dct[d].SetFormat(Format::EVALUATION);

Expand Down Expand Up @@ -281,6 +276,7 @@ void RingGSWAccumulatorLMKCDEY::Automorphism(const std::shared_ptr<RingGSWCrypto

SignedDigitDecompose(params, cta, dcta);

#pragma omp parallel for num_threads(OpenFHEParallelControls.GetThreadLimit(digitsG))
for (uint32_t d = 0; d < digitsG; ++d)
dcta[d].SetFormat(Format::EVALUATION);

Expand Down
9 changes: 7 additions & 2 deletions src/core/include/lattice/hal/dcrtpoly-interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -1390,13 +1390,18 @@ class DCRTPolyInterface : public ILElement<DerivedType, BigVecType> {
const std::vector<NativeInteger>& BModq, const std::vector<NativeInteger>& BModqPrecon) = 0;

/**
* @brief Convert from Coefficient to CRT or vice versa; calls FFT and inverse
* FFT.
* @brief Convert from Coefficient to CRT or vice versa; calls FFT and inverse FFT.
*
* @warning use @see SetFormat(format) instead
*/
void SwitchFormat() override = 0;

/**
* @brief Sets format to value without calling FFT. Only use if you know what you're doing.
*
*/
virtual void OverrideFormat(const Format f) = 0;

/**
* @brief Switch modulus and adjust the values
*
Expand Down
Loading