Skip to content

Commit

Permalink
fixed bugs with addition and multiplication of large doubles in CKKS (#…
Browse files Browse the repository at this point in the history
…436)

Co-authored-by: Yuriy Polyakov <ypolyakod@dualitytech.com>
  • Loading branch information
2 people authored and dsuponitskiy-duality committed Jun 13, 2023
1 parent d7c2859 commit 8694693
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/pke/include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ enum PlaintextEncodings {
std::ostream& operator<<(std::ostream& s, PlaintextEncodings p);

enum LargeScalingFactorConstants {
MAX_BITS_IN_WORD = 62,
MAX_BITS_IN_WORD = 61,
MAX_LOG_STEP = 60,
};

Expand Down
2 changes: 1 addition & 1 deletion src/pke/lib/encoding/ckkspackedencoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ bool CKKSPackedEncoding::Encode() {
double powP = scalingFactor;

// Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
int32_t MAX_BITS_IN_WORD = 61;
int32_t MAX_BITS_IN_WORD = LargeScalingFactorConstants::MAX_BITS_IN_WORD;

int32_t logc = 0;
for (size_t i = 0; i < slots; ++i) {
Expand Down
36 changes: 18 additions & 18 deletions src/pke/lib/scheme/ckksrns/ckksrns-leveledshe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons
}

// Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
int32_t logSF = static_cast<int32_t>(ceil(log2(fabs(scFactor))));
int32_t logValid = (logSF <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ?
logSF :
LargeScalingFactorConstants::MAX_BITS_IN_WORD;
int32_t logSF = static_cast<int32_t>(ceil(log2(fabs(constant * scFactor))));
int32_t logValid = (logSF <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ?
logSF :
LargeScalingFactorConstants::MAX_BITS_IN_WORD;
int32_t logApprox = logSF - logValid;
double approxFactor = pow(2, logApprox);

Expand All @@ -236,17 +236,17 @@ std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalAddOrSub(Cons

// Scale back up by approxFactor within the CRT multiplications.
if (logApprox > 0) {
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
DCRTPoly::Integer intStep = uint64_t(1) << logStep;
std::vector<DCRTPoly::Integer> crtApprox(sizeQl, intStep);
logApprox -= logStep;

while (logApprox > 0) {
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
DCRTPoly::Integer intStep = uint64_t(1) << logStep;
std::vector<DCRTPoly::Integer> crtSF(sizeQl, intStep);
crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli);
Expand Down Expand Up @@ -335,14 +335,14 @@ std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalMult(ConstCip

#if defined(HAVE_INT128)
typedef int128_t DoubleInteger;
int32_t MAX_BITS_IN_WORD = 126;
int32_t MAX_BITS_IN_WORD = 125;
#else
typedef int64_t DoubleInteger;
int32_t MAX_BITS_IN_WORD = LargeScalingFactorConstants::MAX_BITS_IN_WORD;
#endif

// Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
int32_t logSF = static_cast<int32_t>(ceil(log2(fabs(scFactor))));
int32_t logSF = static_cast<int32_t>(ceil(log2(fabs(constant * scFactor))));
int32_t logValid = (logSF <= MAX_BITS_IN_WORD) ? logSF : MAX_BITS_IN_WORD;
int32_t logApprox = logSF - logValid;
double approxFactor = pow(2, logApprox);
Expand Down Expand Up @@ -372,17 +372,17 @@ std::vector<DCRTPoly::Integer> LeveledSHECKKSRNS::GetElementForEvalMult(ConstCip

// Scale back up by approxFactor within the CRT multiplications.
if (logApprox > 0) {
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
DCRTPoly::Integer intStep = uint64_t(1) << logStep;
std::vector<DCRTPoly::Integer> crtApprox(numTowers, intStep);
logApprox -= logStep;

while (logApprox > 0) {
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox :
LargeScalingFactorConstants::MAX_LOG_STEP;
DCRTPoly::Integer intStep = uint64_t(1) << logStep;
std::vector<DCRTPoly::Integer> crtSF(numTowers, intStep);
crtApprox = CKKSPackedEncoding::CRTMult(crtApprox, crtSF, moduli);
Expand Down
18 changes: 15 additions & 3 deletions src/pke/unittest/utckksrns/UnitTestCKKSrns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,13 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {

const double epsHigh = 0.00001;

const double factor = 1 << 25;

const std::vector<std::complex<double>> vectorOfInts0_7{0, 1, 2, 3, 4, 5, 6, 7};
const std::vector<std::complex<double>> vectorOfInts0_7_Neg{0, -1, -2, -3, -4, -5, -6, -7};
const std::vector<std::complex<double>> vectorOfInts0_7_Add{0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5};
const std::vector<std::complex<double>> vectorOfInts0_7_AddLargeScalar{
0 + factor, 1 + factor, 2 + factor, 3 + factor, 4 + factor, 5 + factor, 6 + factor, 7 + factor};
const std::vector<std::complex<double>> vectorOfInts0_7_Sub{-0.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5};
const std::vector<std::complex<double>> vectorOfInts0_7neg{0, -1, -2, -3, -4, -5, -6, -7};
const std::vector<std::complex<double>> vectorOfInts7_0{7, 6, 5, 4, 3, 2, 1, 0};
Expand Down Expand Up @@ -627,6 +631,8 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
Plaintext plaintext1 = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7, 1, 0, nullptr, testData.slots);
Plaintext plaintext1AddScalar =
cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_Add, 1, 0, nullptr, testData.slots);
Plaintext plaintext1AddLargeScalar =
cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_AddLargeScalar, 1, 0, nullptr, testData.slots);
Plaintext plaintext1SubScalar =
cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_Sub, 1, 0, nullptr, testData.slots);
Plaintext negatives1 = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7neg, 1, 0, nullptr, testData.slots);
Expand Down Expand Up @@ -734,7 +740,7 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
cc->Decrypt(kp.secretKey, cResult, &results);
results->SetLength(plaintext1AddScalar->GetLength());
checkEquality(plaintext1AddScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue(), eps,
failmsg + " EvalAdd Ct and Double fails");
failmsg + " EvalAdd Ct and double fails");
approximationErrors.emplace_back(CalculateApproximationError<T>(plaintext1AddScalar->GetCKKSPackedValue(),
results->GetCKKSPackedValue()));

Expand All @@ -743,7 +749,7 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
cc->Decrypt(kp.secretKey, cResult, &results);
results->SetLength(plaintext1SubScalar->GetLength());
checkEquality(plaintext1SubScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue(), eps,
failmsg + " EvalSub Ct and Double fails");
failmsg + " EvalSub Ct and double fails");
approximationErrors.emplace_back(CalculateApproximationError<T>(plaintext1SubScalar->GetCKKSPackedValue(),
results->GetCKKSPackedValue()));

Expand All @@ -765,6 +771,13 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
approximationErrors.emplace_back(CalculateApproximationError<T>(plaintext1AddScalar->GetCKKSPackedValue(),
results->GetCKKSPackedValue()));

// Testing EvalAdd ciphertext + large double
cResult = cc->EvalAdd(ciphertext1, factor);
cc->Decrypt(kp.secretKey, cResult, &results);
results->SetLength(plaintext1AddLargeScalar->GetLength());
checkEquality(plaintext1AddLargeScalar->GetCKKSPackedValue(), results->GetCKKSPackedValue(), factor * eps,
failmsg + " EvalAdd Ct and large double fails");

// Testing EvalNegate
cResult = cc->EvalNegate(ciphertext1);
cc->Decrypt(kp.secretKey, cResult, &results);
Expand Down Expand Up @@ -831,7 +844,6 @@ class UTCKKSRNS : public ::testing::TestWithParam<TEST_CASE_UTCKKSRNS> {
Plaintext plaintextNeg = cc->MakeCKKSPackedPlaintext(vectorOfInts0_7_Neg, 1, 0, nullptr, testData.slots);
Plaintext plaintextMult = cc->MakeCKKSPackedPlaintext(
std::vector<std::complex<double>>({0, 6, 10, 12, 12, 10, 6, 0}), 1, 0, nullptr, testData.slots);
double factor = 1 << 25;
Plaintext plaintextLarge = cc->MakeCKKSPackedPlaintext(
std::vector<std::complex<double>>({factor, factor, 0, 0, 0, 0, 0, 0}), 1, 0, nullptr, testData.slots);
Plaintext plaintextLargeMult = cc->MakeCKKSPackedPlaintext(
Expand Down

0 comments on commit 8694693

Please # to comment.