From f3e28cc8c3268a9634b67f6989dbeb1030dbd385 Mon Sep 17 00:00:00 2001 From: Sean Madden Date: Mon, 17 Feb 2025 22:09:58 -0500 Subject: [PATCH] feat(cryptids): clean up copying of secret data in x25519 --- libraries/cryptids/src/x25519.rs | 213 +++++++++++++++++++++---------- 1 file changed, 148 insertions(+), 65 deletions(-) diff --git a/libraries/cryptids/src/x25519.rs b/libraries/cryptids/src/x25519.rs index 2b333e3..77e2ec2 100644 --- a/libraries/cryptids/src/x25519.rs +++ b/libraries/cryptids/src/x25519.rs @@ -5,7 +5,18 @@ #![allow(clippy::indexing_slicing)] #![allow(clippy::manual_memcpy)] -use core::ops::{Add, Index, IndexMut, Mul, Not, Sub}; +use core::ops::{AddAssign, MulAssign, SubAssign}; +use core::ops::{Index, IndexMut}; + +macro_rules! zeroize { + ($name:ident,$ty:ty) => { + fn $name(v: &mut [$ty]) { + v.iter_mut().for_each(|x| *x = 0); + } + }; +} +zeroize!(zeroize_u8, u8); +zeroize!(zeroize_i64, i64); const fn basepoint() -> [u8; 32] { let mut out = [0u8; 32]; @@ -45,14 +56,24 @@ impl SecretKey { SharedCurve25519Secret(scalarmult(&self.0, &pubkey.0)) } } +impl Drop for SecretKey { + fn drop(&mut self) { + zeroize_u8(&mut self.0); + } +} impl Curve25519PublicKey { /// /// Generates a shared key using the original Curve25519 method of multiplying owned secret key with - /// a Curve25519PublicKey. This is probably not what you want. + /// a Curve25519PublicKey. This is probably not what you want. pub fn generate_shared_secret(&self, secret_key: &SecretKey) -> SharedCurve25519Secret { SharedCurve25519Secret(scalarmult(&secret_key.0, &self.0)) } } +impl Drop for SharedCurve25519Secret { + fn drop(&mut self) { + zeroize_u8(&mut self.0); + } +} impl AsRef<[u8]> for SharedCurve25519Secret { fn as_ref(&self) -> &[u8] { self.0.as_ref() @@ -71,12 +92,12 @@ pub fn scalarmult(scalar: &[u8; 32], point: &[u8; 32]) -> [u8; 32] { clamp[31] |= 0x40; let x = FieldElement::unpack(point); - let mut b = x; + let mut b = x.clone(); let mut a = FieldElement([0i64; 16]); - let mut d = a; - let mut c = a; - let mut e: FieldElement; - let mut f: FieldElement; + let mut d = FieldElement([0i64; 16]); + let mut c = FieldElement([0i64; 16]); + let mut e = FieldElement([0i64; 16]); + let mut f = FieldElement([0i64; 16]); a[0] = 1; d[0] = 1; for i in 0..=254 { @@ -84,34 +105,39 @@ pub fn scalarmult(scalar: &[u8; 32], point: &[u8; 32]) -> [u8; 32] { let bit = ((clamp[i >> 3] >> (i & 0x7)) & 1) as i64; a.swap(&mut b, bit); c.swap(&mut d, bit); - e = a + c; - a = a - c; - c = b + d; - b = b - d; - d = e * e; - f = a * a; - a = c * a; - c = b * e; - e = a + c; - a = a - c; - b = a * a; - c = d - f; - a = c * DB41; - a = a + d; - c = c * a; - a = d * f; - d = b * x; - b = e * e; + add(&mut e, &a, &c); + a -= &c; + add(&mut c, &b, &d); + b -= &d; + d.square_assign(&e); + f.square_assign(&a); + a.mul_rassign(&c); + mul(&mut c, &b, &e); + add(&mut e, &a, &c); + a -= &c; + b.square_assign(&a); + sub(&mut c, &d, &f); + mul(&mut a, &c, &DB41); + a += &d; + c *= &a; + mul(&mut a, &d, &f); + mul(&mut d, &b, &x); + b.square_assign(&e); a.swap(&mut b, bit); c.swap(&mut d, bit); } - c = c.not(); - a = a * c; + invert(&mut c); + a *= &c; a.pack() } -#[derive(Clone, Copy)] +#[derive(Clone)] struct FieldElement([i64; 16]); +impl Drop for FieldElement { + fn drop(&mut self) { + zeroize_i64(&mut self.0); + } +} const DB41: FieldElement = FieldElement([0xDB41, 0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); impl FieldElement { fn carry(&mut self) { @@ -168,6 +194,54 @@ impl FieldElement { out[15] &= 0x7FFF; out } + fn square(&mut self) { + let mut t = [0i64; 32]; + for i in 0..16 { + for j in 0..16 { + t[i + j] += self[i] * self[j]; + } + } + for i in 0..15 { + t[i] += 38 * t[i + 16]; + } + for i in 0..16 { + self[i] = t[i]; + } + self.carry(); + self.carry(); + } + fn square_assign(&mut self, a: &FieldElement) { + let mut t = [0i64; 32]; + for i in 0..16 { + for j in 0..16 { + t[i + j] += a[i] * a[j]; + } + } + for i in 0..15 { + t[i] += 38 * t[i + 16]; + } + for i in 0..16 { + self[i] = t[i]; + } + self.carry(); + self.carry(); + } + fn mul_rassign(&mut self, rhs: &FieldElement) { + let mut t = [0i64; 32]; + for i in 0..16 { + for j in 0..16 { + t[i + j] += rhs[i] * self[j]; + } + } + for i in 0..15 { + t[i] += 38 * t[i + 16]; + } + for i in 0..16 { + self[i] = t[i]; + } + self.carry(); + self.carry(); + } } impl Index for FieldElement { type Output = i64; @@ -181,33 +255,49 @@ impl IndexMut for FieldElement { &mut self.0[index] } } -impl Add for FieldElement { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - let mut out = [0i64; 16]; +fn add(out: &mut FieldElement, a: &FieldElement, b: &FieldElement) { + for i in 0..16 { + out[i] = a[i] + b[i]; + } +} +impl AddAssign<&FieldElement> for FieldElement { + fn add_assign(&mut self, rhs: &FieldElement) { for i in 0..16 { - out[i] = self[i] + rhs[i]; + self[i] += rhs[i]; } - FieldElement(out) } } -impl Sub for FieldElement { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - let mut out = [0i64; 16]; +fn sub(out: &mut FieldElement, a: &FieldElement, b: &FieldElement) { + for i in 0..16 { + out[i] = a[i] - b[i]; + } +} +impl SubAssign<&FieldElement> for FieldElement { + fn sub_assign(&mut self, rhs: &FieldElement) { for i in 0..16 { - out[i] = self[i] - rhs[i]; + self[i] -= rhs[i]; } - FieldElement(out) } } -impl Mul for FieldElement { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - let mut t = [0i64; 31]; +fn mul(out: &mut FieldElement, a: &FieldElement, b: &FieldElement) { + let mut t = [0i64; 32]; + for i in 0..16 { + for j in 0..16 { + t[i + j] += a[i] * b[j]; + } + } + for i in 0..15 { + t[i] += 38 * t[i + 16]; + } + for i in 0..16 { + out[i] = t[i]; + } + out.carry(); + out.carry(); +} +impl MulAssign<&FieldElement> for FieldElement { + fn mul_assign(&mut self, rhs: &FieldElement) { + let mut t = [0i64; 32]; for i in 0..16 { for j in 0..16 { t[i + j] += self[i] * rhs[j]; @@ -216,30 +306,23 @@ impl Mul for FieldElement { for i in 0..15 { t[i] += 38 * t[i + 16]; } - let mut out = [0i64; 16]; for i in 0..16 { - out[i] = t[i]; + self[i] = t[i]; } - let mut out = FieldElement(out); - out.carry(); - out.carry(); - out + self.carry(); + self.carry(); } } -impl Not for FieldElement { - type Output = Self; - - fn not(self) -> Self::Output { - let mut out = self; - for i in 0..=253 { - let i = 253 - i; - out = out * out; - if i != 2 && i != 4 { - out = out * self; - } +fn invert(v: &mut FieldElement) { + let mut out = v.clone(); + for i in 0..=253 { + let i = 253 - i; + out.square(); + if i != 2 && i != 4 { + out *= v; } - out } + *v = out; } #[cfg(test)]