diff --git a/rust/crates/cli/src/commands/server/start.rs b/rust/crates/cli/src/commands/server/start.rs index 6c2f0b78..44088a47 100644 --- a/rust/crates/cli/src/commands/server/start.rs +++ b/rust/crates/cli/src/commands/server/start.rs @@ -8,6 +8,7 @@ use axum::response::{IntoResponse, Response}; use axum::routing::{any, get, post}; use owo_colors::OwoColorize; use pay_core::PaymentState; +use pay_core::ReplayStore; use pay_core::accounts::AccountsStore; use pay_core::server::session::SessionMpp; use pay_core::server::telemetry::FeePayerWallet; @@ -91,6 +92,7 @@ struct AppState { session_mpp: Option>, browser_rpc_url: Option, fee_payer_wallet: Option, + replay_store: ReplayStore, } impl PaymentState for AppState { @@ -112,6 +114,9 @@ impl PaymentState for AppState { fn fee_payer_wallet(&self) -> Option<&FeePayerWallet> { self.fee_payer_wallet.as_ref() } + fn replay_store(&self) -> Option<&ReplayStore> { + Some(&self.replay_store) + } } fn should_use_auto_fee_payer_signer( @@ -769,6 +774,7 @@ impl StartCommand { session_mpp, browser_rpc_url: Some(BROWSER_RPC_PROXY_PATH.to_string()), fee_payer_wallet, + replay_store: ReplayStore::default(), }; let pdb_state = if debugger { diff --git a/rust/crates/core/src/lib.rs b/rust/crates/core/src/lib.rs index 91978dd7..53ac2159 100644 --- a/rust/crates/core/src/lib.rs +++ b/rust/crates/core/src/lib.rs @@ -34,6 +34,8 @@ pub use server::{AccountingKey, AccountingStore, InMemoryStore, current_period}; #[cfg(feature = "server")] use pay_types::metering::ApiSpec; #[cfg(feature = "server")] +pub use server::payment::ReplayStore; +#[cfg(feature = "server")] pub use solana_mpp; #[cfg(feature = "server")] use solana_mpp::server::Mpp; @@ -55,4 +57,7 @@ pub trait PaymentState: Clone + Send + Sync + 'static { fn fee_payer_wallet(&self) -> Option<&server::telemetry::FeePayerWallet> { None } + fn replay_store(&self) -> Option<&server::payment::ReplayStore> { + None + } } diff --git a/rust/crates/core/src/server/payment.rs b/rust/crates/core/src/server/payment.rs index 85e6f363..aa4e5fa7 100644 --- a/rust/crates/core/src/server/payment.rs +++ b/rust/crates/core/src/server/payment.rs @@ -4,10 +4,15 @@ //! - No payment header → 402 with MPP challenge (WWW-Authenticate) //! - Payment header → verify with solana-mpp, then forward upstream +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + use axum::body::Body; use axum::http::{HeaderMap, Method, Request, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; +use base64::Engine; use serde_json::json; use solana_mpp::{ AUTHORIZATION_HEADER, PAYMENT_RECEIPT_HEADER, WWW_AUTHENTICATE_HEADER, format_receipt, @@ -37,6 +42,60 @@ const PAYMENT_PAGE_CONTENT_SECURITY_POLICY: &str = "\ img-src 'self' data: blob: https:; \ connect-src 'self' http://localhost:* http://127.0.0.1:* https:; \ worker-src 'self'"; +const DEFAULT_REPLAY_TTL: Duration = Duration::from_secs(30 * 60); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ReplayRecord { + used_at: Instant, +} + +#[derive(Clone)] +pub struct ReplayStore { + inner: Arc>>, + ttl: Duration, +} + +impl Default for ReplayStore { + fn default() -> Self { + Self::new(DEFAULT_REPLAY_TTL) + } +} + +impl ReplayStore { + pub fn new(ttl: Duration) -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + ttl, + } + } + + pub fn contains_recent(&self, key: &str) -> bool { + let mut guard = self.inner.lock().unwrap(); + cleanup_expired_locked(&mut guard, self.ttl); + guard.contains_key(key) + } + + pub fn mark_used(&self, key: String) { + let mut guard = self.inner.lock().unwrap(); + cleanup_expired_locked(&mut guard, self.ttl); + guard.insert( + key, + ReplayRecord { + used_at: Instant::now(), + }, + ); + } + + pub fn cleanup(&self) { + let mut guard = self.inner.lock().unwrap(); + cleanup_expired_locked(&mut guard, self.ttl); + } + + #[cfg(test)] + fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } +} /// Axum middleware that gates metered endpoints behind MPP payment. pub async fn payment_middleware( @@ -189,6 +248,8 @@ pub async fn payment_middleware( auth_value, subdomain, &path, + &method, + state.replay_store().cloned(), state.fee_payer_wallet().cloned(), req, next, @@ -443,7 +504,7 @@ fn resolve_charge_splits( #[tracing::instrument( name = "charge_authorization", - skip(mpps, auth_value, fee_payer_wallet, req, next), + skip(mpps, auth_value, replay_store, fee_payer_wallet, req, next), fields(subdomain = %subdomain, path = %path) )] async fn handle_charge_authorization( @@ -451,6 +512,8 @@ async fn handle_charge_authorization( auth_value: &str, subdomain: &str, path: &str, + method: &Method, + replay_store: Option, fee_payer_wallet: Option, req: Request, next: Next, @@ -471,6 +534,28 @@ async fn handle_charge_authorization( for mpp in mpps { match mpp.verify_credential(&credential).await { Ok(receipt) => { + let replay_key = + build_mpp_replay_key(&credential, &receipt.reference, method, path); + if let Some(store) = replay_store.as_ref() + && let Some(key) = replay_key.as_deref() + && store.contains_recent(key) + { + tracing::warn!( + subdomain = %subdomain, + path = %path, + method = %method, + replay_key = %key, + "Payment proof replay detected" + ); + telemetry::record_settlement_error( + "mpp", + subdomain, + path, + "payment proof already used", + false, + ); + return replay_failed_response(mpps); + } let payment = decode_payment_amount(&credential, mpp.decimals() as u8); telemetry::record_payment_collected( "mpp", @@ -501,6 +586,12 @@ async fn handle_charge_authorization( { response.headers_mut().insert(PAYMENT_RECEIPT_HEADER, v); } + if response.status().is_success() + && let Some(store) = replay_store.as_ref() + && let Some(key) = replay_key + { + store.mark_used(key); + } return response; } Err(e) => last_error = Some(e), @@ -547,6 +638,26 @@ fn verification_failed_response( response } +fn replay_failed_response(mpps: &[&solana_mpp::server::Mpp]) -> Response { + let mut response = ( + StatusCode::PAYMENT_REQUIRED, + axum::Json(json!({ + "error": "verification_failed", + "message": "payment proof already used", + "retryable": false, + })), + ) + .into_response(); + let challenges: Vec<_> = mpps + .iter() + .filter_map(|mpp| mpp.charge("0.01").ok()) + .collect(); + if let Ok(www_auths) = format_www_authenticate_many(&challenges) { + append_www_authenticate_headers(response.headers_mut(), &www_auths); + } + response +} + pub fn readable_verification_message(error: &solana_mpp::server::VerificationError) -> String { let message = error.to_string(); if message.contains("Fee payer cannot authorize the SPL payment transfer") { @@ -561,6 +672,55 @@ pub fn readable_verification_message(error: &solana_mpp::server::VerificationErr message } +fn build_mpp_replay_key( + credential: &solana_mpp::PaymentCredential, + receipt_reference: &str, + method: &Method, + path: &str, +) -> Option { + let payer = extract_payer_from_credential(credential)?; + let canonical = format!( + "mpp:{}:{payer}:{receipt_reference}:{}:{path}", + credential.challenge.id, + method.as_str() + ); + Some(format!( + "mpp:{}", + blake3::hash(canonical.as_bytes()).to_hex() + )) +} + +fn extract_payer_from_credential(credential: &solana_mpp::PaymentCredential) -> Option { + if let Some(tx_b64) = credential + .payload + .get("transaction") + .and_then(|value| value.as_str()) + { + let tx_bytes = base64::engine::general_purpose::STANDARD + .decode(tx_b64) + .ok()?; + let tx: solana_transaction::Transaction = bincode::deserialize(&tx_bytes).ok()?; + let zero_sig = [0u8; 64]; + for (index, sig) in tx.signatures.iter().enumerate() { + if sig.as_ref() != zero_sig && index < tx.message.account_keys.len() { + return Some(tx.message.account_keys[index].to_string()); + } + } + return tx.message.account_keys.first().map(ToString::to_string); + } + + credential + .payload + .get("source") + .and_then(|value| value.as_str()) + .map(ToString::to_string) +} + +fn cleanup_expired_locked(entries: &mut HashMap, ttl: Duration) { + let now = Instant::now(); + entries.retain(|_, record| now.duration_since(record.used_at) < ttl); +} + fn challenge_json_response(body: serde_json::Value, www_auths: &[String]) -> Response { let mut response = (StatusCode::PAYMENT_REQUIRED, axum::Json(body)).into_response(); append_www_authenticate_headers(response.headers_mut(), www_auths); @@ -640,9 +800,13 @@ fn extract_variant_hint(path: &str) -> Option { #[cfg(test)] mod tests { use super::*; + use base64::Engine; use solana_mpp::WWW_AUTHENTICATE_HEADER; use solana_mpp::server::Mpp; use solana_mpp::server::session::SessionConfig; + use solana_signature::Signature; + use solana_transaction::Transaction; + use std::thread; fn test_mpp() -> Mpp { Mpp::new(solana_mpp::server::Config { @@ -773,6 +937,81 @@ mod tests { ); } + #[test] + fn replay_store_marks_and_finds_recent_keys() { + let store = ReplayStore::new(Duration::from_secs(60)); + assert!(!store.contains_recent("proof-1")); + store.mark_used("proof-1".to_string()); + assert!(store.contains_recent("proof-1")); + } + + #[test] + fn replay_store_expires_old_entries() { + let store = ReplayStore::new(Duration::from_millis(5)); + store.mark_used("proof-1".to_string()); + thread::sleep(Duration::from_millis(10)); + assert!(!store.contains_recent("proof-1")); + assert_eq!(store.len(), 0); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn replay_failed_response_is_non_retryable_verification_error() { + let mpp = test_mpp(); + let response = replay_failed_response(&[&mpp]); + assert_eq!(response.status(), StatusCode::PAYMENT_REQUIRED); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let body: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(body["error"], "verification_failed"); + assert_eq!(body["message"], "payment proof already used"); + assert_eq!(body["retryable"], false); + } + + #[test] + fn build_mpp_replay_key_is_stable_for_same_verified_inputs() { + let mpp = test_mpp(); + let challenge = mpp.charge("0.01").expect("challenge should build"); + let tx = Transaction { + signatures: vec![Signature::new_unique()], + message: solana_message::Message { + header: solana_message::MessageHeader { + num_required_signatures: 1, + num_readonly_signed_accounts: 0, + num_readonly_unsigned_accounts: 0, + }, + account_keys: vec![solana_pubkey::Pubkey::new_unique()], + recent_blockhash: solana_hash::Hash::new_unique(), + instructions: vec![], + }, + }; + let tx_b64 = + base64::engine::general_purpose::STANDARD.encode(bincode::serialize(&tx).unwrap()); + let credential = solana_mpp::PaymentCredential::new( + challenge.to_echo(), + serde_json::json!({ + "type": "transaction", + "transaction": tx_b64, + }), + ); + + let key_a = build_mpp_replay_key( + &credential, + "receipt-ref-1", + &Method::POST, + "v1/simple/echo", + ); + let key_b = build_mpp_replay_key( + &credential, + "receipt-ref-1", + &Method::POST, + "v1/simple/echo", + ); + + assert_eq!(key_a, key_b); + } + #[tokio::test] async fn session_challenge_response_sets_session_header() { let response = session_challenge_response(