diff --git a/broker/Cargo.lock b/broker/Cargo.lock index d91ea53..fad88fe 100644 --- a/broker/Cargo.lock +++ b/broker/Cargo.lock @@ -1718,7 +1718,7 @@ dependencies = [ [[package]] name = "lss-connector" version = "0.1.0" -source = "git+https://github.com/stakwork/sphinx-rs?rev=1abce4dedfc6be8cb261e4faa11d9a753ee323ce#1abce4dedfc6be8cb261e4faa11d9a753ee323ce" +source = "git+https://github.com/stakwork/sphinx-rs?rev=64d5c8aa166c4ff46b0817bc4011f39a3c949de7#64d5c8aa166c4ff46b0817bc4011f39a3c949de7" dependencies = [ "anyhow", "lightning-storage-server", @@ -2720,7 +2720,7 @@ dependencies = [ [[package]] name = "rmp-utils" version = "0.1.0" -source = "git+https://github.com/stakwork/sphinx-rs?rev=1abce4dedfc6be8cb261e4faa11d9a753ee323ce#1abce4dedfc6be8cb261e4faa11d9a753ee323ce" +source = "git+https://github.com/stakwork/sphinx-rs?rev=64d5c8aa166c4ff46b0817bc4011f39a3c949de7#64d5c8aa166c4ff46b0817bc4011f39a3c949de7" dependencies = [ "anyhow", "log", @@ -3311,7 +3311,7 @@ dependencies = [ [[package]] name = "sphinx-auther" version = "0.1.12" -source = "git+https://github.com/stakwork/sphinx-rs?rev=1abce4dedfc6be8cb261e4faa11d9a753ee323ce#1abce4dedfc6be8cb261e4faa11d9a753ee323ce" +source = "git+https://github.com/stakwork/sphinx-rs?rev=64d5c8aa166c4ff46b0817bc4011f39a3c949de7#64d5c8aa166c4ff46b0817bc4011f39a3c949de7" dependencies = [ "anyhow", "base64 0.21.2", @@ -3323,7 +3323,7 @@ dependencies = [ [[package]] name = "sphinx-glyph" version = "0.1.2" -source = "git+https://github.com/stakwork/sphinx-rs?rev=1abce4dedfc6be8cb261e4faa11d9a753ee323ce#1abce4dedfc6be8cb261e4faa11d9a753ee323ce" +source = "git+https://github.com/stakwork/sphinx-rs?rev=64d5c8aa166c4ff46b0817bc4011f39a3c949de7#64d5c8aa166c4ff46b0817bc4011f39a3c949de7" dependencies = [ "anyhow", "hex", @@ -3369,7 +3369,7 @@ dependencies = [ [[package]] name = "sphinx-signer" version = "0.1.0" -source = "git+https://github.com/stakwork/sphinx-rs?rev=1abce4dedfc6be8cb261e4faa11d9a753ee323ce#1abce4dedfc6be8cb261e4faa11d9a753ee323ce" +source = "git+https://github.com/stakwork/sphinx-rs?rev=64d5c8aa166c4ff46b0817bc4011f39a3c949de7#64d5c8aa166c4ff46b0817bc4011f39a3c949de7" dependencies = [ "anyhow", "bip39", diff --git a/broker/Cargo.toml b/broker/Cargo.toml index 5e84b25..41bc45e 100644 --- a/broker/Cargo.toml +++ b/broker/Cargo.toml @@ -39,8 +39,8 @@ vls-proxy = { git = "https://gitlab.com/lightning-signer/validating-li # vls-protocol-client = { path = "../../vls/vls-protocol-client" } # vls-proxy = { path = "../../vls/vls-proxy" } -lss-connector = { git = "https://github.com/stakwork/sphinx-rs", rev = "1abce4dedfc6be8cb261e4faa11d9a753ee323ce" } -sphinx-signer = { git = "https://github.com/stakwork/sphinx-rs", rev = "1abce4dedfc6be8cb261e4faa11d9a753ee323ce" } +lss-connector = { git = "https://github.com/stakwork/sphinx-rs", rev = "64d5c8aa166c4ff46b0817bc4011f39a3c949de7" } +sphinx-signer = { git = "https://github.com/stakwork/sphinx-rs", rev = "64d5c8aa166c4ff46b0817bc4011f39a3c949de7" } # lss-connector = { path = "../../sphinx-rs/lss-connector" } # sphinx-signer = { path = "../../sphinx-rs/signer" } diff --git a/broker/src/conn.rs b/broker/src/conn.rs index c9ed0e5..b606b1b 100644 --- a/broker/src/conn.rs +++ b/broker/src/conn.rs @@ -1,12 +1,13 @@ use anyhow::Result; use rocket::tokio::sync::{mpsc, oneshot}; use serde::{Deserialize, Serialize}; +use sphinx_signer::sphinx_glyph::types::SignerType; use std::collections::HashMap; #[derive(Debug, Serialize, Deserialize)] pub struct Connections { pub pubkey: Option, - pub clients: HashMap, + pub clients: HashMap, pub current: Option, } @@ -27,21 +28,16 @@ impl Connections { pub fn set_current(&mut self, cid: String) { self.current = Some(cid); } - fn add_client(&mut self, cid: &str) { - self.clients.insert(cid.to_string(), true); + pub fn add_client(&mut self, cid: &str, signer_type: SignerType) { + self.clients.insert(cid.to_string(), signer_type); self.current = Some(cid.to_string()); } - fn remove_client(&mut self, cid: &str) { + pub fn remove_client(&mut self, cid: &str) { self.clients.remove(cid); - if self.current == Some(cid.to_string()) { - self.current = None; - } - } - pub fn client_action(&mut self, cid: &str, connected: bool) { - if connected { - self.add_client(cid); - } else { - self.remove_client(cid); + if let Some(id) = &self.current { + if id == cid { + self.current = None; + } } } } @@ -58,6 +54,7 @@ pub struct ChannelRequest { pub message: Vec, pub reply_tx: oneshot::Sender, pub cid: Option, // if it exists, only try the one client + pub signer_type: Option, // if it exists, only try clients of these types } impl ChannelRequest { pub fn new(topic: &str, message: Vec) -> (Self, oneshot::Receiver) { @@ -67,6 +64,7 @@ impl ChannelRequest { message, reply_tx, cid: None, + signer_type: None, }; (cr, reply_rx) } @@ -81,6 +79,7 @@ impl ChannelRequest { message, reply_tx, cid: None, + signer_type: None, }; let _ = sender.send(req).await; let reply = reply_rx.await?; @@ -98,13 +97,14 @@ impl ChannelRequest { message, reply_tx, cid: Some(cid.to_string()), + signer_type: None, }; let _ = sender.send(req).await; let reply = reply_rx.await?; Ok(reply.reply) } pub fn for_cid(&mut self, cid: &str) { - self.cid = Some(cid.to_string()) + self.cid = Some(cid.to_string()); } pub fn new_for( cid: &str, @@ -115,6 +115,18 @@ impl ChannelRequest { cr.for_cid(cid); (cr, reply_rx) } + pub fn for_type(&mut self, signer_type: SignerType) { + self.signer_type = Some(signer_type); + } + pub fn new_for_type( + signer_type: SignerType, + topic: &str, + message: Vec, + ) -> (Self, oneshot::Receiver) { + let (mut cr, reply_rx) = ChannelRequest::new(topic, message); + cr.for_type(signer_type); + (cr, reply_rx) + } } // mpsc reply diff --git a/broker/src/looper.rs b/broker/src/looper.rs index 808eb3f..2072882 100644 --- a/broker/src/looper.rs +++ b/broker/src/looper.rs @@ -4,7 +4,10 @@ use bitcoin::blockdata::constants::ChainHash; use log::*; use rocket::tokio::sync::mpsc; use secp256k1::PublicKey; -use sphinx_signer::{parser, sphinx_glyph::topics}; +use sphinx_signer::{ + parser, + sphinx_glyph::{topics, types::SignerType}, +}; use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::thread; use std::time::Duration; @@ -123,7 +126,7 @@ impl SignerLoop { } msg => { let mut catch_init = false; - if let Message::HsmdInit(m) = msg { + if let Message::HsmdInit(ref m) = msg { catch_init = true; if let Some(set) = settings { if ChainHash::using_genesis_block(set.network).as_bytes() @@ -135,7 +138,14 @@ impl SignerLoop { panic!("Got HsmdInit without settings - likely because HsmdInit was sent after startup"); } } - let reply = self.handle_message(raw_msg, catch_init)?; + let reply = if let Message::PreapproveInvoice(_) + | Message::PreapproveKeysend(_) = msg + { + self.handle_message(raw_msg, catch_init, Some(SignerType::ReceiveSend))? + } else { + // None for signer_type means no restrictions on which signer type to send the message to + self.handle_message(raw_msg, catch_init, None)? + }; // Write the reply to CLN self.client.write_vec(reply)?; } @@ -143,7 +153,12 @@ impl SignerLoop { } } - fn handle_message(&mut self, message: Vec, catch_init: bool) -> Result> { + fn handle_message( + &mut self, + message: Vec, + catch_init: bool, + signer_type: Option, + ) -> Result> { // wait until not busy loop { match try_to_get_busy() { @@ -166,7 +181,7 @@ impl SignerLoop { )?; // send to signer log::info!("SEND ON {}", topics::VLS); - let (res_topic, res) = self.send_request_wait(topics::VLS, md)?; + let (res_topic, res) = self.send_request_wait(topics::VLS, md, signer_type)?; log::info!("GOT ON {}", res_topic); let the_res = if res_topic == topics::LSS_RES { // send reply to LSS to store muts @@ -174,7 +189,7 @@ impl SignerLoop { log::info!("LSS REPLY LEN {}", &lss_reply.len()); // send to signer for HMAC validation, and get final reply log::info!("SEND ON {}", topics::LSS_MSG); - let (res_topic2, res2) = self.send_request_wait(topics::LSS_MSG, lss_reply)?; + let (res_topic2, res2) = self.send_request_wait(topics::LSS_MSG, lss_reply, None)?; log::info!("GOT ON {}, send to CLN", res_topic2); if res_topic2 != topics::VLS_RES { log::warn!("got a topic NOT on {}", topics::VLS_RES); @@ -213,9 +228,17 @@ impl SignerLoop { // returns (topic, payload) // might halt if signer is offline - fn send_request_wait(&mut self, topic: &str, message: Vec) -> Result<(String, Vec)> { + fn send_request_wait( + &mut self, + topic: &str, + message: Vec, + signer_type: Option, + ) -> Result<(String, Vec)> { // Send a request to the MQTT handler to send to signer - let (request, reply_rx) = ChannelRequest::new(topic, message); + let (request, reply_rx) = match signer_type { + Some(st) => ChannelRequest::new_for_type(st, topic, message), + None => ChannelRequest::new(topic, message), + }; // This can fail if MQTT shuts down self.chan .sender diff --git a/broker/src/main.rs b/broker/src/main.rs index 042fb09..210a4c2 100644 --- a/broker/src/main.rs +++ b/broker/src/main.rs @@ -163,27 +163,36 @@ pub async fn broker_setup( let conns_ = conns.clone(); std::thread::spawn(move || { log::info!("=> waiting first connection..."); - while let Ok((cid, connected)) = status_rx.recv() { + while let Ok((cid, connected, signer_type_opt)) = status_rx.recv() { log::info!("=> connection status: {}: {}", cid, connected); let mut cs = conns_.lock().unwrap(); // drop it from list until ready - cs.client_action(&cid, false); + cs.remove_client(&cid); drop(cs); if connected { + // In mqtt.rs, we always send a signer type if connected == true + let signer_type = signer_type_opt.unwrap(); let (dance_complete_tx, dance_complete_rx) = std_oneshot::channel::(); let _ = conn_tx.blocking_send((cid.clone(), dance_complete_tx)); let dance_complete = dance_complete_rx.recv().unwrap_or_else(|e| { - log::info!( + log::warn!( "dance_complete channel died before receiving response: {}", e ); false }); - log::info!("adding client to the list? {}", dance_complete); - let mut cs = conns_.lock().unwrap(); - cs.client_action(&cid, dance_complete); - log::debug!("List: {:?}, action: {}", cs, dance_complete); - drop(cs); + if dance_complete { + log::info!( + "adding client to the list: {}, type: {:?}", + &cid, + signer_type + ); + let mut cs = conns_.lock().unwrap(); + cs.add_client(&cid, signer_type); + drop(cs); + } else { + log::warn!("dance failed, client not added to the list"); + } } } }); diff --git a/broker/src/mqtt.rs b/broker/src/mqtt.rs index b188413..e758a93 100644 --- a/broker/src/mqtt.rs +++ b/broker/src/mqtt.rs @@ -4,7 +4,7 @@ use crate::util::Settings; use rocket::tokio::{sync::broadcast, sync::mpsc}; use rumqttd::{local::LinkTx, AuthMsg, Broker, Config, Notification}; use sphinx_signer::sphinx_glyph::sphinx_auther::token::Token; -use sphinx_signer::sphinx_glyph::topics; +use sphinx_signer::sphinx_glyph::{topics, types::SignerType}; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -15,7 +15,7 @@ pub fn start_broker( settings: Settings, mut receiver: mpsc::Receiver, mut init_receiver: mpsc::Receiver, - status_sender: std::sync::mpsc::Sender<(String, bool)>, + status_sender: std::sync::mpsc::Sender<(String, bool, Option)>, error_sender: broadcast::Sender>, auth_sender: std::sync::mpsc::Sender, connections: Arc>, @@ -39,18 +39,19 @@ pub fn start_broker( }); // connected/disconnected status alerts - let (internal_status_tx, internal_status_rx) = std::sync::mpsc::channel::<(bool, String)>(); + let (internal_status_tx, internal_status_rx) = + std::sync::mpsc::channel::<(bool, String, Option)>(); // track connections let link_tx_ = link_tx.clone(); let _conns_task = std::thread::spawn(move || { - while let Ok((is, cid)) = internal_status_rx.recv() { + while let Ok((is, cid, signer_type)) = internal_status_rx.recv() { if is { subs(&cid, link_tx_.clone()); } else { unsubs(&cid, link_tx_.clone()); } - let _ = status_sender.send((cid, is)); + let _ = status_sender.send((cid, is, signer_type)); } }); @@ -112,9 +113,25 @@ pub fn start_broker( let topic_end = ts[1].to_string(); if topic.ends_with(topics::HELLO) { - let _ = internal_status_tx.send((true, cid)); + let signer_type = match f.publish.payload.get(0) { + Some(byte) => match SignerType::from_byte(*byte) { + Ok(signer_type) => signer_type, + Err(e) => { + log::warn!("Could not deserialize signer type: {}", e); + continue; + } + }, + // This is the ReceiveSend signer type + None => SignerType::default(), + }; + log::debug!( + "caught hello message for id: {}, type: {:?}", + cid, + signer_type + ); + let _ = internal_status_tx.send((true, cid, Some(signer_type))); } else if topic.ends_with(topics::BYE) { - let _ = internal_status_tx.send((false, cid)); + let _ = internal_status_tx.send((false, cid, None)); } else { // VLS, CONTROL, LSS let pld = f.publish.payload.to_vec(); @@ -174,10 +191,25 @@ fn pub_and_wait( } else { let current = current.unwrap(); // Try the current connection - let mut rep = pub_timeout(¤t, &msg.topic, &msg.message, &msg_rx, link_tx); + // This returns None if 1) signer_type is set, and not equal to the current signer + // 2) If pub_timeout times out + let mut rep = if client_list.get(¤t).unwrap() + == msg + .signer_type + .as_ref() + .unwrap_or(client_list.get(¤t).unwrap()) + { + pub_timeout(¤t, &msg.topic, &msg.message, &msg_rx, link_tx) + } else { + None + }; + // If that failed, try looking for some other signer if rep.is_none() { - for cid in client_list.into_keys().filter(|k| k != ¤t) { + // If signer_type is set, we also filter for only these types + for (cid, _) in client_list.into_iter().filter(|(k, v)| { + k != ¤t && v == msg.signer_type.as_ref().unwrap_or(v) + }) { rep = pub_timeout(&cid, &msg.topic, &msg.message, &msg_rx, link_tx); if rep.is_some() { let mut cs = conns_.lock().unwrap(); @@ -199,6 +231,7 @@ fn pub_and_wait( break; } else { log::debug!("couldn't reach any clients..."); + std::thread::sleep(Duration::from_secs(1)); } if let Some(max) = retries { log::debug!("counter: {}, retries: {}", counter, max);