diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp index 6f76d408d..8028a595e 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp @@ -248,14 +248,14 @@ Ciphertext FHECKKSRNS::EvalBootstrap(ConstCiphertext ciphert double timeDecode(0.0); #endif - auto cc = ciphertext->GetCryptoContext(); - uint32_t M = cc->GetCyclotomicOrder(); - uint32_t L0 = cryptoParams->GetElementParams()->GetParams().size(); + auto cc = ciphertext->GetCryptoContext(); + uint32_t M = cc->GetCyclotomicOrder(); + uint32_t L0 = cryptoParams->GetElementParams()->GetParams().size(); + auto initSizeQ = ciphertext->GetElements()[0].GetNumOfElements(); if (numIterations > 1) { // Step 1: Get the input. uint32_t powerOfTwoModulus = 1 << precision; - auto initSizeQ = ciphertext->GetElements()[0].GetNumOfElements(); // Step 2: Scale up by powerOfTwoModulus, and extend the modulus to powerOfTwoModulus * q. // Note that we extend the modulus implicitly without any code calls because the value always stays 0. @@ -275,10 +275,10 @@ Ciphertext FHECKKSRNS::EvalBootstrap(ConstCiphertext ciphert // We mod down, and leave the last CRT value to be 0 because it's divisible by powerOfTwoModulus. auto ctBootstrappedScaledDown = ctInitialBootstrap->Clone(); auto bootstrappingSizeQ = ctBootstrappedScaledDown->GetElements()[0].GetNumOfElements(); + + // If we start with more towers, than we obtain from bootstrapping, return the original ciphertext. if (bootstrappingSizeQ <= initSizeQ) { - OPENFHE_THROW(config_error, "Bootstrapping number of RNS moduli " + std::to_string(bootstrappingSizeQ) + - " must be greater than initial number of RNS moduli " + - std::to_string(initSizeQ)); + return ciphertext->Clone(); } for (auto& cv : ctBootstrappedScaledDown->GetElements()) { cv.DropLastElements(bootstrappingSizeQ - initSizeQ); @@ -612,6 +612,13 @@ Ciphertext FHECKKSRNS::EvalBootstrap(ConstCiphertext ciphert std::cout << "Decoding time: " << timeDecode / 1000.0 << " s" << std::endl; #endif + auto bootstrappingNumTowers = ctxtDec->GetElements()[0].GetNumOfElements(); + + // If we start with more towers, than we obtain from bootstrapping, return the original ciphertext. + if (bootstrappingNumTowers <= initSizeQ) { + return ciphertext->Clone(); + } + return ctxtDec; } diff --git a/src/pke/unittest/utckksrns/UnitTestBootstrap.cpp b/src/pke/unittest/utckksrns/UnitTestBootstrap.cpp index b48cb96fc..3cd9cd2bc 100644 --- a/src/pke/unittest/utckksrns/UnitTestBootstrap.cpp +++ b/src/pke/unittest/utckksrns/UnitTestBootstrap.cpp @@ -54,6 +54,7 @@ enum TEST_CASE_TYPE { BOOTSTRAP_SPARSE, BOOTSTRAP_KEY_SWITCH, BOOTSTRAP_ITERATIVE, + BOOTSTRAP_NUM_TOWERS, }; static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) { @@ -74,6 +75,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) { case BOOTSTRAP_ITERATIVE: typeName = "BOOTSTRAP_ITERATIVE"; break; + case BOOTSTRAP_NUM_TOWERS: + typeName = "BOOTSTRAP_NUM_TOWERS"; + break; default: typeName = "UNKNOWN"; break; @@ -259,6 +263,29 @@ static std::vector testCases = { { BOOTSTRAP_ITERATIVE, "14", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, { BOOTSTRAP_ITERATIVE, "15", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, { BOOTSTRAP_ITERATIVE, "16", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, +#endif + // ========================================== + // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode, LvlBudget, Dim1, Slots + { BOOTSTRAP_NUM_TOWERS, "01", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, + { BOOTSTRAP_NUM_TOWERS, "02", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, + { BOOTSTRAP_NUM_TOWERS, "03", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, + { BOOTSTRAP_NUM_TOWERS, "04", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, +#if NATIVEINT != 128 + { BOOTSTRAP_NUM_TOWERS, "05", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, + { BOOTSTRAP_NUM_TOWERS, "06", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, + { BOOTSTRAP_NUM_TOWERS, "07", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, + { BOOTSTRAP_NUM_TOWERS, "08", {CKKSRNS_SCHEME, 2048, MULT_DEPTH, SMODSIZE, DFLT, 8, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, 8}, +#endif + // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode, LvlBudget, Dim1, Slots + { BOOTSTRAP_NUM_TOWERS, "09", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, + { BOOTSTRAP_NUM_TOWERS, "10", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, + { BOOTSTRAP_NUM_TOWERS, "11", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, + { BOOTSTRAP_NUM_TOWERS, "12", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FIXEDMANUAL, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, +#if NATIVEINT != 128 + { BOOTSTRAP_NUM_TOWERS, "13", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, + { BOOTSTRAP_NUM_TOWERS, "14", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTO, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, + { BOOTSTRAP_NUM_TOWERS, "15", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, SPARSE_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, + { BOOTSTRAP_NUM_TOWERS, "16", {CKKSRNS_SCHEME, RDIM, MULT_DEPTH, SMODSIZE, DFLT, DFLT, UNIFORM_TERNARY, DFLT, FMODSIZE, HEStd_NotSet, HYBRID, FLEXIBLEAUTOEXT, NUM_LRG_DIGS, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, { 3, 2 }, { 0, 0 }, RDIM/2}, #endif // ========================================== }; @@ -507,6 +534,70 @@ class UTCKKSRNS_BOOT : public ::testing::TestWithParam std::string name("EMSCRIPTEN_UNKNOWN"); #else std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name())); +#endif + std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl; + // make it fail + EXPECT_TRUE(0 == 1) << failmsg; + } + } + + void UnitTest_Bootstrap_NumTowers(const TEST_CASE_UTCKKSRNS_BOOT& testData, + const std::string& failmsg = std::string()) { + // This test checks to make sure that we return the original ciphertext if we + // start with more towers than the number of towers we would end up with by + // bootstrapping. + try { + CryptoContext cc(UnitTestGenerateContext(testData.params)); + + cc->EvalBootstrapSetup(testData.levelBudget, testData.dim1, testData.slots); + + auto keyPair = cc->KeyGen(); + cc->EvalBootstrapKeyGen(keyPair.secretKey, testData.slots); + cc->EvalAtIndexKeyGen(keyPair.secretKey, {6}); + cc->EvalMultKeyGen(keyPair.secretKey); + + std::vector> input( + Fill({0.111111, 0.222222, 0.333333, 0.444444, 0.555555, 0.666666, 0.777777, 0.888888}, testData.slots)); + size_t encodedLength = input.size(); + + // We start with a ciphertext with 0 levels consumed. + Plaintext plaintext = cc->MakeCKKSPackedPlaintext(input); + auto ciphertext = cc->Encrypt(keyPair.publicKey, plaintext); + auto ciphertextAfter = cc->EvalBootstrap(ciphertext); + + auto initNumTowers = ciphertext->GetElements()[0].GetNumOfElements(); + auto bootstrappingNumTowers = ciphertextAfter->GetElements()[0].GetNumOfElements(); + // Check to make sure we don't lose any towers. + EXPECT_EQ(initNumTowers, bootstrappingNumTowers); + + Plaintext result; + cc->Decrypt(keyPair.secretKey, ciphertextAfter, &result); + result->SetLength(encodedLength); + auto actualResult = result->GetCKKSPackedValue(); + checkEquality(actualResult, plaintext->GetCKKSPackedValue(), eps, failmsg + " Bootstrapping failed"); + + auto ciphertextTwoIterations = cc->EvalBootstrap(ciphertext); + auto bootstrappingNumTowersTwoIterations = ciphertextTwoIterations->GetElements()[0].GetNumOfElements(); + // Check to make sure we don't lose any towers with double-iteration bootstrapping. + EXPECT_EQ(initNumTowers, bootstrappingNumTowersTwoIterations); + + Plaintext result2; + cc->Decrypt(keyPair.secretKey, ciphertextTwoIterations, &result2); + result->SetLength(encodedLength); + auto actualResult2 = result2->GetCKKSPackedValue(); + checkEquality(actualResult2, plaintext->GetCKKSPackedValue(), eps, + failmsg + " Bootstrapping with two iterations failed"); + } + catch (std::exception& e) { + std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl; + // make it fail + EXPECT_TRUE(0 == 1) << failmsg; + } + catch (...) { +#if defined EMSCRIPTEN + std::string name("EMSCRIPTEN_UNKNOWN"); +#else + std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name())); #endif std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl; // make it fail @@ -532,6 +623,9 @@ TEST_P(UTCKKSRNS_BOOT, CKKSRNS) { case BOOTSTRAP_ITERATIVE: UnitTest_Bootstrap_Iterative(test, test.buildTestName()); break; + case BOOTSTRAP_NUM_TOWERS: + UnitTest_Bootstrap_NumTowers(test, test.buildTestName()); + break; default: break; }