From d7c285931e3bace405a44744910016d6e81dacac Mon Sep 17 00:00:00 2001 From: yspolyakov <89226542+yspolyakov@users.noreply.github.com> Date: Mon, 12 Jun 2023 10:54:29 -0400 Subject: [PATCH] Issue 382 - Fix BFV security check (#432) * fixed a bug with BFV parameter generation * updated documentation * Added a test for ring dimension failure on security requirement --------- Co-authored-by: Yuriy Polyakov Co-authored-by: Dmitriy Suponitskiy --- .../bfvrns/bfvrns-parametergeneration.cpp | 141 ++++++++---------- src/pke/unittest/UnitTestSHE.cpp | 44 +++++- 2 files changed, 102 insertions(+), 83 deletions(-) diff --git a/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp b/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp index cd08ad2e4..f3b460f46 100644 --- a/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp +++ b/src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp @@ -151,11 +151,6 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr n) && (nCustom > 0)) - OPENFHE_THROW(config_error, - "Ring dimension n specified by the user does not meet the " - "security requirement. Please increase it."); - while (nRLWE(logq) > n) { n = 2 * n; logq = logqBFV(n); @@ -163,7 +158,6 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); double logqCeil = k * dcrtBits * log(2); @@ -187,47 +181,38 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr n) && (nCustom > 0)) - OPENFHE_THROW(config_error, - "Ring dimension n specified by the user does not meet the " - "security requirement. Please increase it."); - - // this "while" condition is needed in case the iterative solution for q - // changes the requirement for n, which is rare but still theoretically - // possible + // increases n up to the level where desired security level is achieved while (nRLWE(logq) > n) { - while (nRLWE(logq) > n) { - n = 2 * n; - logq = logqBFV(n, logqPrev); - logqPrev = logq; - } - - logq = logqBFV(n, logqPrev); - - while (fabs(logq - logqPrev) > log(1.001)) { - logqPrev = logq; - logq = logqBFV(n, logqPrev); - } - - // this code updates n and q to account for the discrete size of CRT - // moduli = dcrtBits - - int32_t k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); - - double logqCeil = k * dcrtBits * log(2); - logqPrev = logqCeil; - - while (nRLWE(logqCeil) > n) { - n = 2 * n; - logq = logqBFV(n, logqPrev); - k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); - logqCeil = k * dcrtBits * log(2); - logqPrev = logqCeil; - } + n = 2 * n; + logq = logqBFV(n, logqPrev); + logqPrev = logq; + } + + logq = logqBFV(n, logqPrev); + + // let logq converge with prescribed accuracy + while (fabs(logq - logqPrev) > log(1.001)) { + logqPrev = logq; + logq = logqBFV(n, logqPrev); + } + + // this code updates n and q to account for the discrete size of CRT + // moduli = dcrtBits + int32_t k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); + + double logqCeil = k * dcrtBits * log(2); + logqPrev = logqCeil; + + while (nRLWE(logqCeil) > n) { + n = 2 * n; + logq = logqBFV(n, logqPrev); + k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); + logqCeil = k * dcrtBits * log(2); + logqPrev = logqCeil; } } else if ((evalAddCount == 0) && (multiplicativeDepth > 0) && (keySwitchCount == 0)) { @@ -259,43 +244,35 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr n) && (nCustom > 0)) - OPENFHE_THROW(config_error, - "Ring dimension n specified by the user does not meet the " - "security requirement. Please increase it."); - - // this "while" condition is needed in case the iterative solution for q - // changes the requirement for n, which is rare but still theoretically - // possible + // increases n up to the level where desired security level is achieved while (nRLWE(logq) > n) { - while (nRLWE(logq) > n) { - n = 2 * n; - logq = logqBFV(n, logqPrev); - logqPrev = logq; - } - - logq = logqBFV(n, logqPrev); - - while (fabs(logq - logqPrev) > log(1.001)) { - logqPrev = logq; - logq = logqBFV(n, logqPrev); - } - - // this code updates n and q to account for the discrete size of CRT - // moduli = dcrtBits - - int32_t k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); - - double logqCeil = k * dcrtBits * log(2); - logqPrev = logqCeil; - - while (nRLWE(logqCeil) > n) { - n = 2 * n; - logq = logqBFV(n, logqPrev); - k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); - logqCeil = k * dcrtBits * log(2); - logqPrev = logqCeil; - } + n = 2 * n; + logq = logqBFV(n, logqPrev); + logqPrev = logq; + } + + logq = logqBFV(n, logqPrev); + + // let logq converge with prescribed accuracy + while (fabs(logq - logqPrev) > log(1.001)) { + logqPrev = logq; + logq = logqBFV(n, logqPrev); + } + + // this code updates n and q to account for the discrete size of CRT + // moduli = dcrtBits + + int32_t k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); + + double logqCeil = k * dcrtBits * log(2); + logqPrev = logqCeil; + + while (nRLWE(logqCeil) > n) { + n = 2 * n; + logq = logqBFV(n, logqPrev); + k = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); + logqCeil = k * dcrtBits * log(2); + logqPrev = logqCeil; } } else if ((multiplicativeDepth && (evalAddCount || keySwitchCount)) || (evalAddCount && keySwitchCount)) { @@ -308,6 +285,12 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr nCustom) && (nCustom > 0)) + OPENFHE_THROW(config_error, "Ring dimension " + std::to_string(nCustom) + + " specified by the user does not meet the " + "security requirement. Please increase it to " + + std::to_string(n) + "."); + const size_t numInitialModuli = static_cast(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits)); if (numInitialModuli < 1) OPENFHE_THROW(config_error, "numInitialModuli must be greater than 0."); diff --git a/src/pke/unittest/UnitTestSHE.cpp b/src/pke/unittest/UnitTestSHE.cpp index 08338569d..4fad376e3 100644 --- a/src/pke/unittest/UnitTestSHE.cpp +++ b/src/pke/unittest/UnitTestSHE.cpp @@ -58,7 +58,8 @@ enum TEST_CASE_TYPE { EVALSUM_ALL, KS_SINGLE_CRT, KS_MOD_REDUCE_DCRT, - EVALSQUARE + EVALSQUARE, + RING_DIM_ERROR_HANDLING }; static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) { @@ -97,6 +98,9 @@ static std::ostream& operator<<(std::ostream& os, const TEST_CASE_TYPE& type) { case EVALSQUARE: typeName = "EVALSQUARE"; break; + case RING_DIM_ERROR_HANDLING: + typeName = "RING_DIM_ERROR_HANDLING"; + break; default: typeName = "UNKNOWN"; break; @@ -383,8 +387,8 @@ static std::vector testCases = { { METADATA, "24", {BFVRNS_SCHEME, DFLT, DFLT, DFLT, 20, BATCH, GAUSSIAN, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, HPSPOVERQLEVELED, EXTENDED, DFLT}, }, // ========================================== // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode - { EVALSUM_ALL, "01", {BFVRNS_SCHEME, BATCH_LRG, DFLT, DFLT, 20, BATCH_LRG, DFLT, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, DFLT, STANDARD, DFLT}, }, - { EVALSUM_ALL, "02", {BFVRNS_SCHEME, BATCH_LRG, DFLT, DFLT, 20, BATCH_LRG, DFLT, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, DFLT, EXTENDED, DFLT}, }, + { EVALSUM_ALL, "01", {BFVRNS_SCHEME, BATCH_LRG, 0, DFLT, 20, BATCH_LRG, DFLT, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, 12, DFLT, STANDARD, DFLT}, }, + { EVALSUM_ALL, "02", {BFVRNS_SCHEME, BATCH_LRG, 0, DFLT, 20, BATCH_LRG, DFLT, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, 12, DFLT, EXTENDED, DFLT}, }, // ========================================== // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode { KS_SINGLE_CRT, "01", {BGVRNS_SCHEME, 1<<13, 1, DFLT, 1, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, 256, 4, DFLT, DFLT, DFLT, STANDARD, DFLT}, }, @@ -397,7 +401,7 @@ static std::vector testCases = { // Calling ModReduce in the AUTO modes doesn't do anything because we automatically mod reduce before multiplication, // so we don't need unit tests for KS_MOD_REDUCE_DCRT in the AUTO modes. // ========================================== - // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode + // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode { EVALSQUARE, "01", {BGVRNS_SCHEME, DFLT, 3, DFLT, 20, BATCH, UNIFORM_TERNARY, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, HPS, STANDARD, DFLT}, }, { EVALSQUARE, "02", {BGVRNS_SCHEME, 256, 3, DFLT, BV_DSIZE, BATCH, UNIFORM_TERNARY, 1, 60, HEStd_NotSet, BV, FIXEDAUTO, DFLT, PTM_LRG, DFLT, DFLT, DFLT, DFLT, STANDARD, DFLT}, }, { EVALSQUARE, "03", {BGVRNS_SCHEME, 256, 3, DFLT, BV_DSIZE, BATCH, UNIFORM_TERNARY, 1, 60, HEStd_NotSet, BV, FLEXIBLEAUTO, DFLT, PTM_LRG, DFLT, DFLT, DFLT, DFLT, STANDARD, DFLT}, }, @@ -422,6 +426,9 @@ static std::vector testCases = { { EVALSQUARE, "22", {BFVRNS_SCHEME, DFLT, 3, DFLT, 20, BATCH, GAUSSIAN, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, HPSPOVERQ, EXTENDED, DFLT}, }, { EVALSQUARE, "23", {BFVRNS_SCHEME, DFLT, 3, DFLT, 20, BATCH, UNIFORM_TERNARY, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, HPSPOVERQLEVELED, EXTENDED, DFLT}, }, { EVALSQUARE, "24", {BFVRNS_SCHEME, DFLT, 3, DFLT, 20, BATCH, GAUSSIAN, DFLT, DFLT, DFLT, DFLT, FIXEDMANUAL, DFLT, PTM_LRG, DFLT, DFLT, DFLT, HPSPOVERQLEVELED, EXTENDED, DFLT}, }, + // ========================================== + // TestType, Descr, Scheme, RDim, MultDepth, SModSize, DSize, BatchSz, SecKeyDist, MaxRelinSkDeg, FModSize, SecLvl, KSTech, ScalTech, LDigits, PtMod, StdDev, EvalAddCt, KSCt, MultTech, EncTech, PREMode + { RING_DIM_ERROR_HANDLING, "01", {BFVRNS_SCHEME, 1<<13, 3, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT, 4293918721, DFLT, DFLT, DFLT, DFLT, DFLT, DFLT}, }, }; // clang-format on //=========================================================================================================== @@ -1199,6 +1206,31 @@ class UTGENERAL_SHE : public ::testing::TestWithParam { EXPECT_TRUE(0 == 1) << failmsg; } } + + void UnitTest_BFV_Ringdimension_Security_Check(const TEST_CASE_UTGENERAL_SHE& testData, + const std::string& failmsg = std::string()) { + try { + CryptoContext cc(UnitTestGenerateContext(testData.params)); + + // make it fail if there is no exception thrown + EXPECT_EQ(0, 1); + } + catch (std::exception& e) { + // we expect to catch an exception for this test as ring dimension should not meet the security requirement + // std::cerr << "Exception thrown from " << __func__ << "(): " << e.what() << std::endl; + EXPECT_EQ(1, 1); + } + catch (...) { +#if defined EMSCRIPTEN + std::string name("EMSCRIPTEN_UNKNOWN"); +#else + std::string name(demangle(__cxxabiv1::__cxa_current_exception_type()->name())); +#endif + std::cerr << "Unknown exception of type \"" << name << "\" thrown from " << __func__ << "()" << std::endl; + // make it fail + EXPECT_EQ(0, 1) << failmsg; + } + } }; //=========================================================================================================== TEST_P(UTGENERAL_SHE, SHE) { @@ -1238,6 +1270,10 @@ TEST_P(UTGENERAL_SHE, SHE) { break; case EVALSQUARE: UnitTest_EvalSquare(test, test.buildTestName()); + break; + case RING_DIM_ERROR_HANDLING: + UnitTest_BFV_Ringdimension_Security_Check(test, test.buildTestName()); + break; default: break; }