Skip to content

Commit

Permalink
feat(cryptids): clean up copying of secret data in x25519
Browse files Browse the repository at this point in the history
  • Loading branch information
spmadden committed Feb 18, 2025
1 parent 9b8969e commit f3e28cc
Showing 1 changed file with 148 additions and 65 deletions.
213 changes: 148 additions & 65 deletions libraries/cryptids/src/x25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
#![allow(clippy::indexing_slicing)]
#![allow(clippy::manual_memcpy)]

use core::ops::{Add, Index, IndexMut, Mul, Not, Sub};
use core::ops::{AddAssign, MulAssign, SubAssign};
use core::ops::{Index, IndexMut};

macro_rules! zeroize {
($name:ident,$ty:ty) => {
fn $name(v: &mut [$ty]) {
v.iter_mut().for_each(|x| *x = 0);
}
};
}
zeroize!(zeroize_u8, u8);
zeroize!(zeroize_i64, i64);

const fn basepoint() -> [u8; 32] {
let mut out = [0u8; 32];
Expand Down Expand Up @@ -45,14 +56,24 @@ impl SecretKey {
SharedCurve25519Secret(scalarmult(&self.0, &pubkey.0))
}
}
impl Drop for SecretKey {
fn drop(&mut self) {
zeroize_u8(&mut self.0);
}
}
impl Curve25519PublicKey {
///
/// Generates a shared key using the original Curve25519 method of multiplying owned secret key with
/// a Curve25519PublicKey. This is probably not what you want.
/// a Curve25519PublicKey. This is probably not what you want.
pub fn generate_shared_secret(&self, secret_key: &SecretKey) -> SharedCurve25519Secret {
SharedCurve25519Secret(scalarmult(&secret_key.0, &self.0))
}
}
impl Drop for SharedCurve25519Secret {
fn drop(&mut self) {
zeroize_u8(&mut self.0);
}
}
impl AsRef<[u8]> for SharedCurve25519Secret {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
Expand All @@ -71,47 +92,52 @@ pub fn scalarmult(scalar: &[u8; 32], point: &[u8; 32]) -> [u8; 32] {
clamp[31] |= 0x40;

let x = FieldElement::unpack(point);
let mut b = x;
let mut b = x.clone();
let mut a = FieldElement([0i64; 16]);
let mut d = a;
let mut c = a;
let mut e: FieldElement;
let mut f: FieldElement;
let mut d = FieldElement([0i64; 16]);
let mut c = FieldElement([0i64; 16]);
let mut e = FieldElement([0i64; 16]);
let mut f = FieldElement([0i64; 16]);
a[0] = 1;
d[0] = 1;
for i in 0..=254 {
let i = 254 - i;
let bit = ((clamp[i >> 3] >> (i & 0x7)) & 1) as i64;
a.swap(&mut b, bit);
c.swap(&mut d, bit);
e = a + c;
a = a - c;
c = b + d;
b = b - d;
d = e * e;
f = a * a;
a = c * a;
c = b * e;
e = a + c;
a = a - c;
b = a * a;
c = d - f;
a = c * DB41;
a = a + d;
c = c * a;
a = d * f;
d = b * x;
b = e * e;
add(&mut e, &a, &c);
a -= &c;
add(&mut c, &b, &d);
b -= &d;
d.square_assign(&e);
f.square_assign(&a);
a.mul_rassign(&c);
mul(&mut c, &b, &e);
add(&mut e, &a, &c);
a -= &c;
b.square_assign(&a);
sub(&mut c, &d, &f);
mul(&mut a, &c, &DB41);
a += &d;
c *= &a;
mul(&mut a, &d, &f);
mul(&mut d, &b, &x);
b.square_assign(&e);
a.swap(&mut b, bit);
c.swap(&mut d, bit);
}
c = c.not();
a = a * c;
invert(&mut c);
a *= &c;
a.pack()
}

#[derive(Clone, Copy)]
#[derive(Clone)]
struct FieldElement([i64; 16]);
impl Drop for FieldElement {
fn drop(&mut self) {
zeroize_i64(&mut self.0);
}
}
const DB41: FieldElement = FieldElement([0xDB41, 0x1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
impl FieldElement {
fn carry(&mut self) {
Expand Down Expand Up @@ -168,6 +194,54 @@ impl FieldElement {
out[15] &= 0x7FFF;
out
}
fn square(&mut self) {
let mut t = [0i64; 32];
for i in 0..16 {
for j in 0..16 {
t[i + j] += self[i] * self[j];
}
}
for i in 0..15 {
t[i] += 38 * t[i + 16];
}
for i in 0..16 {
self[i] = t[i];
}
self.carry();
self.carry();
}
fn square_assign(&mut self, a: &FieldElement) {
let mut t = [0i64; 32];
for i in 0..16 {
for j in 0..16 {
t[i + j] += a[i] * a[j];
}
}
for i in 0..15 {
t[i] += 38 * t[i + 16];
}
for i in 0..16 {
self[i] = t[i];
}
self.carry();
self.carry();
}
fn mul_rassign(&mut self, rhs: &FieldElement) {
let mut t = [0i64; 32];
for i in 0..16 {
for j in 0..16 {
t[i + j] += rhs[i] * self[j];
}
}
for i in 0..15 {
t[i] += 38 * t[i + 16];
}
for i in 0..16 {
self[i] = t[i];
}
self.carry();
self.carry();
}
}
impl Index<usize> for FieldElement {
type Output = i64;
Expand All @@ -181,33 +255,49 @@ impl IndexMut<usize> for FieldElement {
&mut self.0[index]
}
}
impl Add for FieldElement {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
let mut out = [0i64; 16];
fn add(out: &mut FieldElement, a: &FieldElement, b: &FieldElement) {
for i in 0..16 {
out[i] = a[i] + b[i];
}
}
impl AddAssign<&FieldElement> for FieldElement {
fn add_assign(&mut self, rhs: &FieldElement) {
for i in 0..16 {
out[i] = self[i] + rhs[i];
self[i] += rhs[i];
}
FieldElement(out)
}
}
impl Sub for FieldElement {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
let mut out = [0i64; 16];
fn sub(out: &mut FieldElement, a: &FieldElement, b: &FieldElement) {
for i in 0..16 {
out[i] = a[i] - b[i];
}
}
impl SubAssign<&FieldElement> for FieldElement {
fn sub_assign(&mut self, rhs: &FieldElement) {
for i in 0..16 {
out[i] = self[i] - rhs[i];
self[i] -= rhs[i];
}
FieldElement(out)
}
}
impl Mul for FieldElement {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
let mut t = [0i64; 31];
fn mul(out: &mut FieldElement, a: &FieldElement, b: &FieldElement) {
let mut t = [0i64; 32];
for i in 0..16 {
for j in 0..16 {
t[i + j] += a[i] * b[j];
}
}
for i in 0..15 {
t[i] += 38 * t[i + 16];
}
for i in 0..16 {
out[i] = t[i];
}
out.carry();
out.carry();
}
impl MulAssign<&FieldElement> for FieldElement {
fn mul_assign(&mut self, rhs: &FieldElement) {
let mut t = [0i64; 32];
for i in 0..16 {
for j in 0..16 {
t[i + j] += self[i] * rhs[j];
Expand All @@ -216,30 +306,23 @@ impl Mul for FieldElement {
for i in 0..15 {
t[i] += 38 * t[i + 16];
}
let mut out = [0i64; 16];
for i in 0..16 {
out[i] = t[i];
self[i] = t[i];
}
let mut out = FieldElement(out);
out.carry();
out.carry();
out
self.carry();
self.carry();
}
}
impl Not for FieldElement {
type Output = Self;

fn not(self) -> Self::Output {
let mut out = self;
for i in 0..=253 {
let i = 253 - i;
out = out * out;
if i != 2 && i != 4 {
out = out * self;
}
fn invert(v: &mut FieldElement) {
let mut out = v.clone();
for i in 0..=253 {
let i = 253 - i;
out.square();
if i != 2 && i != 4 {
out *= v;
}
out
}
*v = out;
}

#[cfg(test)]
Expand Down

0 comments on commit f3e28cc

Please # to comment.