diff --git a/lightning/src/offers/merkle.rs b/lightning/src/offers/merkle.rs index 1a38fe5441f..0de1e2170f2 100644 --- a/lightning/src/offers/merkle.rs +++ b/lightning/src/offers/merkle.rs @@ -280,6 +280,497 @@ impl<'a> Writeable for TlvRecord<'a> { } } +// ============================================================================ +// Selective Disclosure for Payer Proofs (BOLT 12 extension) +// ============================================================================ + +use alloc::collections::BTreeSet; + +/// Error during selective disclosure operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SelectiveDisclosureError { + /// The omitted_tlvs markers are not in strict ascending order. + InvalidOmittedTlvsOrder, + /// The omitted_tlvs contains an invalid marker (0 or signature type). + InvalidOmittedTlvsMarker, + /// The leaf_hashes count doesn't match included TLVs. + LeafHashCountMismatch, + /// Insufficient missing_hashes to reconstruct the tree. + InsufficientMissingHashes, + /// Excess missing_hashes after reconstruction. + ExcessMissingHashes, + /// The TLV stream is empty. + EmptyTlvStream, +} + +/// Data needed to reconstruct a merkle root with selective disclosure. +/// +/// This is used in payer proofs to allow verification of an invoice signature +/// without revealing all invoice fields. +#[derive(Clone, Debug, PartialEq)] +pub struct SelectiveDisclosure { + /// Nonce hashes for included TLVs (in TLV type order). + pub leaf_hashes: Vec, + /// Marker numbers for omitted TLVs (excluding implicit TLV0). + pub omitted_tlvs: Vec, + /// Minimal merkle hashes for omitted subtrees. + pub missing_hashes: Vec, + /// The complete merkle root. + pub merkle_root: sha256::Hash, +} + +/// Internal data for each TLV during tree construction. +struct TlvMerkleData { + tlv_type: u64, + per_tlv_hash: sha256::Hash, + nonce_hash: sha256::Hash, + is_included: bool, +} + +/// Compute selective disclosure data from a TLV stream. +/// +/// This builds the full merkle tree and extracts the data needed for a payer proof: +/// - `leaf_hashes`: nonce hashes for included TLVs +/// - `omitted_tlvs`: marker numbers for omitted TLVs +/// - `missing_hashes`: minimal merkle hashes for omitted subtrees +/// +/// # Arguments +/// * `tlv_bytes` - Complete TLV stream (e.g., invoice bytes without signature) +/// * `included_types` - Set of TLV types to include in the disclosure +pub(super) fn compute_selective_disclosure( + tlv_bytes: &[u8], included_types: &BTreeSet, +) -> Result { + let mut tlv_stream = TlvStream::new(tlv_bytes).peekable(); + let first_record = tlv_stream.peek().ok_or(SelectiveDisclosureError::EmptyTlvStream)?; + let nonce_tag_hash = sha256::Hash::from_engine({ + let mut engine = sha256::Hash::engine(); + engine.input("LnNonce".as_bytes()); + engine.input(first_record.record_bytes); + engine + }); + + let leaf_tag = tagged_hash_engine(sha256::Hash::hash("LnLeaf".as_bytes())); + let nonce_tag = tagged_hash_engine(nonce_tag_hash); + let branch_tag = tagged_hash_engine(sha256::Hash::hash("LnBranch".as_bytes())); + + let mut tlv_data: Vec = Vec::new(); + for record in tlv_stream.filter(|r| !SIGNATURE_TYPES.contains(&r.r#type)) { + let leaf_hash = tagged_hash_from_engine(leaf_tag.clone(), record.record_bytes); + let nonce_hash = tagged_hash_from_engine(nonce_tag.clone(), record.type_bytes); + let per_tlv_hash = + tagged_branch_hash_from_engine(branch_tag.clone(), leaf_hash, nonce_hash); + + let is_included = included_types.contains(&record.r#type); + tlv_data.push(TlvMerkleData { + tlv_type: record.r#type, + per_tlv_hash, + nonce_hash, + is_included, + }); + } + + if tlv_data.is_empty() { + return Err(SelectiveDisclosureError::EmptyTlvStream); + } + + let leaf_hashes: Vec<_> = + tlv_data.iter().filter(|d| d.is_included).map(|d| d.nonce_hash).collect(); + let omitted_tlvs = compute_omitted_markers(&tlv_data); + let (merkle_root, missing_hashes) = build_tree_with_disclosure(&tlv_data, &branch_tag); + + Ok(SelectiveDisclosure { leaf_hashes, omitted_tlvs, missing_hashes, merkle_root }) +} + +/// Compute omitted_tlvs marker numbers per BOLT 12 payer proof spec. +fn compute_omitted_markers(tlv_data: &[TlvMerkleData]) -> Vec { + let mut markers = Vec::new(); + let mut prev_included_type: Option = None; + let mut prev_marker: Option = None; + + for data in tlv_data { + if data.tlv_type == 0 { + continue; + } + + if !data.is_included { + let marker = if let Some(prev_type) = prev_included_type { + prev_type + 1 + } else if let Some(last_marker) = prev_marker { + last_marker + 1 + } else { + 1 + }; + + markers.push(marker); + prev_marker = Some(marker); + prev_included_type = None; + } else { + prev_included_type = Some(data.tlv_type); + prev_marker = None; + } + } + + markers +} + +/// Build merkle tree and collect missing_hashes for omitted subtrees. +/// +/// Returns hashes sorted by ascending TLV type as required by the spec. For internal +/// nodes, the type used for ordering is the minimum TLV type in that subtree. +fn build_tree_with_disclosure( + tlv_data: &[TlvMerkleData], branch_tag: &sha256::HashEngine, +) -> (sha256::Hash, Vec) { + let n = tlv_data.len(); + debug_assert!(n > 0, "TLV stream must contain at least one record"); + + let num_leaves = n * 2; + let mut hashes: Vec> = vec![None; num_leaves]; + let mut is_included: Vec = vec![false; num_leaves]; + let mut min_types: Vec = vec![u64::MAX; num_leaves]; + + for (i, data) in tlv_data.iter().enumerate() { + let pos = i * 2; + hashes[pos] = Some(data.per_tlv_hash); + is_included[pos] = data.is_included; + min_types[pos] = data.tlv_type; + } + + let mut missing_with_types: Vec<(u64, sha256::Hash)> = Vec::new(); + + for level in 1.. { + let step = 2 << level; + let offset = step / 2; + if offset >= num_leaves { + break; + } + + let left_positions: Vec<_> = (0..num_leaves).step_by(step).collect(); + let right_positions: Vec<_> = (offset..num_leaves).step_by(step).collect(); + + for (&left_pos, &right_pos) in left_positions.iter().zip(right_positions.iter()) { + let left_hash = hashes[left_pos]; + let right_hash = hashes[right_pos]; + let left_incl = is_included[left_pos]; + let right_incl = is_included[right_pos]; + let left_min_type = min_types[left_pos]; + let right_min_type = min_types[right_pos]; + + match (left_hash, right_hash, left_incl, right_incl) { + (Some(l), Some(r), true, false) => { + missing_with_types.push((right_min_type, r)); + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), l, r)); + is_included[left_pos] = true; + min_types[left_pos] = core::cmp::min(left_min_type, right_min_type); + }, + (Some(l), Some(r), false, true) => { + missing_with_types.push((left_min_type, l)); + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), l, r)); + is_included[left_pos] = true; + min_types[left_pos] = core::cmp::min(left_min_type, right_min_type); + }, + (Some(l), Some(r), true, true) => { + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), l, r)); + is_included[left_pos] = true; + min_types[left_pos] = core::cmp::min(left_min_type, right_min_type); + }, + (Some(l), Some(r), false, false) => { + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), l, r)); + is_included[left_pos] = false; + min_types[left_pos] = core::cmp::min(left_min_type, right_min_type); + }, + (Some(l), None, incl, _) => { + hashes[left_pos] = Some(l); + is_included[left_pos] = incl; + }, + _ => unreachable!("Invalid state in merkle tree construction"), + } + } + } + + missing_with_types.sort_by_key(|(min_type, _)| *min_type); + let missing_hashes: Vec = + missing_with_types.into_iter().map(|(_, h)| h).collect(); + + (hashes[0].expect("Tree should have a root"), missing_hashes) +} + +/// Reconstruct merkle root from selective disclosure data. +/// +/// The `missing_hashes` must be in ascending type order per spec. +pub(super) fn reconstruct_merkle_root<'a>( + included_records: &[(u64, &'a [u8])], leaf_hashes: &[sha256::Hash], omitted_tlvs: &[u64], + missing_hashes: &[sha256::Hash], +) -> Result { + validate_omitted_tlvs(omitted_tlvs)?; + + if included_records.len() != leaf_hashes.len() { + return Err(SelectiveDisclosureError::LeafHashCountMismatch); + } + + let included_types: Vec = included_records.iter().map(|(t, _)| *t).collect(); + let positions = reconstruct_positions(&included_types, omitted_tlvs); + + let total_tlvs = positions.len(); + let num_leaves = total_tlvs * 2; + + let leaf_tag = tagged_hash_engine(sha256::Hash::hash("LnLeaf".as_bytes())); + let branch_tag = tagged_hash_engine(sha256::Hash::hash("LnBranch".as_bytes())); + + let mut hashes: Vec> = vec![None; num_leaves]; + let mut is_included: Vec = vec![false; num_leaves]; + let mut min_positions: Vec = (0..num_leaves).collect(); + + let mut leaf_hash_idx = 0; + for (i, &incl) in positions.iter().enumerate() { + let pos = i * 2; + is_included[pos] = incl; + min_positions[pos] = i; + + if incl { + let (_, record_bytes) = included_records[leaf_hash_idx]; + let leaf_hash = tagged_hash_from_engine(leaf_tag.clone(), record_bytes); + let nonce_hash = leaf_hashes[leaf_hash_idx]; + let per_tlv = tagged_branch_hash_from_engine(branch_tag.clone(), leaf_hash, nonce_hash); + hashes[pos] = Some(per_tlv); + leaf_hash_idx += 1; + } + } + + // First pass: identify positions needing missing hashes + let mut needs_hash: Vec<(usize, usize)> = Vec::new(); + { + let mut temp_hashes: Vec> = vec![None; num_leaves]; + let mut temp_included: Vec = is_included.clone(); + let mut temp_min_pos: Vec = min_positions.clone(); + + for (i, &incl) in positions.iter().enumerate() { + let pos = i * 2; + if incl { + temp_hashes[pos] = Some(()); + } + } + + for level in 1.. { + let step = 2 << level; + let offset = step / 2; + if offset >= num_leaves { + break; + } + + for left_pos in (0..num_leaves).step_by(step) { + let right_pos = left_pos + offset; + if right_pos >= num_leaves { + continue; + } + + let left_hash = temp_hashes[left_pos]; + let right_hash = temp_hashes[right_pos]; + let left_incl = temp_included[left_pos]; + let right_incl = temp_included[right_pos]; + let left_min = temp_min_pos[left_pos]; + let right_min = temp_min_pos[right_pos]; + + match (left_hash, right_hash, left_incl, right_incl) { + (Some(_), None, true, false) => { + // Need hash for right subtree, keyed by right's min_position + needs_hash.push((right_min, right_pos)); + temp_hashes[left_pos] = Some(()); + temp_included[left_pos] = true; + temp_min_pos[left_pos] = core::cmp::min(left_min, right_min); + }, + (None, Some(_), false, true) => { + // Need hash for left subtree, keyed by left's min_position + needs_hash.push((left_min, left_pos)); + temp_hashes[left_pos] = Some(()); + temp_included[left_pos] = true; + temp_min_pos[left_pos] = core::cmp::min(left_min, right_min); + }, + (Some(_), Some(_), _, _) => { + temp_hashes[left_pos] = Some(()); + temp_included[left_pos] = true; + temp_min_pos[left_pos] = core::cmp::min(left_min, right_min); + }, + (Some(_), None, false, false) => { + // Odd node propagation + }, + (None, None, false, false) => { + temp_min_pos[left_pos] = core::cmp::min(left_min, right_min); + }, + _ => {}, + } + } + } + } + + needs_hash.sort_by_key(|(min_pos, _)| *min_pos); + + if needs_hash.len() != missing_hashes.len() { + return Err(SelectiveDisclosureError::InsufficientMissingHashes); + } + + let mut hash_map: alloc::collections::BTreeMap = + alloc::collections::BTreeMap::new(); + for (i, (_, tree_pos)) in needs_hash.iter().enumerate() { + hash_map.insert(*tree_pos, missing_hashes[i]); + } + + // Second pass: reconstruction + for level in 1.. { + let step = 2 << level; + let offset = step / 2; + if offset >= num_leaves { + break; + } + + for left_pos in (0..num_leaves).step_by(step) { + let right_pos = left_pos + offset; + if right_pos >= num_leaves { + continue; + } + + let left_hash = hashes[left_pos]; + let right_hash = hashes[right_pos]; + let left_incl = is_included[left_pos]; + let right_incl = is_included[right_pos]; + let left_min = min_positions[left_pos]; + let right_min = min_positions[right_pos]; + + match (left_hash, right_hash, left_incl, right_incl) { + (Some(l), None, true, false) => { + let r = hash_map + .get(&right_pos) + .ok_or(SelectiveDisclosureError::InsufficientMissingHashes)?; + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), l, *r)); + is_included[left_pos] = true; + min_positions[left_pos] = core::cmp::min(left_min, right_min); + }, + (None, Some(r), false, true) => { + let l = hash_map + .get(&left_pos) + .ok_or(SelectiveDisclosureError::InsufficientMissingHashes)?; + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), *l, r)); + is_included[left_pos] = true; + min_positions[left_pos] = core::cmp::min(left_min, right_min); + }, + (Some(l), Some(r), _, _) => { + hashes[left_pos] = + Some(tagged_branch_hash_from_engine(branch_tag.clone(), l, r)); + is_included[left_pos] = true; + min_positions[left_pos] = core::cmp::min(left_min, right_min); + }, + (Some(l), None, false, false) => { + hashes[left_pos] = Some(l); + }, + (None, None, false, false) => { + min_positions[left_pos] = core::cmp::min(left_min, right_min); + }, + _ => { + return Err(SelectiveDisclosureError::InsufficientMissingHashes); + }, + }; + } + } + + hashes[0].ok_or(SelectiveDisclosureError::InsufficientMissingHashes) +} + +fn validate_omitted_tlvs(markers: &[u64]) -> Result<(), SelectiveDisclosureError> { + let mut prev = 0u64; + for &marker in markers { + if marker == 0 { + return Err(SelectiveDisclosureError::InvalidOmittedTlvsMarker); + } + if SIGNATURE_TYPES.contains(&marker) { + return Err(SelectiveDisclosureError::InvalidOmittedTlvsMarker); + } + if marker <= prev { + return Err(SelectiveDisclosureError::InvalidOmittedTlvsOrder); + } + prev = marker; + } + Ok(()) +} + +/// Reconstruct position inclusion map from included types and omitted markers. +/// +/// This reverses the marker encoding algorithm from `compute_omitted_markers`: +/// - Markers form "runs" of consecutive values (e.g., [11, 12] is a run) +/// - A "jump" in markers (e.g., 12 → 41) indicates an included TLV came between +/// - After included type X, the next marker in that run equals X + 1 +/// +/// The algorithm tracks `prev_marker` to detect continuations vs jumps: +/// - If `marker == prev_marker + 1`: continuation → omitted position +/// - Otherwise: jump → included position comes first, then process marker as continuation +/// +/// Example: included=[10, 40], markers=[11, 12, 41, 42] +/// - Position 0: TLV0 (always omitted) +/// - marker=11, prev=0: 11 != 1, jump! Insert included (10), prev=10 +/// - marker=11, prev=10: 11 == 11, continuation → omitted, prev=11 +/// - marker=12, prev=11: 12 == 12, continuation → omitted, prev=12 +/// - marker=41, prev=12: 41 != 13, jump! Insert included (40), prev=40 +/// - marker=41, prev=40: 41 == 41, continuation → omitted, prev=41 +/// - marker=42, prev=41: 42 == 42, continuation → omitted, prev=42 +/// Result: [O, I, O, O, I, O, O] +fn reconstruct_positions(included_types: &[u64], omitted_markers: &[u64]) -> Vec { + let total = 1 + included_types.len() + omitted_markers.len(); + let mut positions = Vec::with_capacity(total); + positions.push(false); // TLV0 is always omitted + + let mut inc_idx = 0; + let mut mrk_idx = 0; + // After TLV0 (implicit marker 0), next continuation would be marker 1 + let mut prev_marker: u64 = 0; + + while inc_idx < included_types.len() || mrk_idx < omitted_markers.len() { + if mrk_idx >= omitted_markers.len() { + // No more markers, remaining positions are included + positions.push(true); + inc_idx += 1; + } else if inc_idx >= included_types.len() { + // No more included types, remaining positions are omitted + positions.push(false); + prev_marker = omitted_markers[mrk_idx]; + mrk_idx += 1; + } else { + let marker = omitted_markers[mrk_idx]; + let inc_type = included_types[inc_idx]; + + if marker == prev_marker + 1 { + // Continuation of current run → this position is omitted + positions.push(false); + prev_marker = marker; + mrk_idx += 1; + } else { + // Jump detected! An included TLV comes before this marker. + // After the included type, prev_marker resets to that type, + // so the marker will be processed as a continuation next iteration. + positions.push(true); + prev_marker = inc_type; + inc_idx += 1; + // Don't advance mrk_idx - same marker will be continuation next + } + } + } + + positions +} + +/// Creates a TaggedHash directly from a merkle root (for payer proof verification). +impl TaggedHash { + /// Creates a tagged hash from a pre-computed merkle root. + pub(super) fn from_merkle_root(tag: &'static str, merkle_root: sha256::Hash) -> Self { + let tag_hash = sha256::Hash::hash(tag.as_bytes()); + let digest = Message::from_digest(tagged_hash(tag_hash, merkle_root).to_byte_array()); + Self { tag, merkle_root, digest } + } +} + #[cfg(test)] mod tests { use super::{TlvStream, SIGNATURE_TYPES}; @@ -497,4 +988,205 @@ mod tests { self.fmt_bech32_str(f) } } + + // ============================================================================ + // Tests for selective disclosure / payer proof reconstruction + // ============================================================================ + + /// Test reconstruct_positions with the BOLT 12 payer proof spec example. + /// + /// TLVs: 0(omit), 10(incl), 20(omit), 30(omit), 40(incl), 50(omit), 60(omit) + /// Markers: [11, 12, 41, 42] + /// Expected positions: [O, I, O, O, I, O, O] + #[test] + fn test_reconstruct_positions_spec_example() { + let included_types = vec![10, 40]; + let markers = vec![11, 12, 41, 42]; + let positions = super::reconstruct_positions(&included_types, &markers); + assert_eq!(positions, vec![false, true, false, false, true, false, false]); + } + + /// Test reconstruct_positions when there are omitted TLVs before the first included. + /// + /// TLVs: 0(omit), 5(omit), 10(incl), 20(omit) + /// Markers: [1, 11] (1 is first omitted after TLV0, 11 is after included 10) + /// Expected positions: [O, O, I, O] + #[test] + fn test_reconstruct_positions_omitted_before_included() { + let included_types = vec![10]; + let markers = vec![1, 11]; + let positions = super::reconstruct_positions(&included_types, &markers); + assert_eq!(positions, vec![false, false, true, false]); + } + + /// Test reconstruct_positions with only included TLVs (no omitted except TLV0). + /// + /// TLVs: 0(omit), 10(incl), 20(incl) + /// Markers: [] (no omitted TLVs after TLV0) + /// Expected positions: [O, I, I] + #[test] + fn test_reconstruct_positions_no_omitted() { + let included_types = vec![10, 20]; + let markers = vec![]; + let positions = super::reconstruct_positions(&included_types, &markers); + assert_eq!(positions, vec![false, true, true]); + } + + /// Test reconstruct_positions with only omitted TLVs (no included). + /// + /// TLVs: 0(omit), 5(omit), 10(omit) + /// Markers: [1, 2] (consecutive omitted after TLV0) + /// Expected positions: [O, O, O] + #[test] + fn test_reconstruct_positions_no_included() { + let included_types = vec![]; + let markers = vec![1, 2]; + let positions = super::reconstruct_positions(&included_types, &markers); + assert_eq!(positions, vec![false, false, false]); + } + + /// Test round-trip: compute selective disclosure then reconstruct merkle root. + #[test] + fn test_selective_disclosure_round_trip() { + use alloc::collections::BTreeSet; + + // Build TLV stream matching spec example structure + // TLVs: 0, 10, 20, 30, 40, 50, 60 + let mut tlv_bytes = Vec::new(); + tlv_bytes.extend_from_slice(&[0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // TLV 0 + tlv_bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); // TLV 10 + tlv_bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); // TLV 20 + tlv_bytes.extend_from_slice(&[0x1e, 0x02, 0x00, 0x00]); // TLV 30 + tlv_bytes.extend_from_slice(&[0x28, 0x02, 0x00, 0x00]); // TLV 40 + tlv_bytes.extend_from_slice(&[0x32, 0x02, 0x00, 0x00]); // TLV 50 + tlv_bytes.extend_from_slice(&[0x3c, 0x02, 0x00, 0x00]); // TLV 60 + + // Include types 10 and 40 + let mut included = BTreeSet::new(); + included.insert(10); + included.insert(40); + + // Compute selective disclosure + let disclosure = super::compute_selective_disclosure(&tlv_bytes, &included).unwrap(); + + // Verify markers match spec example + assert_eq!(disclosure.omitted_tlvs, vec![11, 12, 41, 42]); + + // Verify leaf_hashes count matches included TLVs + assert_eq!(disclosure.leaf_hashes.len(), 2); + + // Collect included records for reconstruction + let included_records: Vec<(u64, &[u8])> = TlvStream::new(&tlv_bytes) + .filter(|r| included.contains(&r.r#type)) + .map(|r| (r.r#type, r.record_bytes)) + .collect(); + + // Reconstruct merkle root + let reconstructed = super::reconstruct_merkle_root( + &included_records, + &disclosure.leaf_hashes, + &disclosure.omitted_tlvs, + &disclosure.missing_hashes, + ) + .unwrap(); + + // Must match original + assert_eq!(reconstructed, disclosure.merkle_root); + } + + /// Test that missing_hashes are in ascending type order per spec. + /// + /// Per spec: "MUST include the minimal set of merkle hashes of missing merkle + /// leaves or nodes in `missing_hashes`, in ascending type order." + /// + /// For the spec example with TLVs [0(o), 10(I), 20(o), 30(o), 40(I), 50(o), 60(o)]: + /// - hash(0) covers type 0 + /// - hash(B(20,30)) covers types 20-30 (min=20) + /// - hash(50) covers type 50 + /// - hash(60) covers type 60 + /// + /// Expected order: [type 0, type 20, type 50, type 60] + /// This means 4 missing_hashes in this order. + #[test] + fn test_missing_hashes_ascending_type_order() { + use alloc::collections::BTreeSet; + + // Build TLV stream: 0, 10, 20, 30, 40, 50, 60 + let mut tlv_bytes = Vec::new(); + tlv_bytes.extend_from_slice(&[0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // TLV 0 + tlv_bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); // TLV 10 + tlv_bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); // TLV 20 + tlv_bytes.extend_from_slice(&[0x1e, 0x02, 0x00, 0x00]); // TLV 30 + tlv_bytes.extend_from_slice(&[0x28, 0x02, 0x00, 0x00]); // TLV 40 + tlv_bytes.extend_from_slice(&[0x32, 0x02, 0x00, 0x00]); // TLV 50 + tlv_bytes.extend_from_slice(&[0x3c, 0x02, 0x00, 0x00]); // TLV 60 + + // Include types 10 and 40 (same as spec example) + let mut included = BTreeSet::new(); + included.insert(10); + included.insert(40); + + let disclosure = super::compute_selective_disclosure(&tlv_bytes, &included).unwrap(); + + // We should have 4 missing hashes for omitted types: + // - type 0 (single leaf) + // - types 20+30 (combined branch, min_type=20) + // - type 50 (single leaf) + // - type 60 (single leaf) + // + // The spec example only shows 3, but that appears to be incomplete + // (missing hash for type 60). Our implementation should produce 4. + assert_eq!( + disclosure.missing_hashes.len(), + 4, + "Expected 4 missing hashes for omitted types [0, 20+30, 50, 60]" + ); + + // Verify the round-trip still works with the correct ordering + let included_records: Vec<(u64, &[u8])> = TlvStream::new(&tlv_bytes) + .filter(|r| included.contains(&r.r#type)) + .map(|r| (r.r#type, r.record_bytes)) + .collect(); + + let reconstructed = super::reconstruct_merkle_root( + &included_records, + &disclosure.leaf_hashes, + &disclosure.omitted_tlvs, + &disclosure.missing_hashes, + ) + .unwrap(); + + assert_eq!(reconstructed, disclosure.merkle_root); + } + + /// Test that reconstruction fails with wrong number of missing_hashes. + #[test] + fn test_reconstruction_fails_with_wrong_missing_hashes() { + use alloc::collections::BTreeSet; + + let mut tlv_bytes = Vec::new(); + tlv_bytes.extend_from_slice(&[0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // TLV 0 + tlv_bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); // TLV 10 + tlv_bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); // TLV 20 + + let mut included = BTreeSet::new(); + included.insert(10); + + let disclosure = super::compute_selective_disclosure(&tlv_bytes, &included).unwrap(); + + let included_records: Vec<(u64, &[u8])> = TlvStream::new(&tlv_bytes) + .filter(|r| included.contains(&r.r#type)) + .map(|r| (r.r#type, r.record_bytes)) + .collect(); + + // Try with empty missing_hashes (should fail) + let result = super::reconstruct_merkle_root( + &included_records, + &disclosure.leaf_hashes, + &disclosure.omitted_tlvs, + &[], // Wrong! + ); + + assert!(result.is_err()); + } } diff --git a/lightning/src/offers/mod.rs b/lightning/src/offers/mod.rs index 5b5cf6cdc78..bbbf91a1f1c 100644 --- a/lightning/src/offers/mod.rs +++ b/lightning/src/offers/mod.rs @@ -25,6 +25,7 @@ pub mod merkle; pub mod nonce; pub mod parse; mod payer; +pub mod payer_proof; pub mod refund; pub(crate) mod signer; pub mod static_invoice; diff --git a/lightning/src/offers/payer_proof.rs b/lightning/src/offers/payer_proof.rs new file mode 100644 index 00000000000..d5115ac6c70 --- /dev/null +++ b/lightning/src/offers/payer_proof.rs @@ -0,0 +1,987 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! Payer proofs for BOLT 12 invoices. +//! +//! A [`PayerProof`] cryptographically proves that a BOLT 12 invoice was paid by demonstrating: +//! - Possession of the payment preimage (proving the payment occurred) +//! - A valid invoice signature over a merkle root (proving the invoice is authentic) +//! - The payer's signature (proving who authorized the payment) +//! +//! This implements the payer proof extension to BOLT 12 as specified in +//! . + +use alloc::collections::BTreeSet; + +use crate::io; +use crate::io::Read; +use crate::offers::invoice::{Bolt12Invoice, SIGNATURE_TAG}; +use crate::offers::merkle::{ + self, SelectiveDisclosure, SelectiveDisclosureError, TaggedHash, TlvStream, SIGNATURE_TYPES, +}; +use crate::offers::parse::Bech32Encode; +use crate::types::features::Bolt12InvoiceFeatures; +use crate::types::payment::{PaymentHash, PaymentPreimage}; +use crate::util::ser::{BigSize, Readable, Writeable}; + +use bitcoin::hashes::{sha256, Hash, HashEngine}; +use bitcoin::secp256k1::schnorr::Signature; +use bitcoin::secp256k1::{Message, PublicKey, Secp256k1}; + +use core::convert::TryFrom; +use core::time::Duration; + +#[allow(unused_imports)] +use crate::prelude::*; + +const TLV_SIGNATURE: u64 = 240; +const TLV_PREIMAGE: u64 = 242; +const TLV_OMITTED_TLVS: u64 = 244; +const TLV_MISSING_HASHES: u64 = 246; +const TLV_LEAF_HASHES: u64 = 248; +const TLV_PAYER_SIGNATURE: u64 = 250; + +const TLV_INVREQ_METADATA: u64 = 0; +const TLV_INVREQ_PAYER_ID: u64 = 88; +const TLV_INVOICE_PAYMENT_HASH: u64 = 168; +const TLV_INVOICE_FEATURES: u64 = 174; +const TLV_INVOICE_NODE_ID: u64 = 176; + +/// Human-readable prefix for payer proofs in bech32 encoding. +pub const PAYER_PROOF_HRP: &str = "lnp"; + +/// Tag for payer signature computation per BOLT 12 signature calculation. +/// Format: "lightning" || messagename || fieldname +const PAYER_SIGNATURE_TAG: &str = concat!("lightning", "payer_proof", "payer_signature"); + +/// Error when building or verifying a payer proof. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PayerProofError { + /// The preimage doesn't match the invoice's payment hash. + PreimageMismatch, + /// Error during merkle tree operations. + MerkleError(SelectiveDisclosureError), + /// The invoice signature is invalid. + InvalidInvoiceSignature, + /// The payer signature is invalid. + InvalidPayerSignature, + /// Error during signing. + SigningError, + /// Missing required field in the proof. + MissingRequiredField(&'static str), + /// The proof contains invalid data. + InvalidData(&'static str), + /// The invreq_metadata field cannot be included (per spec). + InvreqMetadataNotAllowed, + /// The omitted_tlvs contains an included TLV type. + OmittedTlvsContainsIncluded, + /// The omitted_tlvs has too many trailing markers. + TooManyTrailingOmittedMarkers, + /// Error decoding the payer proof. + DecodeError(crate::ln::msgs::DecodeError), +} + +impl From for PayerProofError { + fn from(e: SelectiveDisclosureError) -> Self { + PayerProofError::MerkleError(e) + } +} + +impl From for PayerProofError { + fn from(e: crate::ln::msgs::DecodeError) -> Self { + PayerProofError::DecodeError(e) + } +} + +/// A cryptographic proof that a BOLT 12 invoice was paid. +/// +/// Contains the payment preimage, selective disclosure of invoice fields, +/// the invoice signature, and a payer signature proving who paid. +#[derive(Clone, Debug)] +pub struct PayerProof { + bytes: Vec, + contents: PayerProofContents, + merkle_root: sha256::Hash, +} + +#[derive(Clone, Debug)] +struct PayerProofContents { + payer_id: PublicKey, + payment_hash: PaymentHash, + invoice_node_id: PublicKey, + preimage: PaymentPreimage, + invoice_signature: Signature, + payer_signature: Signature, + payer_note: Option, + + #[allow(dead_code)] + leaf_hashes: Vec, + #[allow(dead_code)] + omitted_tlvs: Vec, + #[allow(dead_code)] + missing_hashes: Vec, + + #[allow(dead_code)] + offer_description: Option, + #[allow(dead_code)] + offer_issuer: Option, + #[allow(dead_code)] + invoice_amount: Option, + #[allow(dead_code)] + invoice_created_at: Option, + #[allow(dead_code)] + invoice_features: Option, +} + +/// Builds a [`PayerProof`] from a paid invoice and its preimage. +/// +/// By default, only the required fields are included (payer_id, payment_hash, +/// invoice_node_id). Additional fields can be included for selective disclosure +/// using the `include_*` methods. +pub struct PayerProofBuilder<'a> { + invoice: &'a Bolt12Invoice, + preimage: PaymentPreimage, + included_types: BTreeSet, +} + +impl<'a> PayerProofBuilder<'a> { + /// Create a new builder from a paid invoice and its preimage. + /// + /// Returns an error if the preimage doesn't match the invoice's payment hash. + pub fn new( + invoice: &'a Bolt12Invoice, preimage: PaymentPreimage, + ) -> Result { + let computed_hash = sha256::Hash::hash(&preimage.0); + if computed_hash.as_byte_array() != &invoice.payment_hash().0 { + return Err(PayerProofError::PreimageMismatch); + } + + let mut included_types = BTreeSet::new(); + included_types.insert(TLV_INVREQ_PAYER_ID); + included_types.insert(TLV_INVOICE_PAYMENT_HASH); + included_types.insert(TLV_INVOICE_NODE_ID); + + if invoice.invoice_features() != &Bolt12InvoiceFeatures::empty() { + included_types.insert(TLV_INVOICE_FEATURES); + } + + Ok(Self { invoice, preimage, included_types }) + } + + /// Check if a TLV type is allowed to be included in the payer proof. + /// + /// Per spec: MUST NOT include invreq_metadata (type 0). + fn is_type_allowed(tlv_type: u64) -> bool { + tlv_type != TLV_INVREQ_METADATA + } + + /// Include a specific TLV type in the proof. + /// + /// Returns an error if the type is not allowed (e.g., invreq_metadata). + pub fn include_type(mut self, tlv_type: u64) -> Result { + if !Self::is_type_allowed(tlv_type) { + return Err(PayerProofError::InvreqMetadataNotAllowed); + } + self.included_types.insert(tlv_type); + Ok(self) + } + + /// Include the offer description in the proof. + pub fn include_offer_description(mut self) -> Self { + self.included_types.insert(10); + self + } + + /// Include the offer issuer in the proof. + pub fn include_offer_issuer(mut self) -> Self { + self.included_types.insert(18); + self + } + + /// Include the invoice amount in the proof. + pub fn include_invoice_amount(mut self) -> Self { + self.included_types.insert(170); + self + } + + /// Include the invoice creation timestamp in the proof. + pub fn include_invoice_created_at(mut self) -> Self { + self.included_types.insert(164); + self + } + + /// Build an unsigned payer proof. + pub fn build(self) -> Result { + let mut invoice_bytes = Vec::new(); + self.invoice.write(&mut invoice_bytes).expect("Vec write should not fail"); + + let bytes_without_sig: Vec = TlvStream::new(&invoice_bytes) + .filter(|r| !SIGNATURE_TYPES.contains(&r.r#type)) + .flat_map(|r| r.record_bytes.to_vec()) + .collect(); + + let disclosure = + merkle::compute_selective_disclosure(&bytes_without_sig, &self.included_types)?; + + let included_records: Vec<(u64, Vec)> = TlvStream::new(&invoice_bytes) + .filter(|r| self.included_types.contains(&r.r#type)) + .map(|r| (r.r#type, r.record_bytes.to_vec())) + .collect(); + + let invoice_signature = self.invoice.signature(); + + Ok(UnsignedPayerProof { + invoice_signature, + preimage: self.preimage, + payer_id: self.invoice.payer_signing_pubkey(), + payment_hash: self.invoice.payment_hash().clone(), + invoice_node_id: self.invoice.signing_pubkey(), + included_records, + disclosure, + invoice_features: if self.included_types.contains(&174) { + Some(self.invoice.invoice_features().clone()) + } else { + None + }, + }) + } +} + +/// An unsigned [`PayerProof`] ready for signing. +pub struct UnsignedPayerProof { + invoice_signature: Signature, + preimage: PaymentPreimage, + payer_id: PublicKey, + payment_hash: PaymentHash, + invoice_node_id: PublicKey, + included_records: Vec<(u64, Vec)>, + disclosure: SelectiveDisclosure, + invoice_features: Option, +} + +impl UnsignedPayerProof { + /// Returns the merkle root of the invoice. + pub fn merkle_root(&self) -> sha256::Hash { + self.disclosure.merkle_root + } + + /// Sign the proof with the payer's key to create a complete proof. + pub fn sign(self, sign_fn: F, note: Option<&str>) -> Result + where + F: FnOnce(&Message) -> Result, + { + let message = Self::compute_payer_signature_message(note, &self.disclosure.merkle_root); + let payer_signature = sign_fn(&message).map_err(|_| PayerProofError::SigningError)?; + + let secp_ctx = Secp256k1::verification_only(); + secp_ctx + .verify_schnorr(&payer_signature, &message, &self.payer_id.into()) + .map_err(|_| PayerProofError::InvalidPayerSignature)?; + + let bytes = self.serialize_payer_proof(&payer_signature, note); + + Ok(PayerProof { + bytes, + contents: PayerProofContents { + payer_id: self.payer_id, + payment_hash: self.payment_hash, + invoice_node_id: self.invoice_node_id, + preimage: self.preimage, + invoice_signature: self.invoice_signature, + payer_signature, + payer_note: note.map(String::from), + leaf_hashes: self.disclosure.leaf_hashes, + omitted_tlvs: self.disclosure.omitted_tlvs, + missing_hashes: self.disclosure.missing_hashes, + offer_description: None, + offer_issuer: None, + invoice_amount: None, + invoice_created_at: None, + invoice_features: self.invoice_features, + }, + merkle_root: self.disclosure.merkle_root, + }) + } + + /// Compute the payer signature message per BOLT 12 signature calculation. + fn compute_payer_signature_message(note: Option<&str>, merkle_root: &sha256::Hash) -> Message { + let mut inner_hasher = sha256::Hash::engine(); + if let Some(n) = note { + inner_hasher.input(n.as_bytes()); + } + inner_hasher.input(merkle_root.as_ref()); + let inner_msg = sha256::Hash::from_engine(inner_hasher); + + let tag_hash = sha256::Hash::hash(PAYER_SIGNATURE_TAG.as_bytes()); + + let mut final_hasher = sha256::Hash::engine(); + final_hasher.input(tag_hash.as_ref()); + final_hasher.input(tag_hash.as_ref()); + final_hasher.input(inner_msg.as_ref()); + let final_digest = sha256::Hash::from_engine(final_hasher); + + Message::from_digest(*final_digest.as_byte_array()) + } + + fn serialize_payer_proof(&self, payer_signature: &Signature, note: Option<&str>) -> Vec { + let mut bytes = Vec::new(); + + for (_, record_bytes) in &self.included_records { + bytes.extend_from_slice(record_bytes); + } + + BigSize(TLV_SIGNATURE).write(&mut bytes).expect("Vec write should not fail"); + BigSize(64).write(&mut bytes).expect("Vec write should not fail"); + self.invoice_signature.write(&mut bytes).expect("Vec write should not fail"); + + BigSize(TLV_PREIMAGE).write(&mut bytes).expect("Vec write should not fail"); + BigSize(32).write(&mut bytes).expect("Vec write should not fail"); + bytes.extend_from_slice(&self.preimage.0); + + if !self.disclosure.omitted_tlvs.is_empty() { + let mut omitted_bytes = Vec::new(); + for marker in &self.disclosure.omitted_tlvs { + BigSize(*marker).write(&mut omitted_bytes).expect("Vec write should not fail"); + } + BigSize(TLV_OMITTED_TLVS).write(&mut bytes).expect("Vec write should not fail"); + BigSize(omitted_bytes.len() as u64) + .write(&mut bytes) + .expect("Vec write should not fail"); + bytes.extend_from_slice(&omitted_bytes); + } + + if !self.disclosure.missing_hashes.is_empty() { + let len = self.disclosure.missing_hashes.len() * 32; + BigSize(TLV_MISSING_HASHES).write(&mut bytes).expect("Vec write should not fail"); + BigSize(len as u64).write(&mut bytes).expect("Vec write should not fail"); + for hash in &self.disclosure.missing_hashes { + bytes.extend_from_slice(hash.as_ref()); + } + } + + if !self.disclosure.leaf_hashes.is_empty() { + let len = self.disclosure.leaf_hashes.len() * 32; + BigSize(TLV_LEAF_HASHES).write(&mut bytes).expect("Vec write should not fail"); + BigSize(len as u64).write(&mut bytes).expect("Vec write should not fail"); + for hash in &self.disclosure.leaf_hashes { + bytes.extend_from_slice(hash.as_ref()); + } + } + + let note_bytes = note.map(|n| n.as_bytes()).unwrap_or(&[]); + let payer_sig_len = 64 + note_bytes.len(); + BigSize(TLV_PAYER_SIGNATURE).write(&mut bytes).expect("Vec write should not fail"); + BigSize(payer_sig_len as u64).write(&mut bytes).expect("Vec write should not fail"); + payer_signature.write(&mut bytes).expect("Vec write should not fail"); + bytes.extend_from_slice(note_bytes); + + bytes + } +} + +impl PayerProof { + /// Verify the payer proof. + pub fn verify(&self) -> Result<(), PayerProofError> { + let computed = sha256::Hash::hash(&self.contents.preimage.0); + if computed.as_byte_array() != &self.contents.payment_hash.0 { + return Err(PayerProofError::PreimageMismatch); + } + + let tagged_hash = TaggedHash::from_merkle_root(SIGNATURE_TAG, self.merkle_root); + merkle::verify_signature( + &self.contents.invoice_signature, + &tagged_hash, + self.contents.invoice_node_id, + ) + .map_err(|_| PayerProofError::InvalidInvoiceSignature)?; + + let message = UnsignedPayerProof::compute_payer_signature_message( + self.contents.payer_note.as_deref(), + &self.merkle_root, + ); + + let secp_ctx = Secp256k1::verification_only(); + secp_ctx + .verify_schnorr( + &self.contents.payer_signature, + &message, + &self.contents.payer_id.into(), + ) + .map_err(|_| PayerProofError::InvalidPayerSignature)?; + + Ok(()) + } + + /// The payment preimage proving the invoice was paid. + pub fn preimage(&self) -> PaymentPreimage { + self.contents.preimage + } + + /// The payer's public key (who paid). + pub fn payer_id(&self) -> PublicKey { + self.contents.payer_id + } + + /// The invoice node ID (who was paid). + pub fn invoice_node_id(&self) -> PublicKey { + self.contents.invoice_node_id + } + + /// The payment hash. + pub fn payment_hash(&self) -> PaymentHash { + self.contents.payment_hash + } + + /// The payer's note, if any. + pub fn payer_note(&self) -> Option<&str> { + self.contents.payer_note.as_deref() + } + + /// The merkle root of the original invoice. + pub fn merkle_root(&self) -> sha256::Hash { + self.merkle_root + } + + /// The raw bytes of the payer proof. + pub fn bytes(&self) -> &[u8] { + &self.bytes + } +} + +impl Bech32Encode for PayerProof { + const BECH32_HRP: &'static str = PAYER_PROOF_HRP; +} + +impl AsRef<[u8]> for PayerProof { + fn as_ref(&self) -> &[u8] { + &self.bytes + } +} + +impl TryFrom> for PayerProof { + type Error = crate::offers::parse::Bolt12ParseError; + + fn try_from(bytes: Vec) -> Result { + use crate::ln::msgs::DecodeError; + use crate::offers::parse::Bolt12ParseError; + + let mut payer_id: Option = None; + let mut payment_hash: Option = None; + let mut invoice_node_id: Option = None; + let mut invoice_signature: Option = None; + let mut preimage: Option = None; + let mut payer_signature: Option = None; + let mut payer_note: Option = None; + let mut invoice_features: Option = None; + + let mut leaf_hashes: Vec = Vec::new(); + let mut omitted_tlvs: Vec = Vec::new(); + let mut missing_hashes: Vec = Vec::new(); + + let mut included_types: BTreeSet = BTreeSet::new(); + let mut included_records: Vec<(u64, Vec)> = Vec::new(); + + let mut prev_tlv_type: u64 = 0; + let mut seen_tlv_types: BTreeSet = BTreeSet::new(); + + for record in TlvStream::new(&bytes) { + let tlv_type = record.r#type; + + if tlv_type <= prev_tlv_type && prev_tlv_type != 0 { + return Err(Bolt12ParseError::Decode(DecodeError::InvalidValue)); + } + + if seen_tlv_types.contains(&tlv_type) { + return Err(Bolt12ParseError::Decode(DecodeError::InvalidValue)); + } + seen_tlv_types.insert(tlv_type); + prev_tlv_type = tlv_type; + + match tlv_type { + TLV_INVREQ_PAYER_ID => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let _len: BigSize = Readable::read(&mut record_cursor)?; + payer_id = Some(Readable::read(&mut record_cursor)?); + included_types.insert(tlv_type); + included_records.push((tlv_type, record.record_bytes.to_vec())); + }, + TLV_INVOICE_PAYMENT_HASH => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let _len: BigSize = Readable::read(&mut record_cursor)?; + payment_hash = Some(Readable::read(&mut record_cursor)?); + included_types.insert(tlv_type); + included_records.push((tlv_type, record.record_bytes.to_vec())); + }, + TLV_INVOICE_FEATURES => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let len: BigSize = Readable::read(&mut record_cursor)?; + let mut feature_bytes = vec![0u8; len.0 as usize]; + record_cursor + .read_exact(&mut feature_bytes) + .map_err(|_| DecodeError::ShortRead)?; + invoice_features = Some(Bolt12InvoiceFeatures::from_le_bytes(feature_bytes)); + included_types.insert(tlv_type); + included_records.push((tlv_type, record.record_bytes.to_vec())); + }, + TLV_INVOICE_NODE_ID => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let _len: BigSize = Readable::read(&mut record_cursor)?; + invoice_node_id = Some(Readable::read(&mut record_cursor)?); + included_types.insert(tlv_type); + included_records.push((tlv_type, record.record_bytes.to_vec())); + }, + TLV_SIGNATURE => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let _len: BigSize = Readable::read(&mut record_cursor)?; + invoice_signature = Some(Readable::read(&mut record_cursor)?); + }, + TLV_PREIMAGE => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let _len: BigSize = Readable::read(&mut record_cursor)?; + let mut preimage_bytes = [0u8; 32]; + record_cursor + .read_exact(&mut preimage_bytes) + .map_err(|_| DecodeError::ShortRead)?; + preimage = Some(PaymentPreimage(preimage_bytes)); + }, + TLV_OMITTED_TLVS => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let len: BigSize = Readable::read(&mut record_cursor)?; + let end_pos = record_cursor.position() + len.0; + while record_cursor.position() < end_pos { + let marker: BigSize = Readable::read(&mut record_cursor)?; + omitted_tlvs.push(marker.0); + } + }, + TLV_MISSING_HASHES => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let len: BigSize = Readable::read(&mut record_cursor)?; + if len.0 % 32 != 0 { + return Err(Bolt12ParseError::Decode(DecodeError::InvalidValue)); + } + let num_hashes = len.0 / 32; + for _ in 0..num_hashes { + let mut hash_bytes = [0u8; 32]; + record_cursor + .read_exact(&mut hash_bytes) + .map_err(|_| DecodeError::ShortRead)?; + missing_hashes.push(sha256::Hash::from_byte_array(hash_bytes)); + } + }, + TLV_LEAF_HASHES => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let len: BigSize = Readable::read(&mut record_cursor)?; + if len.0 % 32 != 0 { + return Err(Bolt12ParseError::Decode(DecodeError::InvalidValue)); + } + let num_hashes = len.0 / 32; + for _ in 0..num_hashes { + let mut hash_bytes = [0u8; 32]; + record_cursor + .read_exact(&mut hash_bytes) + .map_err(|_| DecodeError::ShortRead)?; + leaf_hashes.push(sha256::Hash::from_byte_array(hash_bytes)); + } + }, + TLV_PAYER_SIGNATURE => { + let mut record_cursor = io::Cursor::new(record.record_bytes); + let _type: BigSize = Readable::read(&mut record_cursor)?; + let len: BigSize = Readable::read(&mut record_cursor)?; + payer_signature = Some(Readable::read(&mut record_cursor)?); + let note_len = len.0.saturating_sub(64); + if note_len > 0 { + let mut note_bytes = vec![0u8; note_len as usize]; + record_cursor + .read_exact(&mut note_bytes) + .map_err(|_| DecodeError::ShortRead)?; + payer_note = Some( + String::from_utf8(note_bytes).map_err(|_| DecodeError::InvalidValue)?, + ); + } + }, + _ => { + if tlv_type == TLV_INVREQ_METADATA { + return Err(Bolt12ParseError::Decode(DecodeError::InvalidValue)); + } + if !SIGNATURE_TYPES.contains(&tlv_type) { + included_types.insert(tlv_type); + included_records.push((tlv_type, record.record_bytes.to_vec())); + } + }, + } + } + + let payer_id = payer_id.ok_or(Bolt12ParseError::InvalidSemantics( + crate::offers::parse::Bolt12SemanticError::MissingPayerSigningPubkey, + ))?; + let payment_hash = payment_hash.ok_or(Bolt12ParseError::InvalidSemantics( + crate::offers::parse::Bolt12SemanticError::MissingPaymentHash, + ))?; + let invoice_node_id = invoice_node_id.ok_or(Bolt12ParseError::InvalidSemantics( + crate::offers::parse::Bolt12SemanticError::MissingSigningPubkey, + ))?; + let invoice_signature = invoice_signature.ok_or(Bolt12ParseError::InvalidSemantics( + crate::offers::parse::Bolt12SemanticError::MissingSignature, + ))?; + let preimage = preimage.ok_or(Bolt12ParseError::Decode(DecodeError::InvalidValue))?; + let payer_signature = payer_signature.ok_or(Bolt12ParseError::InvalidSemantics( + crate::offers::parse::Bolt12SemanticError::MissingSignature, + ))?; + + validate_omitted_tlvs_for_parsing(&omitted_tlvs, &included_types) + .map_err(|_| Bolt12ParseError::Decode(DecodeError::InvalidValue))?; + + if leaf_hashes.len() != included_records.len() { + return Err(Bolt12ParseError::Decode(DecodeError::InvalidValue)); + } + + let included_refs: Vec<(u64, &[u8])> = + included_records.iter().map(|(t, b)| (*t, b.as_slice())).collect(); + let merkle_root = merkle::reconstruct_merkle_root( + &included_refs, + &leaf_hashes, + &omitted_tlvs, + &missing_hashes, + ) + .map_err(|_| Bolt12ParseError::Decode(DecodeError::InvalidValue))?; + + Ok(PayerProof { + bytes, + contents: PayerProofContents { + payer_id, + payment_hash, + invoice_node_id, + preimage, + invoice_signature, + payer_signature, + payer_note, + leaf_hashes, + omitted_tlvs, + missing_hashes, + offer_description: None, + offer_issuer: None, + invoice_amount: None, + invoice_created_at: None, + invoice_features, + }, + merkle_root, + }) + } +} + +/// Validate omitted_tlvs markers during parsing. +/// +/// Per spec: +/// - MUST NOT contain 0 +/// - MUST NOT contain signature TLV element numbers (240-1000) +/// - MUST be in strict ascending order +/// - MUST NOT contain the number of an included TLV field +/// - MUST NOT contain more than one number larger than the largest included non-signature TLV +fn validate_omitted_tlvs_for_parsing( + omitted_tlvs: &[u64], included_types: &BTreeSet, +) -> Result<(), PayerProofError> { + let mut prev = 0u64; + let mut trailing_count = 0; + let max_included = included_types.iter().copied().max().unwrap_or(0); + + for &marker in omitted_tlvs { + // MUST NOT contain 0 + if marker == 0 { + return Err(PayerProofError::InvalidData("omitted_tlvs contains 0")); + } + + // MUST NOT contain signature TLV types + if SIGNATURE_TYPES.contains(&marker) { + return Err(PayerProofError::InvalidData("omitted_tlvs contains signature type")); + } + + // MUST be strictly ascending + if marker <= prev { + return Err(PayerProofError::InvalidData("omitted_tlvs not strictly ascending")); + } + + // MUST NOT contain included TLV types + if included_types.contains(&marker) { + return Err(PayerProofError::OmittedTlvsContainsIncluded); + } + + // Count markers larger than largest included + if marker > max_included { + trailing_count += 1; + } + + prev = marker; + } + + // MUST NOT contain more than one number larger than largest included + if trailing_count > 1 { + return Err(PayerProofError::TooManyTrailingOmittedMarkers); + } + + Ok(()) +} + +impl core::fmt::Display for PayerProof { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + self.fmt_bech32_str(f) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::offers::merkle::compute_selective_disclosure; + + #[test] + fn test_selective_disclosure_computation() { + // Test that the merkle selective disclosure works correctly + // Simple TLV stream with types 1, 2 + let tlv_bytes = vec![ + 0x01, 0x03, 0xe8, 0x03, 0xe8, // type 1, length 3, value + 0x02, 0x08, 0x00, 0x00, 0x01, 0x00, 0x00, 0x02, 0x00, 0x03, // type 2 + ]; + + let mut included = BTreeSet::new(); + included.insert(1); + + let result = compute_selective_disclosure(&tlv_bytes, &included); + assert!(result.is_ok()); + + let disclosure = result.unwrap(); + assert_eq!(disclosure.leaf_hashes.len(), 1); // One included TLV + assert!(!disclosure.missing_hashes.is_empty()); // Should have missing hashes for omitted + } + + /// Test the omitted_tlvs marker algorithm per BOLT 12 payer proof spec. + /// + /// From the spec example: + /// TLVs: 0 (omitted), 10 (included), 20 (omitted), 30 (omitted), + /// 40 (included), 50 (omitted), 60 (omitted), 240 (signature) + /// + /// Expected markers: [11, 12, 41, 42] + /// + /// The algorithm: + /// - TLV 0 is always omitted and implicit (not in markers) + /// - For omitted TLV after included: marker = prev_included_type + 1 + /// - For consecutive omitted TLVs: marker = prev_marker + 1 + #[test] + fn test_omitted_markers_spec_example() { + // Build a synthetic TLV stream matching the spec example + // TLV format: type (BigSize) || length (BigSize) || value + let mut tlv_bytes = Vec::new(); + + // TLV 0: type=0, len=4, value=dummy + tlv_bytes.extend_from_slice(&[0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); + // TLV 10: type=10, len=2, value=dummy + tlv_bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); + // TLV 20: type=20, len=2, value=dummy + tlv_bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); + // TLV 30: type=30, len=2, value=dummy + tlv_bytes.extend_from_slice(&[0x1e, 0x02, 0x00, 0x00]); + // TLV 40: type=40, len=2, value=dummy + tlv_bytes.extend_from_slice(&[0x28, 0x02, 0x00, 0x00]); + // TLV 50: type=50, len=2, value=dummy + tlv_bytes.extend_from_slice(&[0x32, 0x02, 0x00, 0x00]); + // TLV 60: type=60, len=2, value=dummy + tlv_bytes.extend_from_slice(&[0x3c, 0x02, 0x00, 0x00]); + + // Include types 10 and 40 + let mut included = BTreeSet::new(); + included.insert(10); + included.insert(40); + + let disclosure = compute_selective_disclosure(&tlv_bytes, &included).unwrap(); + + // Per spec example, omitted_tlvs should be [11, 12, 41, 42] + assert_eq!(disclosure.omitted_tlvs, vec![11, 12, 41, 42]); + + // leaf_hashes should have 2 entries (one for each included TLV) + assert_eq!(disclosure.leaf_hashes.len(), 2); + } + + /// Test that the marker algorithm handles edge cases correctly. + #[test] + fn test_omitted_markers_edge_cases() { + // Test with only one included TLV at the start + let mut tlv_bytes = Vec::new(); + tlv_bytes.extend_from_slice(&[0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // TLV 0 + tlv_bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); // TLV 10 + tlv_bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); // TLV 20 + tlv_bytes.extend_from_slice(&[0x1e, 0x02, 0x00, 0x00]); // TLV 30 + + let mut included = BTreeSet::new(); + included.insert(10); + + let disclosure = compute_selective_disclosure(&tlv_bytes, &included).unwrap(); + + // After included type 10, omitted types 20 and 30 get markers 11 and 12 + assert_eq!(disclosure.omitted_tlvs, vec![11, 12]); + } + + /// Test that all included TLVs produce no omitted markers (except implicit TLV0). + #[test] + fn test_omitted_markers_all_included() { + let mut tlv_bytes = Vec::new(); + tlv_bytes.extend_from_slice(&[0x00, 0x04, 0x00, 0x00, 0x00, 0x00]); // TLV 0 (always omitted) + tlv_bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); // TLV 10 + tlv_bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); // TLV 20 + + let mut included = BTreeSet::new(); + included.insert(10); + included.insert(20); + + let disclosure = compute_selective_disclosure(&tlv_bytes, &included).unwrap(); + + // Only TLV 0 is omitted (implicit), so no markers needed + assert!(disclosure.omitted_tlvs.is_empty()); + } + + /// Test validation of omitted_tlvs - must not contain 0. + #[test] + fn test_validate_omitted_tlvs_rejects_zero() { + let omitted = vec![0, 5, 10]; + let included: BTreeSet = [20, 30].iter().copied().collect(); + + let result = validate_omitted_tlvs_for_parsing(&omitted, &included); + assert!(matches!(result, Err(PayerProofError::InvalidData(_)))); + } + + /// Test validation of omitted_tlvs - must not contain signature types. + #[test] + fn test_validate_omitted_tlvs_rejects_signature_types() { + let omitted = vec![5, 10, 250]; // 250 is a signature type + let included: BTreeSet = [20, 30].iter().copied().collect(); + + let result = validate_omitted_tlvs_for_parsing(&omitted, &included); + assert!(matches!(result, Err(PayerProofError::InvalidData(_)))); + } + + /// Test validation of omitted_tlvs - must be strictly ascending. + #[test] + fn test_validate_omitted_tlvs_rejects_non_ascending() { + let omitted = vec![5, 10, 8]; // 8 is not strictly ascending after 10 + let included: BTreeSet = [20, 30].iter().copied().collect(); + + let result = validate_omitted_tlvs_for_parsing(&omitted, &included); + assert!(matches!(result, Err(PayerProofError::InvalidData(_)))); + } + + /// Test validation of omitted_tlvs - must not contain included types. + #[test] + fn test_validate_omitted_tlvs_rejects_included_types() { + let omitted = vec![5, 20, 25]; // 20 is in included set + let included: BTreeSet = [20, 30].iter().copied().collect(); + + let result = validate_omitted_tlvs_for_parsing(&omitted, &included); + assert!(matches!(result, Err(PayerProofError::OmittedTlvsContainsIncluded))); + } + + /// Test validation of omitted_tlvs - must not have too many trailing markers. + #[test] + fn test_validate_omitted_tlvs_rejects_too_many_trailing() { + let omitted = vec![5, 100, 101]; // 100 and 101 are both > max included (30) + let included: BTreeSet = [20, 30].iter().copied().collect(); + + let result = validate_omitted_tlvs_for_parsing(&omitted, &included); + assert!(matches!(result, Err(PayerProofError::TooManyTrailingOmittedMarkers))); + } + + /// Test that valid omitted_tlvs pass validation. + #[test] + fn test_validate_omitted_tlvs_accepts_valid() { + let omitted = vec![5, 10, 35]; // All valid: ascending, no 0, no sig types, one trailing + let included: BTreeSet = [20, 30].iter().copied().collect(); + + let result = validate_omitted_tlvs_for_parsing(&omitted, &included); + assert!(result.is_ok()); + } + + /// Test that invreq_metadata (type 0) cannot be explicitly included. + #[test] + fn test_invreq_metadata_not_allowed() { + assert!(!PayerProofBuilder::<'_>::is_type_allowed(TLV_INVREQ_METADATA)); + assert!(PayerProofBuilder::<'_>::is_type_allowed(TLV_INVREQ_PAYER_ID)); + } + + /// Test that out-of-order TLVs are rejected during parsing. + #[test] + fn test_parsing_rejects_out_of_order_tlvs() { + use core::convert::TryFrom; + + // Create a malformed TLV stream with out-of-order types (20 before 10) + // TLV format: type (BigSize) || length (BigSize) || value + let mut bytes = Vec::new(); + // TLV type 20, length 2, value + bytes.extend_from_slice(&[0x14, 0x02, 0x00, 0x00]); + // TLV type 10, length 2, value (OUT OF ORDER!) + bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); + + let result = PayerProof::try_from(bytes); + assert!(result.is_err()); + } + + /// Test that duplicate TLVs are rejected during parsing. + #[test] + fn test_parsing_rejects_duplicate_tlvs() { + use core::convert::TryFrom; + + // Create a malformed TLV stream with duplicate type 10 + let mut bytes = Vec::new(); + // TLV type 10, length 2, value + bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); + // TLV type 10 again (DUPLICATE!) + bytes.extend_from_slice(&[0x0a, 0x02, 0x00, 0x00]); + + let result = PayerProof::try_from(bytes); + assert!(result.is_err()); + } + + /// Test that invalid hash lengths (not multiple of 32) are rejected. + #[test] + fn test_parsing_rejects_invalid_hash_length() { + use core::convert::TryFrom; + + // Create a TLV stream with missing_hashes (type 246) that has invalid length + // BigSize encoding: values 0-252 are single byte, 253-65535 use 0xFD prefix + let mut bytes = Vec::new(); + // TLV type 246 (missing_hashes) - 246 < 253 so single byte + bytes.push(0xf6); // type 246 + bytes.push(0x21); // length 33 (not multiple of 32!) + bytes.extend_from_slice(&[0x00; 33]); // 33 bytes of zeros + + let result = PayerProof::try_from(bytes); + assert!(result.is_err()); + } + + /// Test that invalid leaf_hashes length (not multiple of 32) is rejected. + #[test] + fn test_parsing_rejects_invalid_leaf_hashes_length() { + use core::convert::TryFrom; + + // Create a TLV stream with leaf_hashes (type 248) that has invalid length + // BigSize encoding: values 0-252 are single byte, 253-65535 use 0xFD prefix + let mut bytes = Vec::new(); + // TLV type 248 (leaf_hashes) - 248 < 253 so single byte + bytes.push(0xf8); // type 248 + bytes.push(0x1f); // length 31 (not multiple of 32!) + bytes.extend_from_slice(&[0x00; 31]); // 31 bytes of zeros + + let result = PayerProof::try_from(bytes); + assert!(result.is_err()); + } +}