From a97c37a8ae0605bfb41bb9b67aa02f3b117730f1 Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Tue, 29 Oct 2024 12:19:35 +1100 Subject: [PATCH] Implement GetKey for KeyMap Create a `KeyMay` type to replace the current `BTreeMap` alias and implement `bitcoin::psbt::GetKey` for it. Close: #709 --- src/descriptor/key_map.rs | 322 ++++++++++++++++++++++++++++++++++++++ src/descriptor/mod.rs | 42 ++--- src/lib.rs | 6 +- 3 files changed, 342 insertions(+), 28 deletions(-) create mode 100644 src/descriptor/key_map.rs diff --git a/src/descriptor/key_map.rs b/src/descriptor/key_map.rs new file mode 100644 index 000000000..02a9afcd5 --- /dev/null +++ b/src/descriptor/key_map.rs @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: CC0-1.0 + +//! A map of public key to secret key. + +use core::iter; + +use bitcoin::psbt::{GetKey, GetKeyError, KeyRequest}; +use bitcoin::secp256k1::{Secp256k1, Signing}; + +#[cfg(doc)] +use super::Descriptor; +use super::{DescriptorKeyParseError, DescriptorPublicKey, DescriptorSecretKey, SinglePubKey}; +use crate::prelude::{btree_map, BTreeMap}; + +/// Alias type for a map of public key to secret key. +/// +/// This map is returned whenever a descriptor that contains secrets is parsed using +/// [`Descriptor::parse_descriptor`], since the descriptor will always only contain +/// public keys. This map allows looking up the corresponding secret key given a +/// public key from the descriptor. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct KeyMap { + map: BTreeMap, +} + +impl KeyMap { + /// Creates a new empty `KeyMap`. + #[inline] + pub fn new() -> Self { Self { map: BTreeMap::new() } } + + /// Inserts secret key into key map returning the associated public key. + #[inline] + pub fn insert( + &mut self, + secp: &Secp256k1, + sk: DescriptorSecretKey, + ) -> Result { + let pk = sk.to_public(secp)?; + if self.map.get(&pk).is_none() { + self.map.insert(pk.clone(), sk); + } + Ok(pk) + } + + /// Gets the secret key associated with `pk` if `pk` is in the map. + #[inline] + pub fn get(&self, pk: &DescriptorPublicKey) -> Option<&DescriptorSecretKey> { self.map.get(pk) } + + /// Returns the number of items in this map. + #[inline] + pub fn len(&self) -> usize { self.map.len() } + + /// Returns true if the map is empty. + #[inline] + pub fn is_empty(&self) -> bool { self.map.is_empty() } +} + +impl Default for KeyMap { + fn default() -> Self { Self::new() } +} + +impl IntoIterator for KeyMap { + type Item = (DescriptorPublicKey, DescriptorSecretKey); + type IntoIter = btree_map::IntoIter; + + #[inline] + fn into_iter(self) -> Self::IntoIter { self.map.into_iter() } +} + +impl iter::Extend<(DescriptorPublicKey, DescriptorSecretKey)> for KeyMap { + #[inline] + fn extend(&mut self, iter: T) + where + T: IntoIterator, + { + self.map.extend(iter) + } +} + +impl GetKey for KeyMap { + type Error = GetKeyError; + + fn get_key( + &self, + key_request: KeyRequest, + secp: &Secp256k1, + ) -> Result, Self::Error> { + Ok(self.map.iter().find_map(|(k, v)| { + match k { + DescriptorPublicKey::Single(ref pk) => match key_request { + KeyRequest::Pubkey(ref request) => match pk.key { + SinglePubKey::FullKey(ref pk) => { + if pk == request { + match v { + DescriptorSecretKey::Single(ref sk) => Some(sk.key), + _ => unreachable!("Single maps to Single"), + } + } else { + None + } + } + SinglePubKey::XOnly(_) => None, + }, + _ => None, + }, + // Performance: Might be faster to check the origin and then if it matches return + // the key directly instead of calling `get_key` on the xpriv. + DescriptorPublicKey::XPub(ref xpub) => { + let pk = xpub.xkey.public_key; + match key_request { + KeyRequest::Pubkey(ref request) => { + if pk == request.inner { + match v { + DescriptorSecretKey::XPrv(xpriv) => { + let xkey = xpriv.xkey; + if let Ok(child) = + xkey.derive_priv(secp, &xpriv.derivation_path) + { + Some(bitcoin::PrivateKey::new( + child.private_key, + xkey.network, + )) + } else { + None + } + } + _ => unreachable!("XPrv maps to XPrv"), + } + } else { + None + } + } + KeyRequest::Bip32(..) => match v { + DescriptorSecretKey::XPrv(xpriv) => { + // This clone goes away in next release of rust-bitcoin. + if let Ok(Some(sk)) = xpriv.xkey.get_key(key_request.clone(), secp) + { + Some(sk) + } else { + None + } + } + _ => unreachable!("XPrv maps to XPrv"), + }, + _ => unreachable!("rust-bitcoin v0.32"), + } + } + DescriptorPublicKey::MultiXPub(ref xpub) => { + let pk = xpub.xkey.public_key; + match key_request { + KeyRequest::Pubkey(ref request) => { + if pk == request.inner { + match v { + DescriptorSecretKey::MultiXPrv(xpriv) => { + Some(xpriv.xkey.to_priv()) + } + _ => unreachable!("MultiXPrv maps to MultiXPrv"), + } + } else { + None + } + } + KeyRequest::Bip32(..) => match v { + DescriptorSecretKey::MultiXPrv(xpriv) => { + // These clones goes away in next release of rust-bitcoin. + if let Ok(Some(sk)) = xpriv.xkey.get_key(key_request.clone(), secp) + { + Some(sk) + } else { + None + } + } + _ => unreachable!("MultiXPrv maps to MultiXPrv"), + }, + _ => unreachable!("rust-bitcoin v0.32"), + } + } + } + })) + } +} + +#[cfg(test)] +mod tests { + // use bitcoin::NetworkKind; + use bitcoin::bip32::{ChildNumber, IntoDerivationPath, Xpriv}; + + use super::*; + use crate::Descriptor; + + #[test] + fn get_key_single_key() { + let secp = Secp256k1::new(); + + let descriptor_sk_s = + "[90b6a706/44'/0'/0'/0/0]cMk8gWmj1KpjdYnAWwsEDekodMYhbyYBhG8gMtCCxucJ98JzcNij"; + + let single = match descriptor_sk_s.parse::().unwrap() { + DescriptorSecretKey::Single(single) => single, + _ => panic!("unexpected DescriptorSecretKey variant"), + }; + + let want_sk = single.key; + let descriptor_s = format!("wpkh({})", descriptor_sk_s); + let (_, keymap) = Descriptor::parse_descriptor(&secp, &descriptor_s).unwrap(); + + let pk = want_sk.public_key(&secp); + let request = KeyRequest::Pubkey(pk); + let got_sk = keymap + .get_key(request, &secp) + .expect("get_key call errored") + .expect("failed to find the key"); + assert_eq!(got_sk, want_sk) + } + + #[test] + fn get_key_xpriv_single_key_xpriv() { + let secp = Secp256k1::new(); + + let s = "xprv9s21ZrQH143K3QTDL4LXw2F7HEK3wJUD2nW2nRk4stbPy6cq3jPPqjiChkVvvNKmPGJxWUtg6LnF5kejMRNNU3TGtRBeJgk33yuGBxrMPHi"; + + let xpriv = s.parse::().unwrap(); + let xpriv_fingerprint = xpriv.fingerprint(&secp); + + // Sanity check. + { + let descriptor_sk_s = format!("[{}]{}", xpriv_fingerprint, xpriv); + let descriptor_sk = descriptor_sk_s.parse::().unwrap(); + let got = match descriptor_sk { + DescriptorSecretKey::XPrv(x) => x.xkey, + _ => panic!("unexpected DescriptorSecretKey variant"), + }; + assert_eq!(got, xpriv); + } + + let want_sk = xpriv.to_priv(); + let descriptor_s = format!("wpkh([{}]{})", xpriv_fingerprint, xpriv); + let (_, keymap) = Descriptor::parse_descriptor(&secp, &descriptor_s).unwrap(); + + let pk = want_sk.public_key(&secp); + let request = KeyRequest::Pubkey(pk); + let got_sk = keymap + .get_key(request, &secp) + .expect("get_key call errored") + .expect("failed to find the key"); + assert_eq!(got_sk, want_sk) + } + + #[test] + fn get_key_xpriv_child_depth_one() { + let secp = Secp256k1::new(); + + let s = "xprv9s21ZrQH143K3QTDL4LXw2F7HEK3wJUD2nW2nRk4stbPy6cq3jPPqjiChkVvvNKmPGJxWUtg6LnF5kejMRNNU3TGtRBeJgk33yuGBxrMPHi"; + let master = s.parse::().unwrap(); + let master_fingerprint = master.fingerprint(&secp); + + let child_number = ChildNumber::from_hardened_idx(44).unwrap(); + let child = master.derive_priv(&secp, &[child_number]).unwrap(); + + // Sanity check. + { + let descriptor_sk_s = format!("[{}/44']{}", master_fingerprint, child); + let descriptor_sk = descriptor_sk_s.parse::().unwrap(); + let got = match descriptor_sk { + DescriptorSecretKey::XPrv(ref x) => x.xkey, + _ => panic!("unexpected DescriptorSecretKey variant"), + }; + assert_eq!(got, child); + } + + let want_sk = child.to_priv(); + let descriptor_s = format!("wpkh({}/44')", s); + let (_, keymap) = Descriptor::parse_descriptor(&secp, &descriptor_s).unwrap(); + + let pk = want_sk.public_key(&secp); + let request = KeyRequest::Pubkey(pk); + let got_sk = keymap + .get_key(request, &secp) + .expect("get_key call errored") + .expect("failed to find the key"); + assert_eq!(got_sk, want_sk) + } + + #[test] + fn get_key_xpriv_with_path() { + let secp = Secp256k1::new(); + + let s = "xprv9s21ZrQH143K3QTDL4LXw2F7HEK3wJUD2nW2nRk4stbPy6cq3jPPqjiChkVvvNKmPGJxWUtg6LnF5kejMRNNU3TGtRBeJgk33yuGBxrMPHi"; + let master = s.parse::().unwrap(); + let master_fingerprint = master.fingerprint(&secp); + + let first_external_child = "44'/0'/0'/0/0"; + let derivation_path = first_external_child.into_derivation_path().unwrap(); + + let child = master.derive_priv(&secp, &derivation_path).unwrap(); + + // Sanity check. + { + let descriptor_sk_s = + format!("[{}/{}]{}", master_fingerprint, first_external_child, child); + let descriptor_sk = descriptor_sk_s.parse::().unwrap(); + let got = match descriptor_sk { + DescriptorSecretKey::XPrv(ref x) => x.xkey, + _ => panic!("unexpected DescriptorSecretKey variant"), + }; + assert_eq!(got, child); + } + + let want_sk = child.to_priv(); + let descriptor_s = format!("wpkh({}/44'/0'/0'/0/*)", s); + let (_, keymap) = Descriptor::parse_descriptor(&secp, &descriptor_s).unwrap(); + + let key_source = (master_fingerprint, derivation_path); + let request = KeyRequest::Bip32(key_source); + let got_sk = keymap + .get_key(request, &secp) + .expect("get_key call errored") + .expect("failed to find the key"); + + assert_eq!(got_sk, want_sk) + } +} diff --git a/src/descriptor/mod.rs b/src/descriptor/mod.rs index ce7654958..6afb7d3f1 100644 --- a/src/descriptor/mod.rs +++ b/src/descriptor/mod.rs @@ -46,20 +46,14 @@ pub use self::tr::{TapTree, Tr}; pub mod checksum; mod key; +mod key_map; pub use self::key::{ ConversionError, DefiniteDescriptorKey, DerivPaths, DescriptorKeyParseError, DescriptorMultiXKey, DescriptorPublicKey, DescriptorSecretKey, DescriptorXKey, InnerXKey, SinglePriv, SinglePub, SinglePubKey, Wildcard, }; - -/// Alias type for a map of public key to secret key -/// -/// This map is returned whenever a descriptor that contains secrets is parsed using -/// [`Descriptor::parse_descriptor`], since the descriptor will always only contain -/// public keys. This map allows looking up the corresponding secret key given a -/// public key from the descriptor. -pub type KeyMap = BTreeMap; +pub use self::key_map::KeyMap; /// Script descriptor #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -708,28 +702,24 @@ impl Descriptor { key_map: &mut KeyMap, secp: &secp256k1::Secp256k1, ) -> Result { - let (public_key, secret_key) = match DescriptorSecretKey::from_str(s) { - Ok(sk) => ( - sk.to_public(secp) - .map_err(|e| Error::Unexpected(e.to_string()))?, - Some(sk), - ), - Err(_) => ( + match DescriptorSecretKey::from_str(s) { + Ok(sk) => { + let pk = key_map + .insert(secp, sk) + .map_err(|e| Error::Unexpected(e.to_string()))?; + Ok(pk) + } + Err(_) => { // try to parse as a public key if parsing as a secret key failed - s.parse() - .map_err(|e| Error::Parse(ParseError::box_from_str(e)))?, - None, - ), - }; - - if let Some(secret_key) = secret_key { - key_map.insert(public_key.clone(), secret_key); + let pk = s + .parse() + .map_err(|e| Error::Parse(ParseError::box_from_str(e)))?; + Ok(pk) + } } - - Ok(public_key) } - let mut keymap_pk = KeyMapWrapper(BTreeMap::new(), secp); + let mut keymap_pk = KeyMapWrapper(KeyMap::new(), secp); struct KeyMapWrapper<'a, C: secp256k1::Signing>(KeyMap, &'a secp256k1::Secp256k1); diff --git a/src/lib.rs b/src/lib.rs index 8115afe13..359e37bfa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -752,7 +752,7 @@ mod prelude { pub use alloc::{ borrow::{Borrow, Cow, ToOwned}, boxed::Box, - collections::{vec_deque::VecDeque, BTreeMap, BTreeSet, BinaryHeap}, + collections::{btree_map, vec_deque::VecDeque, BTreeMap, BTreeSet, BinaryHeap}, rc, slice, string::{String, ToString}, sync, @@ -762,7 +762,9 @@ mod prelude { pub use std::{ borrow::{Borrow, Cow, ToOwned}, boxed::Box, - collections::{vec_deque::VecDeque, BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet}, + collections::{ + btree_map, vec_deque::VecDeque, BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, + }, rc, slice, string::{String, ToString}, sync,