Skip to content

Commit

Permalink
Return original ciphertext in bootstrapping if we already start with …
Browse files Browse the repository at this point in the history
…enough towers. (#314)

* Return original ciphertext if we already start with enough towers.

* Add comment to explain changes.
  • Loading branch information
sarojaduality authored Mar 2, 2023
1 parent 3bf9353 commit 7e0a7e1
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
double timeDecode(0.0);
#endif

auto cc = ciphertext->GetCryptoContext();
uint32_t M = cc->GetCyclotomicOrder();
uint32_t L0 = cryptoParams->GetElementParams()->GetParams().size();
auto cc = ciphertext->GetCryptoContext();
uint32_t M = cc->GetCyclotomicOrder();
uint32_t L0 = cryptoParams->GetElementParams()->GetParams().size();
auto initSizeQ = ciphertext->GetElements()[0].GetNumOfElements();

if (numIterations > 1) {
// Step 1: Get the input.
uint32_t powerOfTwoModulus = 1 << precision;
auto initSizeQ = ciphertext->GetElements()[0].GetNumOfElements();

// Step 2: Scale up by powerOfTwoModulus, and extend the modulus to powerOfTwoModulus * q.
// Note that we extend the modulus implicitly without any code calls because the value always stays 0.
Expand All @@ -275,10 +275,10 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
// We mod down, and leave the last CRT value to be 0 because it's divisible by powerOfTwoModulus.
auto ctBootstrappedScaledDown = ctInitialBootstrap->Clone();
auto bootstrappingSizeQ = ctBootstrappedScaledDown->GetElements()[0].GetNumOfElements();

// If we start with more towers, than we obtain from bootstrapping, return the original ciphertext.
if (bootstrappingSizeQ <= initSizeQ) {
OPENFHE_THROW(config_error, "Bootstrapping number of RNS moduli " + std::to_string(bootstrappingSizeQ) +
" must be greater than initial number of RNS moduli " +
std::to_string(initSizeQ));
return ciphertext->Clone();
}
for (auto& cv : ctBootstrappedScaledDown->GetElements()) {
cv.DropLastElements(bootstrappingSizeQ - initSizeQ);
Expand Down Expand Up @@ -612,6 +612,13 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
std::cout << "Decoding time: " << timeDecode / 1000.0 << " s" << std::endl;
#endif

auto bootstrappingNumTowers = ctxtDec->GetElements()[0].GetNumOfElements();

// If we start with more towers, than we obtain from bootstrapping, return the original ciphertext.
if (bootstrappingNumTowers <= initSizeQ) {
return ciphertext->Clone();
}

return ctxtDec;
}

Expand Down
94 changes: 94 additions & 0 deletions src/pke/unittest/utckksrns/UnitTestBootstrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ enum TEST_CASE_TYPE {
BOOTSTRAP_SPARSE,
BOOTSTRAP_KEY_SWITCH,
BOOTSTRAP_ITERATIVE,
BOOTSTRAP_NUM_TOWERS,
};

static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
Expand All @@ -74,6 +75,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) {
case BOOTSTRAP_ITERATIVE:
typeName = "BOOTSTRAP_ITERATIVE";
break;
case BOOTSTRAP_NUM_TOWERS:
typeName = "BOOTSTRAP_NUM_TOWERS";
break;
default:
typeName = "UNKNOWN";
break;
Expand Down Expand Up @@ -259,6 +263,29 @@ static std::vector<TEST_CASE_UTCKKSRNS_BOOT> testCases = {
{ BOOTSTRAP_ITERATIVE, "14", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_ITERATIVE, "15", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_ITERATIVE, "16", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
#endif
// ==========================================
// TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode, LvlBudget, Dim1, Slots
{ BOOTSTRAP_NUM_TOWERS, "01", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
{ BOOTSTRAP_NUM_TOWERS, "02", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
{ BOOTSTRAP_NUM_TOWERS, "03", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
{ BOOTSTRAP_NUM_TOWERS, "04", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
#if NATIVEINT != 128
{ BOOTSTRAP_NUM_TOWERS, "05", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
{ BOOTSTRAP_NUM_TOWERS, "06", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
{ BOOTSTRAP_NUM_TOWERS, "07", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
{ BOOTSTRAP_NUM_TOWERS, "08", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8},
#endif
// TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode, LvlBudget, Dim1, Slots
{ BOOTSTRAP_NUM_TOWERS, "09", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_NUM_TOWERS, "10", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_NUM_TOWERS, "11", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_NUM_TOWERS, "12", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
#if NATIVEINT != 128
{ BOOTSTRAP_NUM_TOWERS, "13", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_NUM_TOWERS, "14", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_NUM_TOWERS, "15", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
{ BOOTSTRAP_NUM_TOWERS, "16", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2},
#endif
// ==========================================
};
Expand Down Expand Up @@ -507,6 +534,70 @@ class UTCKKSRNS_BOOT : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS_BOOT>
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
}

void UnitTest_Bootstrap_NumTowers(const TEST_CASE_UTCKKSRNS_BOOT& testData,
const std::string& failmsg = std::string()) {
// This test checks to make sure that we return the original ciphertext if we
// start with more towers than the number of towers we would end up with by
// bootstrapping.
try {
CryptoContext<Element> cc(UnitTestGenerateContext(testData.params));

cc->EvalBootstrapSetup(testData.levelBudget, testData.dim1, testData.slots);

auto keyPair = cc->KeyGen();
cc->EvalBootstrapKeyGen(keyPair.secretKey, testData.slots);
cc->EvalAtIndexKeyGen(keyPair.secretKey, {6});
cc->EvalMultKeyGen(keyPair.secretKey);

std::vector<std::complex<double>> input(
Fill({0.111111, 0.222222, 0.333333, 0.444444, 0.555555, 0.666666, 0.777777, 0.888888}, testData.slots));
size_t encodedLength = input.size();

// We start with a ciphertext with 0 levels consumed.
Plaintext plaintext = cc->MakeCKKSPackedPlaintext(input);
auto ciphertext = cc->Encrypt(keyPair.publicKey, plaintext);
auto ciphertextAfter = cc->EvalBootstrap(ciphertext);

auto initNumTowers = ciphertext->GetElements()[0].GetNumOfElements();
auto bootstrappingNumTowers = ciphertextAfter->GetElements()[0].GetNumOfElements();
// Check to make sure we don't lose any towers.
EXPECT_EQ(initNumTowers, bootstrappingNumTowers);

Plaintext result;
cc->Decrypt(keyPair.secretKey, ciphertextAfter, &result);
result->SetLength(encodedLength);
auto actualResult = result->GetCKKSPackedValue();
checkEquality(actualResult, plaintext->GetCKKSPackedValue(), eps, failmsg + " Bootstrapping failed");

auto ciphertextTwoIterations = cc->EvalBootstrap(ciphertext);
auto bootstrappingNumTowersTwoIterations = ciphertextTwoIterations->GetElements()[0].GetNumOfElements();
// Check to make sure we don't lose any towers with double-iteration bootstrapping.
EXPECT_EQ(initNumTowers, bootstrappingNumTowersTwoIterations);

Plaintext result2;
cc->Decrypt(keyPair.secretKey, ciphertextTwoIterations, &result2);
result->SetLength(encodedLength);
auto actualResult2 = result2->GetCKKSPackedValue();
checkEquality(actualResult2, plaintext->GetCKKSPackedValue(), eps,
failmsg + " Bootstrapping with two iterations failed");
}
catch (std::exception& e) {
std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl;
// make it fail
EXPECT_TRUE(0 == 1) << failmsg;
}
catch (...) {
#if defined EMSCRIPTEN
std::string name("EMSCRIPTEN_UNKNOWN");
#else
std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name()));
#endif
std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl;
// make it fail
Expand All @@ -532,6 +623,9 @@ TEST_P(UTCKKSRNS_BOOT, CKKSRNS) {
case BOOTSTRAP_ITERATIVE:
UnitTest_Bootstrap_Iterative(test, test.buildTestName());
break;
case BOOTSTRAP_NUM_TOWERS:
UnitTest_Bootstrap_NumTowers(test, test.buildTestName());
break;
default:
break;
}
Expand Down

0 comments on commit 7e0a7e1

Please # to comment.