diff --git a/crates/libmlx/src/firmware/source.rs b/crates/libmlx/src/firmware/source.rs index b9dc5b2305..9a664932ca 100644 --- a/crates/libmlx/src/firmware/source.rs +++ b/crates/libmlx/src/firmware/source.rs @@ -21,7 +21,7 @@ use std::path::{Path, PathBuf}; -use forge_ssh::ssh_client::{AuthConfig, SshClientConfig}; +use forge_ssh::ssh_client::{AuthConfig, HostKeyVerification, SshClientConfig}; use tokio::io::AsyncWriteExt; use tracing; @@ -287,12 +287,17 @@ async fn resolve_ssh( let auth = credentials.clone().map(AuthConfig::try_from).transpose()?; + let host_key_verification = match known_hosts_file { + Some(path) => HostKeyVerification::KnownHostsFile(path.to_path_buf()), + None => HostKeyVerification::DefaultKnownHostsFile, + }; + let client = SshClientConfig { host, port: *port, username, auth: auth.as_ref(), - known_hosts_file, + host_key_verification, } .make_authenticated_client() .await?; diff --git a/crates/ssh/src/ssh.rs b/crates/ssh/src/ssh.rs index 629702c6b6..e8a38f0c42 100644 --- a/crates/ssh/src/ssh.rs +++ b/crates/ssh/src/ssh.rs @@ -19,7 +19,7 @@ use std::net::SocketAddr; use std::time::Duration; use crate::SshError; -use crate::ssh_client::AuthConfig; +use crate::ssh_client::{AuthConfig, HostKeyVerification}; async fn execute_command( command: &str, @@ -34,7 +34,7 @@ async fn execute_command( port: 22, username: &username, auth: Some(&auth), - known_hosts_file: None, + host_key_verification: HostKeyVerification::Insecure, } .make_authenticated_client() .await?; diff --git a/crates/ssh/src/ssh_client.rs b/crates/ssh/src/ssh_client.rs index c3b5dcee8a..a358d81956 100644 --- a/crates/ssh/src/ssh_client.rs +++ b/crates/ssh/src/ssh_client.rs @@ -52,18 +52,28 @@ pub struct SshClientConfig<'a> { pub port: u16, pub username: &'a str, pub auth: Option<&'a AuthConfig>, - pub known_hosts_file: Option<&'a Path>, + pub host_key_verification: HostKeyVerification, +} + +/// How to verify the SSH server's host key. +pub enum HostKeyVerification { + /// Skip host key verification. + Insecure, + /// Verify against the system default known_hosts file (`~/.ssh/known_hosts`). + DefaultKnownHostsFile, + /// Verify against an explicit known_hosts file. + KnownHostsFile(PathBuf), } impl<'a> SshClientConfig<'a> { - pub async fn make_authenticated_client(&'a self) -> SshResult { + pub async fn make_authenticated_client(self) -> SshResult { let mut client = russh::client::connect( russh_client_config(), (self.host, self.port), KnownHostsCheck { host: self.host.to_string(), port: self.port, - known_hosts_file: self.known_hosts_file.map(|p| p.to_path_buf()), + check: self.host_key_verification, }, ) .await?; @@ -165,7 +175,7 @@ fn russh_client_config() -> Arc { struct KnownHostsCheck { host: String, port: u16, - known_hosts_file: Option, + check: HostKeyVerification, } impl russh::client::Handler for KnownHostsCheck { @@ -175,18 +185,17 @@ impl russh::client::Handler for KnownHostsCheck { &mut self, server_public_key: &PublicKey, ) -> Result { - if let Some(path) = self.known_hosts_file.as_ref() { - return russh::keys::check_known_hosts_path( - &self.host, - self.port, - server_public_key, - path, - ) - .map_err(russh::Error::from); + match &self.check { + HostKeyVerification::Insecure => Ok(true), + HostKeyVerification::KnownHostsFile(path) => { + russh::keys::check_known_hosts_path(&self.host, self.port, server_public_key, path) + .map_err(russh::Error::from) + } + HostKeyVerification::DefaultKnownHostsFile => { + russh::keys::check_known_hosts(&self.host, self.port, server_public_key) + .map_err(russh::Error::from) + } } - - russh::keys::check_known_hosts(&self.host, self.port, server_public_key) - .map_err(russh::Error::from) } }