From ebe1c75579170072dc59b8dee2b55ce31663178f Mon Sep 17 00:00:00 2001 From: Peter Dettman Date: Wed, 20 Mar 2024 13:57:09 +0700 Subject: [PATCH] EdDSA: Explicit guard against infinite looping --- .../bouncycastle/math/ec/rfc8032/Ed25519.java | 14 ++++++++-- .../bouncycastle/math/ec/rfc8032/Ed448.java | 14 ++++++++-- .../math/ec/rfc8032/Scalar25519.java | 10 ++++++- .../math/ec/rfc8032/Scalar448.java | 10 ++++++- .../math/ec/rfc8032/ScalarUtil.java | 28 +++++++++---------- 5 files changed, 56 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed25519.java b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed25519.java index a9248176b1..7aedfb90ef 100644 --- a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed25519.java +++ b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed25519.java @@ -566,7 +566,12 @@ private static boolean implVerify(byte[] sig, int sigOff, byte[] pk, int pkOff, int[] v0 = new int[4]; int[] v1 = new int[4]; - Scalar25519.reduceBasisVar(nA, v0, v1); + + if (!Scalar25519.reduceBasisVar(nA, v0, v1)) + { + throw new IllegalStateException(); + } + Scalar25519.multiply128Var(nS, v1, nS); PointAccum pZ = new PointAccum(); @@ -628,7 +633,12 @@ private static boolean implVerify(byte[] sig, int sigOff, PublicPoint publicPoin int[] v0 = new int[4]; int[] v1 = new int[4]; - Scalar25519.reduceBasisVar(nA, v0, v1); + + if (!Scalar25519.reduceBasisVar(nA, v0, v1)) + { + throw new IllegalStateException(); + } + Scalar25519.multiply128Var(nS, v1, nS); PointAccum pZ = new PointAccum(); diff --git a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed448.java b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed448.java index 9fc9bed9b7..2d2053e649 100644 --- a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed448.java +++ b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Ed448.java @@ -510,7 +510,12 @@ private static boolean implVerify(byte[] sig, int sigOff, byte[] pk, int pkOff, int[] v0 = new int[8]; int[] v1 = new int[8]; - Scalar448.reduceBasisVar(nA, v0, v1); + + if (!Scalar448.reduceBasisVar(nA, v0, v1)) + { + throw new IllegalStateException(); + } + Scalar448.multiply225Var(nS, v1, nS); PointProjective pZ = new PointProjective(); @@ -569,7 +574,12 @@ private static boolean implVerify(byte[] sig, int sigOff, PublicPoint publicPoin int[] v0 = new int[8]; int[] v1 = new int[8]; - Scalar448.reduceBasisVar(nA, v0, v1); + + if (!Scalar448.reduceBasisVar(nA, v0, v1)) + { + throw new IllegalStateException(); + } + Scalar448.multiply225Var(nS, v1, nS); PointProjective pZ = new PointProjective(); diff --git a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar25519.java b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar25519.java index a760625798..175513fba4 100644 --- a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar25519.java +++ b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar25519.java @@ -295,7 +295,7 @@ static byte[] reduce512(byte[] n) return r; } - static void reduceBasisVar(int[] k, int[] z0, int[] z1) + static boolean reduceBasisVar(int[] k, int[] z0, int[] z1) { /* * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L. @@ -312,11 +312,18 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1) int[] v0 = new int[4]; System.arraycopy(k, 0, v0, 0, 4); int[] v1 = new int[4]; v1[0] = 1; + // Conservative upper bound on the number of loop iterations needed + int iterations = TARGET_LENGTH * 4; int last = 15; int len_Nv = ScalarUtil.getBitLengthPositive(last, Nv); while (len_Nv > TARGET_LENGTH) { + if (--iterations < 0) + { + return false; + } + int len_p = ScalarUtil.getBitLength(last, p); int s = len_p - len_Nv; s &= ~(s >> 31); @@ -346,6 +353,7 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1) // v1 * k == v0 mod L System.arraycopy(v0, 0, z0, 0, 4); System.arraycopy(v1, 0, z1, 0, 4); + return true; } static void toSignedDigits(int bits, int[] z) diff --git a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar448.java b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar448.java index edac5c02da..f6ba495f97 100644 --- a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar448.java +++ b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/Scalar448.java @@ -560,7 +560,7 @@ static byte[] reduce912(byte[] n) return r; } - static void reduceBasisVar(int[] k, int[] z0, int[] z1) + static boolean reduceBasisVar(int[] k, int[] z0, int[] z1) { /* * Split scalar k into two half-size scalars z0 and z1, such that z1 * k == z0 mod L. @@ -577,11 +577,18 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1) int[] v0 = new int[8]; System.arraycopy(k, 0, v0, 0, 8); int[] v1 = new int[8]; v1[0] = 1; + // Conservative upper bound on the number of loop iterations needed + int iterations = TARGET_LENGTH * 4; int last = 27; int len_Nv = ScalarUtil.getBitLengthPositive(last, Nv); while (len_Nv > TARGET_LENGTH) { + if (--iterations < 0) + { + return false; + } + int len_p = ScalarUtil.getBitLength(last, p); int s = len_p - len_Nv; s &= ~(s >> 31); @@ -614,6 +621,7 @@ static void reduceBasisVar(int[] k, int[] z0, int[] z1) // v1 * k == v0 mod L System.arraycopy(v0, 0, z0, 0, 8); System.arraycopy(v1, 0, z1, 0, 8); + return true; } static void toSignedDigits(int bits, int[] x, int[] z) diff --git a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/ScalarUtil.java b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/ScalarUtil.java index 100fd485e7..483bf477d9 100644 --- a/core/src/main/java/org/bouncycastle/math/ec/rfc8032/ScalarUtil.java +++ b/core/src/main/java/org/bouncycastle/math/ec/rfc8032/ScalarUtil.java @@ -22,11 +22,11 @@ static void addShifted_NP(int last, int s, int[] Nu, int[] Nv, int[] p, int[] t) cc_p += p_i & M; cc_p += Nv[i] & M; - p_i = (int)cc_p; cc_p >>= 32; - p[i] = p_i; + p_i = (int)cc_p; cc_p >>>= 32; + p[i] = p_i; cc_Nu += p_i & M; - Nu[i] = (int)cc_Nu; cc_Nu >>= 32; + Nu[i] = (int)cc_Nu; cc_Nu >>>= 32; } } else if (s < 32) @@ -50,20 +50,20 @@ else if (s < 32) cc_p += p_i & M; cc_p += v_s & M; - p_i = (int)cc_p; cc_p >>= 32; + p_i = (int)cc_p; cc_p >>>= 32; p[i] = p_i; int q_s = (p_i << s) | (prev_q >>> -s); - prev_q =p_i; + prev_q = p_i; cc_Nu += q_s & M; - Nu[i] = (int)cc_Nu; cc_Nu >>= 32; + Nu[i] = (int)cc_Nu; cc_Nu >>>= 32; } } else { - // Keep the original value of p in t. - System.arraycopy(p, 0, t, 0, p.length); + // Copy the low limbs of the original p + System.arraycopy(p, 0, t, 0, last); int sWords = s >>> 5; int sBits = s & 31; if (sBits == 0) @@ -75,10 +75,10 @@ else if (s < 32) cc_p += p[i] & M; cc_p += Nv[i - sWords] & M; - p[i] = (int)cc_p; cc_p >>= 32; + p[i] = (int)cc_p; cc_p >>>= 32; cc_Nu += p[i - sWords] & M; - Nu[i] = (int)cc_Nu; cc_Nu >>= 32; + Nu[i] = (int)cc_Nu; cc_Nu >>>= 32; } } else @@ -102,14 +102,14 @@ else if (s < 32) cc_p += p[i] & M; cc_p += v_s & M; - p[i] = (int)cc_p; cc_p >>= 32; + p[i] = (int)cc_p; cc_p >>>= 32; int next_q = p[i - sWords]; int q_s = (next_q << sBits) | (prev_q >>> -sBits); prev_q = next_q; cc_Nu += q_s & M; - Nu[i] = (int)cc_Nu; cc_Nu >>= 32; + Nu[i] = (int)cc_Nu; cc_Nu >>>= 32; } } } @@ -251,8 +251,8 @@ else if (s < 32) } else { - // Keep the original value of p in t. - System.arraycopy(p, 0, t, 0, p.length); + // Copy the low limbs of the original p + System.arraycopy(p, 0, t, 0, last); int sWords = s >>> 5; int sBits = s & 31; if (sBits == 0)