Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 169 additions & 27 deletions crates/openshell-sandbox/src/ssh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,26 @@ fn hmac_sha256(key: &[u8], data: &[u8]) -> String {
hex::encode(result)
}

/// Per-channel state for tracking PTY resources and I/O senders.
///
/// Each SSH channel gets its own PTY master (if a PTY was requested) and input
/// sender. This allows `window_change_request` to resize the correct PTY when
/// multiple channels are open simultaneously (e.g. parallel shells, shell +
/// sftp, etc.).
struct ChannelState {
input_sender: Option<mpsc::Sender<Vec<u8>>>,
pty_master: Option<std::fs::File>,
pty_request: Option<PtyRequest>,
}

struct SshHandler {
policy: SandboxPolicy,
workdir: Option<String>,
netns_fd: Option<RawFd>,
proxy_url: Option<String>,
ca_file_paths: Option<Arc<(PathBuf, PathBuf)>>,
provider_env: HashMap<String, String>,
input_sender: Option<mpsc::Sender<Vec<u8>>>,
pty_master: Option<std::fs::File>,
pty_request: Option<PtyRequest>,
channels: HashMap<ChannelId, ChannelState>,
}

impl SshHandler {
Expand All @@ -291,9 +301,7 @@ impl SshHandler {
proxy_url,
ca_file_paths,
provider_env,
input_sender: None,
pty_master: None,
pty_request: None,
channels: HashMap::new(),
}
}
}
Expand Down Expand Up @@ -388,7 +396,12 @@ impl russh::server::Handler for SshHandler {
_modes: &[(russh::Pty, u32)],
session: &mut Session,
) -> Result<(), Self::Error> {
self.pty_request = Some(PtyRequest {
let state = self.channels.entry(channel).or_insert_with(|| ChannelState {
input_sender: None,
pty_master: None,
pty_request: None,
});
state.pty_request = Some(PtyRequest {
term: term.to_string(),
col_width,
row_height,
Expand All @@ -401,21 +414,25 @@ impl russh::server::Handler for SshHandler {

async fn window_change_request(
&mut self,
_channel: ChannelId,
channel: ChannelId,
col_width: u32,
row_height: u32,
pixel_width: u32,
pixel_height: u32,
_session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(master) = self.pty_master.as_ref() {
let winsize = Winsize {
ws_row: to_u16(row_height.max(1)),
ws_col: to_u16(col_width.max(1)),
ws_xpixel: to_u16(pixel_width),
ws_ypixel: to_u16(pixel_height),
};
let _ = unsafe_pty::set_winsize(master.as_raw_fd(), winsize);
if let Some(state) = self.channels.get(&channel) {
if let Some(master) = state.pty_master.as_ref() {
let winsize = Winsize {
ws_row: to_u16(row_height.max(1)),
ws_col: to_u16(col_width.max(1)),
ws_xpixel: to_u16(pixel_width),
ws_ypixel: to_u16(pixel_height),
};
if let Err(e) = unsafe_pty::set_winsize(master.as_raw_fd(), winsize) {
warn!("failed to resize PTY for channel {channel:?}: {e}");
}
}
}
Ok(())
}
Expand Down Expand Up @@ -474,7 +491,12 @@ impl russh::server::Handler for SshHandler {
self.ca_file_paths.clone(),
&self.provider_env,
)?;
self.input_sender = Some(input_sender);
let state = self.channels.entry(channel).or_insert_with(|| ChannelState {
input_sender: None,
pty_master: None,
pty_request: None,
});
state.input_sender = Some(input_sender);
} else {
warn!(subsystem = name, "unsupported subsystem requested");
session.channel_failure(channel)?;
Expand All @@ -499,26 +521,30 @@ impl russh::server::Handler for SshHandler {

async fn data(
&mut self,
_channel: ChannelId,
channel: ChannelId,
data: &[u8],
_session: &mut Session,
) -> Result<(), Self::Error> {
if let Some(sender) = self.input_sender.as_ref() {
let _ = sender.send(data.to_vec());
if let Some(state) = self.channels.get(&channel) {
if let Some(sender) = state.input_sender.as_ref() {
let _ = sender.send(data.to_vec());
}
}
Ok(())
}

async fn channel_eof(
&mut self,
_channel: ChannelId,
channel: ChannelId,
_session: &mut Session,
) -> Result<(), Self::Error> {
// Drop the input sender so the stdin writer thread sees a
// disconnected channel and closes the child's stdin pipe. This
// is essential for commands like `cat | tar xf -` which need
// stdin EOF to know the input stream is complete.
self.input_sender.take();
if let Some(state) = self.channels.get_mut(&channel) {
state.input_sender.take();
}
Ok(())
}
}
Expand All @@ -530,7 +556,12 @@ impl SshHandler {
handle: Handle,
command: Option<String>,
) -> anyhow::Result<()> {
if let Some(pty) = self.pty_request.take() {
let state = self.channels.entry(channel).or_insert_with(|| ChannelState {
input_sender: None,
pty_master: None,
pty_request: None,
});
if let Some(pty) = state.pty_request.take() {
// PTY was requested — allocate a real PTY (interactive shell or
// exec that explicitly asked for a terminal).
let (pty_master, input_sender) = spawn_pty_shell(
Expand All @@ -545,8 +576,8 @@ impl SshHandler {
self.ca_file_paths.clone(),
&self.provider_env,
)?;
self.pty_master = Some(pty_master);
self.input_sender = Some(input_sender);
state.pty_master = Some(pty_master);
state.input_sender = Some(input_sender);
} else {
// No PTY requested — use plain pipes so stdout/stderr are
// separate and output has clean LF line endings. This is the
Expand All @@ -562,7 +593,7 @@ impl SshHandler {
self.ca_file_paths.clone(),
&self.provider_env,
)?;
self.input_sender = Some(input_sender);
state.input_sender = Some(input_sender);
}
Ok(())
}
Expand Down Expand Up @@ -999,7 +1030,7 @@ mod unsafe_pty {

#[allow(unsafe_code)]
pub fn set_winsize(fd: RawFd, winsize: Winsize) -> std::io::Result<()> {
let rc = unsafe { libc::ioctl(fd, libc::TIOCSWINSZ, winsize) };
let rc = unsafe { libc::ioctl(fd, libc::TIOCSWINSZ, &winsize) };
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
Expand Down Expand Up @@ -1404,4 +1435,115 @@ mod tests {
assert!(!is_loopback_host("not-an-ip"));
assert!(!is_loopback_host("[]"));
}

// -----------------------------------------------------------------------
// Per-channel PTY state tests (#543)
// -----------------------------------------------------------------------

#[test]
fn set_winsize_applies_to_correct_pty() {
// Verify that set_winsize applies to a specific PTY master FD,
// which is the mechanism that per-channel tracking relies on.
// With the old single-pty_master design, a window_change_request
// for channel N would resize whatever PTY was stored last —
// potentially belonging to a different channel.
let pty_a = openpty(None, None).expect("openpty a");
let pty_b = openpty(None, None).expect("openpty b");
let master_a = std::fs::File::from(pty_a.master);
let master_b = std::fs::File::from(pty_b.master);
let fd_a = master_a.as_raw_fd();
let fd_b = master_b.as_raw_fd();
assert_ne!(fd_a, fd_b, "two PTYs must have distinct FDs");

// Close the slave ends to avoid leaking FDs in the test.
drop(std::fs::File::from(pty_a.slave));
drop(std::fs::File::from(pty_b.slave));

// Resize only PTY B.
let winsize_b = Winsize {
ws_row: 50,
ws_col: 120,
ws_xpixel: 0,
ws_ypixel: 0,
};
unsafe_pty::set_winsize(fd_b, winsize_b).expect("set_winsize on PTY B");

// Resize PTY A to a different size.
let winsize_a = Winsize {
ws_row: 24,
ws_col: 80,
ws_xpixel: 0,
ws_ypixel: 0,
};
unsafe_pty::set_winsize(fd_a, winsize_a).expect("set_winsize on PTY A");

// Read back sizes via ioctl to verify independence.
let mut actual_a: libc::winsize = unsafe { std::mem::zeroed() };
let mut actual_b: libc::winsize = unsafe { std::mem::zeroed() };
#[allow(unsafe_code)]
unsafe {
libc::ioctl(fd_a, libc::TIOCGWINSZ, &mut actual_a);
libc::ioctl(fd_b, libc::TIOCGWINSZ, &mut actual_b);
}

assert_eq!(actual_a.ws_row, 24, "PTY A should be 24 rows");
assert_eq!(actual_a.ws_col, 80, "PTY A should be 80 cols");
assert_eq!(actual_b.ws_row, 50, "PTY B should be 50 rows");
assert_eq!(actual_b.ws_col, 120, "PTY B should be 120 cols");
}

#[test]
fn channel_state_independent_input_senders() {
// Verify that each channel gets its own input sender so that
// data() and channel_eof() affect only the targeted channel.
// Uses the ChannelState struct directly without needing ChannelId
// constructors.
let (tx_a, rx_a) = mpsc::channel::<Vec<u8>>();
let (tx_b, rx_b) = mpsc::channel::<Vec<u8>>();

let mut state_a = ChannelState {
input_sender: Some(tx_a),
pty_master: None,
pty_request: None,
};
let state_b = ChannelState {
input_sender: Some(tx_b),
pty_master: None,
pty_request: None,
};

// Send data to channel A only.
state_a
.input_sender
.as_ref()
.unwrap()
.send(b"hello-a".to_vec())
.unwrap();
// Send data to channel B only.
state_b
.input_sender
.as_ref()
.unwrap()
.send(b"hello-b".to_vec())
.unwrap();

assert_eq!(rx_a.recv().unwrap(), b"hello-a");
assert_eq!(rx_b.recv().unwrap(), b"hello-b");

// EOF on channel A (drop sender) should not affect channel B.
state_a.input_sender.take();
assert!(
rx_a.recv().is_err(),
"channel A sender dropped, recv should fail"
);

// Channel B should still be functional.
state_b
.input_sender
.as_ref()
.unwrap()
.send(b"still-alive".to_vec())
.unwrap();
assert_eq!(rx_b.recv().unwrap(), b"still-alive");
}
}
Loading