diff --git a/src/pke/include/constants.h b/src/pke/include/constants.h index 05f65c38b..3573320fe 100644 --- a/src/pke/include/constants.h +++ b/src/pke/include/constants.h @@ -152,7 +152,7 @@ enum PlaintextEncodings { std::ostream& operator<<(std::ostream& s, PlaintextEncodings p); enum LargeScalingFactorConstants { - MAX_BITS_IN_WORD = 62, + MAX_BITS_IN_WORD = 61, MAX_LOG_STEP = 60, }; diff --git a/src/pke/lib/encoding/ckkspackedencoding.cpp b/src/pke/lib/encoding/ckkspackedencoding.cpp index 5d6f26500..dc5b0b771 100644 --- a/src/pke/lib/encoding/ckkspackedencoding.cpp +++ b/src/pke/lib/encoding/ckkspackedencoding.cpp @@ -264,7 +264,7 @@ bool CKKSPackedEncoding::Encode() { double powP = scalingFactor; // Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer. - int32_t MAX_BITS_IN_WORD = 61; + int32_t MAX_BITS_IN_WORD = LargeScalingFactorConstants::MAX_BITS_IN_WORD; int32_t logc = 0; for (size_t i = 0; i < slots; ++i) { diff --git a/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp b/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp index 124236365..5a195c85c 100644 --- a/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp +++ b/src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp @@ -224,10 +224,10 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons } // Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer. - int32_t logSF = static_cast(ceil(log2(fabs(scFactor)))); - int32_t logValid = (logSF <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ? - logSF : - LargeScalingFactorConstants::MAX_BITS_IN_WORD; + int32_t logSF = static_cast(ceil(log2(fabs(constant * scFactor)))); + int32_t logValid = (logSF <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ? + logSF : + LargeScalingFactorConstants::MAX_BITS_IN_WORD; int32_t logApprox = logSF - logValid; double approxFactor = pow(2, logApprox); @@ -236,17 +236,17 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons // Scale back up by approxFactor within the CRT multiplications. if (logApprox > 0) { - int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? - logApprox : - LargeScalingFactorConstants::MAX_LOG_STEP; + int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? + logApprox : + LargeScalingFactorConstants::MAX_LOG_STEP; DCRTPoly::Integer intStep = uint64_t(1) << logStep; std::vector crtApprox(sizeQl, intStep); logApprox -= logStep; while (logApprox > 0) { - int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? - logApprox : - LargeScalingFactorConstants::MAX_LOG_STEP; + int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? + logApprox : + LargeScalingFactorConstants::MAX_LOG_STEP; DCRTPoly::Integer intStep = uint64_t(1) << logStep; std::vector crtSF(sizeQl, intStep); crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli); @@ -335,14 +335,14 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalMult(ConstCip #if defined(HAVE_INT128) typedef int128_t DoubleInteger; - int32_t MAX_BITS_IN_WORD = 126; + int32_t MAX_BITS_IN_WORD = 125; #else typedef int64_t DoubleInteger; int32_t MAX_BITS_IN_WORD = LargeScalingFactorConstants::MAX_BITS_IN_WORD; #endif // Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer. - int32_t logSF = static_cast(ceil(log2(fabs(scFactor)))); + int32_t logSF = static_cast(ceil(log2(fabs(constant * scFactor)))); int32_t logValid = (logSF <= MAX_BITS_IN_WORD) ? logSF : MAX_BITS_IN_WORD; int32_t logApprox = logSF - logValid; double approxFactor = pow(2, logApprox); @@ -372,17 +372,17 @@ std::vector LeveledSHECKKSRNS::GetElementForEvalMult(ConstCip // Scale back up by approxFactor within the CRT multiplications. if (logApprox > 0) { - int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? - logApprox : - LargeScalingFactorConstants::MAX_LOG_STEP; + int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? + logApprox : + LargeScalingFactorConstants::MAX_LOG_STEP; DCRTPoly::Integer intStep = uint64_t(1) << logStep; std::vector crtApprox(numTowers, intStep); logApprox -= logStep; while (logApprox > 0) { - int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? - logApprox : - LargeScalingFactorConstants::MAX_LOG_STEP; + int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ? + logApprox : + LargeScalingFactorConstants::MAX_LOG_STEP; DCRTPoly::Integer intStep = uint64_t(1) << logStep; std::vector crtSF(numTowers, intStep); crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli); diff --git a/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp b/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp index 3be524923..6258ad153 100644 --- a/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp +++ b/src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp @@ -590,9 +590,13 @@ class UTCKKSRNS : public ::testing::TestWithParam { const double epsHigh = 0.00001; + const double factor = 1 << 25; + const std::vector> vectorOfInts0_7{0, 1, 2, 3, 4, 5, 6, 7}; const std::vector> vectorOfInts0_7_Neg{0, -1, -2, -3, -4, -5, -6, -7}; const std::vector> vectorOfInts0_7_Add{0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5}; + const std::vector> vectorOfInts0_7_AddLargeScalar{ + 0 + factor, 1 + factor, 2 + factor, 3 + factor, 4 + factor, 5 + factor, 6 + factor, 7 + factor}; const std::vector> vectorOfInts0_7_Sub{-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5}; const std::vector> vectorOfInts0_7neg{0, -1, -2, -3, -4, -5, -6, -7}; const std::vector> vectorOfInts7_0{7, 6, 5, 4, 3, 2, 1, 0}; @@ -627,6 +631,8 @@ class UTCKKSRNS : public ::testing::TestWithParam { Plaintext plaintext1 = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7, 1, 0, nullptr, testData.slots); Plaintext plaintext1AddScalar = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_Add, 1, 0, nullptr, testData.slots); + Plaintext plaintext1AddLargeScalar = + cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_AddLargeScalar, 1, 0, nullptr, testData.slots); Plaintext plaintext1SubScalar = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_Sub, 1, 0, nullptr, testData.slots); Plaintext negatives1 = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7neg, 1, 0, nullptr, testData.slots); @@ -734,7 +740,7 @@ class UTCKKSRNS : public ::testing::TestWithParam { cc->Decrypt(kp.secretKey, cResult, &results); results->SetLength(plaintext1AddScalar->GetLength()); checkEquality(plaintext1AddScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue(), eps, - failmsg + " EvalAdd Ct and Double fails"); + failmsg + " EvalAdd Ct and double fails"); approximationErrors.emplace_back(CalculateApproximationError(plaintext1AddScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue())); @@ -743,7 +749,7 @@ class UTCKKSRNS : public ::testing::TestWithParam { cc->Decrypt(kp.secretKey, cResult, &results); results->SetLength(plaintext1SubScalar->GetLength()); checkEquality(plaintext1SubScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue(), eps, - failmsg + " EvalSub Ct and Double fails"); + failmsg + " EvalSub Ct and double fails"); approximationErrors.emplace_back(CalculateApproximationError(plaintext1SubScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue())); @@ -765,6 +771,13 @@ class UTCKKSRNS : public ::testing::TestWithParam { approximationErrors.emplace_back(CalculateApproximationError(plaintext1AddScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue())); + // Testing EvalAdd ciphertext + large double + cResult = cc->EvalAdd(ciphertext1, factor); + cc->Decrypt(kp.secretKey, cResult, &results); + results->SetLength(plaintext1AddLargeScalar->GetLength()); + checkEquality(plaintext1AddLargeScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue(), factor * eps, + failmsg + " EvalAdd Ct and large double fails"); + // Testing EvalNegate cResult = cc->EvalNegate(ciphertext1); cc->Decrypt(kp.secretKey, cResult, &results); @@ -831,7 +844,6 @@ class UTCKKSRNS : public ::testing::TestWithParam { Plaintext plaintextNeg = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_Neg, 1, 0, nullptr, testData.slots); Plaintext plaintextMult = cc->MakeCKKSPackedPlaintext( std::vector>({0, 6, 10, 12, 12, 10, 6, 0}), 1, 0, nullptr, testData.slots); - double factor = 1 << 25; Plaintext plaintextLarge = cc->MakeCKKSPackedPlaintext( std::vector>({factor, factor, 0, 0, 0, 0, 0, 0}), 1, 0, nullptr, testData.slots); Plaintext plaintextLargeMult = cc->MakeCKKSPackedPlaintext(