diff --git a/.local.env b/.local.env index db32f308..e11552ac 100644 --- a/.local.env +++ b/.local.env @@ -1,4 +1,4 @@ -SENTRIUS_VERSION=1.1.371 +SENTRIUS_VERSION=1.1.375 SENTRIUS_SSH_VERSION=1.1.41 SENTRIUS_KEYCLOAK_VERSION=1.1.53 SENTRIUS_AGENT_VERSION=1.1.42 @@ -6,4 +6,4 @@ SENTRIUS_AI_AGENT_VERSION=1.1.264 LLMPROXY_VERSION=1.0.78 LAUNCHER_VERSION=1.0.82 AGENTPROXY_VERSION=1.0.85 -SSHPROXY_VERSION=1.0.40 \ No newline at end of file +SSHPROXY_VERSION=1.0.87 \ No newline at end of file diff --git a/.local.env.bak b/.local.env.bak index db32f308..e11552ac 100644 --- a/.local.env.bak +++ b/.local.env.bak @@ -1,4 +1,4 @@ -SENTRIUS_VERSION=1.1.371 +SENTRIUS_VERSION=1.1.375 SENTRIUS_SSH_VERSION=1.1.41 SENTRIUS_KEYCLOAK_VERSION=1.1.53 SENTRIUS_AGENT_VERSION=1.1.42 @@ -6,4 +6,4 @@ SENTRIUS_AI_AGENT_VERSION=1.1.264 LLMPROXY_VERSION=1.0.78 LAUNCHER_VERSION=1.0.82 AGENTPROXY_VERSION=1.0.85 -SSHPROXY_VERSION=1.0.40 \ No newline at end of file +SSHPROXY_VERSION=1.0.87 \ No newline at end of file diff --git a/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java b/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java index 26721f4e..9bbb337d 100644 --- a/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java +++ b/api/src/main/java/io/sentrius/sso/controllers/api/SystemApiController.java @@ -128,6 +128,9 @@ public String setOption(HttpServletRequest request, HttpServletResponse response case "java.lang.Float": results.add(systemOptions.setValue(option.getName(), Float.valueOf(entry.getValue()[0]), false)); break; + case "java.lang.Double": + results.add(systemOptions.setValue(option.getName(), Double.valueOf(entry.getValue()[0]), false)); + break; default: log.error("Unsupported type: {}", option.getClosestType()); } diff --git a/dataplane/src/main/java/io/sentrius/sso/core/model/sessions/SessionOutput.java b/dataplane/src/main/java/io/sentrius/sso/core/model/sessions/SessionOutput.java index 49325d78..bfd74cd0 100644 --- a/dataplane/src/main/java/io/sentrius/sso/core/model/sessions/SessionOutput.java +++ b/dataplane/src/main/java/io/sentrius/sso/core/model/sessions/SessionOutput.java @@ -248,11 +248,13 @@ public AuditOutput waitForOutput(Long time, } if (!persistentMessage.isEmpty()){ + log.info("Persistent Message: {}", persistentMessage); var trigger = persistentMessage.pop(); triggers.add( getTrigger(trigger)); } if (!prompt.isEmpty()){ + log.info("Prompt: {}", prompt); var trigger = prompt.pop(); triggers.add( getTrigger(trigger)); } diff --git a/dataplane/src/main/java/io/sentrius/sso/core/services/SshListenerService.java b/dataplane/src/main/java/io/sentrius/sso/core/services/SshListenerService.java index d3170d55..9a09b003 100644 --- a/dataplane/src/main/java/io/sentrius/sso/core/services/SshListenerService.java +++ b/dataplane/src/main/java/io/sentrius/sso/core/services/SshListenerService.java @@ -169,6 +169,7 @@ public void sendToTerminalSession(String terminalSessionId, ConnectedSystem conn public void processTerminalMessage( ConnectedSystem terminalSessionId, Session.TerminalMessage terminalMessage) { + log.info("process terminal messsage"); if (!terminalSessionId.getSession().getClosed() && terminalMessage.getType() != Session.MessageType.HEARTBEAT) { try { diff --git a/llm-dataplane/src/main/java/io/sentrius/sso/automation/auditing/rules/TwoPartyAIMonitor.java b/llm-dataplane/src/main/java/io/sentrius/sso/automation/auditing/rules/TwoPartyAIMonitor.java index 7ce681c7..9478f2cf 100644 --- a/llm-dataplane/src/main/java/io/sentrius/sso/automation/auditing/rules/TwoPartyAIMonitor.java +++ b/llm-dataplane/src/main/java/io/sentrius/sso/automation/auditing/rules/TwoPartyAIMonitor.java @@ -120,6 +120,7 @@ public Optional trigger(String cmd) { } if (llmResponse.get() != null) { + log.info("OpenAI analysis completed. Malicious: {}, response: {}, question: {}", flaggedAsMalicious, llmResponse.get(), llmQuestion.get()); Trigger trg = llmQuestion.get() != null ? new Trigger(TriggerAction.PROMPT_ACTION, llmResponse.get(), llmQuestion.get()) : new Trigger(TriggerAction.PERSISTENT_MESSAGE, llmResponse.get()); @@ -183,6 +184,7 @@ public Optional onMessage(Session.TerminalMessage text) { analysis.get(); if (llmResponse.get() != null && llmQuestion.get() != null) { + log.info("OpenAI analysis completed. Malicious: {}, response: {}, question: {}", flaggedAsMalicious, llmResponse.get(), llmQuestion.get()); Trigger trg = llmQuestion.get() != null && enableLLMQuestions ? new Trigger(TriggerAction.PROMPT_ACTION, llmResponse.get(), llmQuestion.get()) : @@ -197,11 +199,14 @@ public Optional onMessage(Session.TerminalMessage text) { } } if ((connectedSystem.getWebsocketListenerSessionId() == null || connectedSystem.getWebsocketListenerSessionId().isEmpty() ) && flaggedAsMalicious) { + log.info("Flagged as malicious but no websocket session ID available. Returning JIT action."); if (llmQuestion.get()!= null){ + log.info("Flagged as malicious but no websocket session ID available. Returning prompt action."); Trigger trg = new Trigger(TriggerAction.PROMPT_ACTION, DESCRIPTION); return Optional.of(trg); } else { + log.info("Flagged as malicious but no websocket session ID available. Returning JIT action."); Trigger trg = new Trigger(TriggerAction.JIT_ACTION, DESCRIPTION); return Optional.of(trg); } diff --git a/ssh-proxy/pom.xml b/ssh-proxy/pom.xml index 79f81382..bf8e52a1 100644 --- a/ssh-proxy/pom.xml +++ b/ssh-proxy/pom.xml @@ -34,7 +34,11 @@ sentrius-dataplane 1.0.0-SNAPSHOT - + + io.sentrius + llm-dataplane + 1.0.0-SNAPSHOT + io.kubernetes diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/config/TaskConfig.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/config/TaskConfig.java index f9ea4aed..2011615f 100644 --- a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/config/TaskConfig.java +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/config/TaskConfig.java @@ -1,8 +1,11 @@ package io.sentrius.sso.sshproxy.config; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadPoolExecutor; import io.sentrius.sso.core.services.TerminalService; import jakarta.annotation.PreDestroy; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -13,31 +16,37 @@ @Slf4j @Configuration @EnableAsync +@RequiredArgsConstructor public class TaskConfig { + private final TerminalService terminalService; + + // Keep a reference so we can shut it down explicitly on destroy, if desired. private ThreadPoolTaskExecutor executor; @Bean(name = "taskExecutor") - public Executor taskExecutor() { - ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); - executor.setCorePoolSize(15); - executor.setMaxPoolSize(20); - executor.setQueueCapacity(100); - executor.setThreadNamePrefix("SentriusTask-"); - executor.initialize(); - return executor; + public ThreadPoolTaskExecutor taskExecutor() { + ThreadPoolTaskExecutor exec = new ThreadPoolTaskExecutor(); + exec.setCorePoolSize(15); + exec.setMaxPoolSize(20); + exec.setQueueCapacity(100); + exec.setThreadNamePrefix("ProxySession-"); + exec.setWaitForTasksToCompleteOnShutdown(true); + exec.setAwaitTerminationSeconds(30); + exec.initialize(); + + this.executor = exec; // assign the field, not a shadowed local + return exec; // expose as Executor for @Async } @PreDestroy public void shutdownExecutor() { if (executor != null) { + log.info("Shutting down task executor"); executor.shutdown(); } - log.info("Shutting down executor"); - // Call shutdown on SshListenerService to close streams + // If you truly want this on application shutdown: + log.info("Shutting down TerminalService"); terminalService.shutdown(); } - - @Autowired - private TerminalService terminalService; -} +} \ No newline at end of file diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ResponseServiceSession.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ResponseServiceSession.java index bfaefe57..ff2da5d3 100644 --- a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ResponseServiceSession.java +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ResponseServiceSession.java @@ -22,6 +22,7 @@ public class ResponseServiceSession implements DataSession { private final InputStream in; private final OutputStream out; private final BaseAccessTokenAuditor auditor; + private String persistentMessage = ""; private ConnectedSystem connectedSystem; @@ -83,7 +84,13 @@ public void sendMessage(WebSocketMessage message) throws IOException { break; case PERSISTENT_MESSAGE: - msg = formatPersistentMessage(trigger, auditLog); + if (!persistentMessage.equals(trigger.getDescription())) { + log.info(ANSI_BOLD + "Persistent message: " + ANSI_RESET + trigger.getDescription()); + msg = formatPersistentMessage(trigger, auditLog); + } + else { + log.info(ANSI_BOLD + "Persistent message: samesies" + ANSI_RESET); + } break; case APPROVE_ACTION: msg = formatApproveMessage(trigger, auditLog); @@ -106,6 +113,7 @@ public void sendMessage(WebSocketMessage message) throws IOException { + } } @@ -163,11 +171,15 @@ private String formatRecordMessage(Session.Trigger trigger, Session.TerminalMess } private String formatPersistentMessage(Session.Trigger trigger, Session.TerminalMessage auditLog) { + if (trigger.getDescription() == null || trigger.getDescription().isEmpty()) { + return ""; + } StringBuilder sb = new StringBuilder(); sb.append("\r\n"); - sb.append(ANSI_BLUE).append(ANSI_BOLD).append("💬 MESSAGE").append(ANSI_RESET).append("\r\n"); + sb.append(ANSI_BLUE).append(ANSI_BOLD).append("💬 AI Monitor").append(ANSI_RESET).append("\r\n"); sb.append(ANSI_BLUE).append(trigger.getDescription()).append(ANSI_RESET).append("\r\n"); sb.append("\r\n"); + persistentMessage = trigger.getDescription(); return sb.toString(); } diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ShellHandlerRunnable.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ShellHandlerRunnable.java new file mode 100644 index 00000000..145ff7cc --- /dev/null +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/ShellHandlerRunnable.java @@ -0,0 +1,339 @@ +package io.sentrius.sso.sshproxy.handler; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; +import java.security.GeneralSecurityException; +import java.sql.SQLException; +import java.util.concurrent.atomic.AtomicReference; +import io.sentrius.sso.automation.auditing.Trigger; +import io.sentrius.sso.automation.auditing.TriggerAction; +import io.sentrius.sso.core.model.ConnectedSystem; +import io.sentrius.sso.core.model.HostSystem; +import io.sentrius.sso.core.services.SshListenerService; +import io.sentrius.sso.core.services.terminal.SessionTrackingService; +import io.sentrius.sso.protobuf.Session; +import io.sentrius.sso.sshproxy.service.HostSystemSelectionService; +import io.sentrius.sso.sshproxy.service.InlineTerminalResponseService; +import io.sentrius.sso.sshproxy.streams.SessionRoute; +import lombok.Builder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.apache.sshd.server.ExitCallback; +import org.apache.sshd.server.session.ServerSession; + + +@Getter +@Setter +@Builder +@Slf4j +public class ShellHandlerRunnable implements Runnable { + + + protected volatile boolean running = true; + private final SshListenerService sshListenerService; + private final SessionTrackingService sessionTrackingService; + private InputStream in; + private final ExitCallback callback; + private final SessionRoute sessionRoute; + + private final HostSystemSelectionService hostSystemSelectionService; + private final InlineTerminalResponseService terminalResponseService; + + + private HostSystem selectedHostSystem; + private ServerSession session; + + @Builder.Default + private final AtomicReference commandBuffer = new AtomicReference<>(new StringBuilder()); + + + + @Override + public void run() { + log.info("Starting ShellHandlerRunnable for user: {}", session.getUsername()); + try { + byte[] buffer = new byte[1024]; + var auditLog = + Session.TerminalMessage.newBuilder(); + commandBuffer.set(new StringBuilder()); + while (running) { + + int bytesRead = in.read(buffer); + if (bytesRead == -1) { + // EOF reached + break; + } + + if (bytesRead > 0) { + log.info("Read {} bytes from SSH input stream", bytesRead); + } + + for (int i = 0; i < bytesRead; i++) { + byte b = buffer[i]; + char c = (char) b; + + + // Process input character and send audit log + if (c >= 32 && c <= 126) { + log.trace("Processing printable character: {}", c); + // Printable characters + auditLog.setCommand(String.valueOf(c)); + commandBuffer.get().append(c); + auditLog.setType(Session.MessageType.USER_DATA); + auditLog.setKeycode(-1); + getSshListenerService().processTerminalMessage( + sessionRoute.getCurrent().get(), + auditLog.build() + ); + log.info("Appending printable character to command buffer: {}", c); + auditLog = Session.TerminalMessage.newBuilder(); + } else { + // Control characters and special keys + if (handleBuiltinCommand(commandBuffer.toString())) { + log.info("Handled built-in command: {}", commandBuffer); + commandBuffer.set(new StringBuilder()); + auditLog.setKeycode(c); + + + boolean allNoAction = true; + auditLog.setType(Session.MessageType.USER_DATA); + + var auditLogSend = auditLog.build(); + for (var action : sessionRoute.getCurrent().get().getSessionStartupActions()) { + var trigger = action.onMessage(auditLogSend); + if (trigger.get().getAction() == TriggerAction.JIT_ACTION) { + allNoAction = false; + // drop the message + sessionRoute.getCurrent().get().getTerminalAuditor().setSessionTrigger(trigger.get()); + log.debug("**** Setting JIT Trigger: {}", trigger.get()); + sessionTrackingService.addSystemTrigger(sessionRoute.getCurrent().get(), trigger.get()); + return; + } else if (trigger.get().getAction() == TriggerAction.WARN_ACTION) { + allNoAction = false; + // send the message + log.debug("**** Setting WARN Trigger: {}", trigger.get()); + sessionRoute.getCurrent().get().getTerminalAuditor().setSessionTrigger(trigger.get()); + sessionTrackingService.addSystemTrigger(sessionRoute.getCurrent().get(), trigger.get()); + } else if (trigger.get().getAction() == TriggerAction.PROMPT_ACTION) { + sessionTrackingService.addTrigger(sessionRoute.getCurrent().get(), trigger.get()); + return; + } + } + if (allNoAction && sessionRoute.getCurrent().get().getSessionStartupActions().size() > 0) { + log.debug("**** Setting NO_ACTION Trigger"); + var noActionTrigger = new Trigger(TriggerAction.NO_ACTION, ""); + sessionTrackingService.addSystemTrigger(sessionRoute.getCurrent().get(), noActionTrigger); + sessionRoute.getCurrent().get().getTerminalAuditor().setSessionTrigger(noActionTrigger); + } + + log.debug("Sending terminal keycode to session"); + + getSshListenerService().processTerminalMessage( + sessionRoute.getCurrent().get(), + auditLogSend + ); + auditLog = Session.TerminalMessage.newBuilder(); + } else { + + log.info("Appending control character to command buffer: {}", (int) c); + // Forward command to target SSH server + sessionRoute.getCurrent().get().getCommander().write(SshListenerService.keyMap.get(3)); + sessionRoute.getCurrent().get().getTerminalAuditor().clear(0); // clear in case + } + + } + } + } + }catch (Exception e) { + log.error("error",e); + e.printStackTrace(); + throw new RuntimeException(e); + } finally { + + try { + in.close(); + } catch (Exception e) { + log.error("Error closing input stream: {}", e.getMessage()); + } + } + } + + private boolean handleBuiltinCommand(String command) + throws IOException, SQLException, GeneralSecurityException, ClassNotFoundException, InvocationTargetException, + NoSuchMethodException, InstantiationException, IllegalAccessException { + String cmd = command.toLowerCase().trim(); + String[] parts = command.trim().split("\\s+"); + + log.info("Processing built-in command: '{}'", cmd); + switch (cmd) { + case "exit": + case "quit": + terminalResponseService.sendMessage("Goodbye!\r\n", sessionRoute.getOut()); + running = false; + callback.onExit(0); + return true; + + case "help": + showHelp(); + return true; + + case "status": + showStatus(); + return false; + + case "hosts": + showAvailableHosts(); + return false; + + default: + if (parts.length >= 2 && "connect".equals(parts[0].toLowerCase())) { + log.info("Handling connect command to switch target host"); + return handleConnectCommand(parts); + } + log.info("Unknown command '{}'", cmd); + return true; + } + } + + private void showHelp() throws IOException { + String help = "\r\n" + + "Sentrius SSH Proxy - Built-in Commands:\r\n" + + " help - Show this help message\r\n" + + " status - Show session status\r\n" + + " hosts - List available target hosts\r\n" + + " connect - Connect to HostSystem by ID\r\n" + + " connect - Connect to HostSystem by display name\r\n" + + " exit - Close SSH session\r\n" + + "\r\n" + + "All other commands are forwarded to the target SSH server\r\n" + + "and subject to Sentrius security policies.\r\n\r\n"; + + terminalResponseService.sendMessage(help, sessionRoute.getOut()); + } + + private void showStatus() throws IOException { + String hostInfo = selectedHostSystem != null + ? String.format( + "%s (%s:%d)", selectedHostSystem.getDisplayName(), + selectedHostSystem.getHost(), selectedHostSystem.getPort() + ) + : "No target host configured"; + + String status = String.format( + "\r\n" + + "Sentrius SSH Proxy Status:\r\n" + + " User: %s\r\n" + + " Target Host: %s\r\n" + + " Session Active: %s\r\n" + + " Safeguards: ENABLED\r\n\r\n", + session.getUsername(), + hostInfo, + running ? "YES" : "NO" + ); + + terminalResponseService.sendMessage(status, sessionRoute.getOut()); + } + + private void showAvailableHosts() throws IOException { + var hostSystems = hostSystemSelectionService.getAllHostSystems(); + + StringBuilder hostList = new StringBuilder("\r\nAvailable HostSystems:\r\n"); + hostList.append("ID\tName\t\t\tHost:Port\t\tStatus\r\n"); + hostList.append("────────────────────────────────────────────────────────────\r\n"); + + if (hostSystems.isEmpty()) { + hostList.append("No HostSystems configured in database.\r\n"); + } else { + for (HostSystem hs : hostSystems) { + String name = hs.getDisplayName() != null ? hs.getDisplayName() : "N/A"; + String hostPort = String.format("%s:%d", hs.getHost(), hs.getPort()); + String status = + hostSystemSelectionService.isHostSystemValid(hs) ? "Valid" : "Invalid"; + String current = + (selectedHostSystem != null && selectedHostSystem.getId().equals(hs.getId())) ? " *" : ""; + + hostList.append(String.format( + "%d\t%-15s\t%-15s\t%s%s\r\n", + hs.getId(), name, hostPort, status, current + )); + } + hostList.append("\r\n* = Current selection\r\n"); + } + hostList.append("\r\n"); + + terminalResponseService.sendMessage(hostList.toString(), sessionRoute.getOut()); + } + + private boolean handleConnectCommand(String[] parts) + throws IOException, SQLException, GeneralSecurityException, ClassNotFoundException, InvocationTargetException, + NoSuchMethodException, InstantiationException, IllegalAccessException { + if (parts.length < 2) { + terminalResponseService.sendMessage("Usage: connect \r\n", sessionRoute.getOut()); + return true; + } + + String target = parts[1]; + HostSystem targetHost = null; + + // Try to parse as ID first + try { + Long id = Long.parseLong(target); + targetHost = hostSystemSelectionService.getHostSystemById(id).orElse(null); + + + } catch (NumberFormatException e) { + // Not a number, try by display name + var hostsByName = hostSystemSelectionService.getHostSystemsByDisplayName(target); + if (!hostsByName.isEmpty()) { + targetHost = hostsByName.get(0); + if (hostsByName.size() > 1) { + terminalResponseService.sendMessage( + String.format("Warning: Multiple hosts found with name '%s', using first one.\r\n", target), + sessionRoute.getOut() + ); + } + } + } + + if (targetHost == null) { + terminalResponseService.sendMessage( + String.format("Error: HostSystem '%s' not found.\r\n", target), sessionRoute.getOut()); + return true; + } + + if (!hostSystemSelectionService.isHostSystemValid(targetHost)) { + terminalResponseService.sendMessage( + String.format("Error: HostSystem '%s' is not properly configured.\r\n", target), sessionRoute.getOut()); + return true; + } + + selectedHostSystem = targetHost; + + var connectedSystem = sessionRoute.connect(sessionRoute.getCurrent().get().getUser(), + targetHost.getHostGroups().get(0), in, targetHost.getId()); + + + sessionRoute.set(connectedSystem); + + commandBuffer.set(new StringBuilder()); + + + terminalResponseService.sendMessage( + String.format( + "Connected to HostSystem: %s (%s:%d)\r\n", + targetHost.getDisplayName(), targetHost.getHost(), targetHost.getPort() + ), sessionRoute.getOut() + ); + + log.info( + "SSH proxy session switched to HostSystem: {} ({}:{})", + targetHost.getDisplayName(), targetHost.getHost(), targetHost.getPort() + ); + + return false; + } + +} diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShell.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShell.java index 4f4af6c9..07340836 100644 --- a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShell.java +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShell.java @@ -3,31 +3,20 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.lang.reflect.InvocationTargetException; import java.security.GeneralSecurityException; -import java.sql.SQLException; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Future; +import io.sentrius.sso.core.config.SystemOptions; import io.sentrius.sso.core.model.ConnectedSystem; import io.sentrius.sso.core.model.HostSystem; -import io.sentrius.sso.core.model.hostgroup.HostGroup; -import io.sentrius.sso.core.model.hostgroup.ProfileConfiguration; -import io.sentrius.sso.core.model.metadata.TerminalSessionMetadata; -import io.sentrius.sso.core.model.users.User; -import io.sentrius.sso.core.services.HostGroupService; -import io.sentrius.sso.core.services.SessionService; import io.sentrius.sso.core.services.SshListenerService; -import io.sentrius.sso.core.services.TerminalService; import io.sentrius.sso.core.services.UserService; -import io.sentrius.sso.core.services.metadata.TerminalSessionMetadataService; import io.sentrius.sso.core.services.security.CryptoService; import io.sentrius.sso.core.services.terminal.SessionTrackingService; -import io.sentrius.sso.protobuf.Session; import io.sentrius.sso.sshproxy.config.SshProxyConfig; import io.sentrius.sso.sshproxy.service.HostSystemSelectionService; import io.sentrius.sso.sshproxy.service.InlineTerminalResponseService; import io.sentrius.sso.sshproxy.service.SshCommandProcessor; -import lombok.Builder; +import io.sentrius.sso.sshproxy.streams.SessionRoute; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.apache.sshd.server.Environment; @@ -35,7 +24,7 @@ import org.apache.sshd.server.channel.ChannelSession; import org.apache.sshd.server.command.Command; import org.apache.sshd.server.session.ServerSession; -import org.hibernate.Hibernate; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; /** * Individual SSH shell session that applies Sentrius safeguards @@ -50,43 +39,51 @@ public class SshProxyShell implements Command { final SshProxyConfig config; final SessionTrackingService sessionTrackingService; - final SessionService sessionService; final SshListenerService sshListenerService; final CryptoService cryptoService; - final TerminalSessionMetadataService terminalSessionMetadataService; - final HostGroupService hostGroupService; - final TerminalService terminalService; + final UserService userService; + private InputStream in; private OutputStream out; private OutputStream err; private ExitCallback callback; private Environment environment; private ServerSession session; - private ConnectedSystem connectedSystem; + private HostSystem selectedHostSystem; - private Thread shellThread; - private volatile boolean running = false; + private final SessionRoute sessionRoute; + + private ShellHandlerRunnable shellHandler; + + + + private final ThreadPoolTaskExecutor taskExecutor; // inject this + private Future shellFuture = null; + + - // Track active sessions - private static final ConcurrentMap activeSessions = new ConcurrentHashMap<>(); public SshProxyShell( - SshCommandProcessor commandProcessor, InlineTerminalResponseService terminalResponseService, HostSystemSelectionService hostSystemSelectionService, SshProxyConfig config, SessionTrackingService sessionTrackingService, SessionService sessionService, SshListenerService sshListenerService, CryptoService cryptoService, TerminalSessionMetadataService terminalSessionMetadataService, HostGroupService hostGroupService, TerminalService terminalService, UserService userService) { + SshCommandProcessor commandProcessor, InlineTerminalResponseService terminalResponseService, + HostSystemSelectionService hostSystemSelectionService, SshProxyConfig config, + SessionTrackingService sessionTrackingService, + SshListenerService sshListenerService, CryptoService cryptoService, + SessionRoute sessionRoute, UserService userService, + ThreadPoolTaskExecutor taskExecutor + ) { this.commandProcessor = commandProcessor; this.terminalResponseService = terminalResponseService; this.hostSystemSelectionService = hostSystemSelectionService; this.config = config; this.sessionTrackingService = sessionTrackingService; this.userService = userService; - this.sessionService = sessionService; this.sshListenerService = sshListenerService; this.cryptoService = cryptoService; - this.terminalSessionMetadataService = terminalSessionMetadataService; - this.hostGroupService = hostGroupService; - this.terminalService = terminalService; + this.taskExecutor = taskExecutor; + this.sessionRoute = sessionRoute; } @@ -98,6 +95,7 @@ public void setInputStream(InputStream in) { @Override public void setOutputStream(OutputStream out) { this.out = out; + sessionRoute.setOutputStream(out); } @Override @@ -128,7 +126,11 @@ public void start(ChannelSession channel, Environment env) throws IOException { initializeHostSystemSelection(); - var connectedSystem = connect(user, selectedHostSystem.getHostGroups().get(0), selectedHostSystem.getId()); + var connectedSystem = sessionRoute.connect(user, selectedHostSystem.getHostGroups().get(0), + in, + selectedHostSystem.getId()); + sessionRoute.set(connectedSystem); + sendWelcomeMessage(); startShellLoop(connectedSystem); } catch (Exception e) { @@ -155,44 +157,52 @@ private void initializeHostSystemSelection() { } } - public ConnectedSystem connect(User user, HostGroup hostGroup, Long hostId) - throws IOException, ClassNotFoundException, InvocationTargetException, NoSuchMethodException, - InstantiationException, IllegalAccessException, SQLException, GeneralSecurityException { - var hostSystem = getHostGroupService().getHostSystem(hostId); - Hibernate.initialize(hostSystem.get().getPublicKeyList()); - ProfileConfiguration config = hostGroup.getConfiguration(); + private void sendPrompt() throws IOException { + String hostname = selectedHostSystem != null ? selectedHostSystem.getHost() : "unknown"; + String prompt = String.format("[sentrius@%s]$ ", hostname); + terminalResponseService.sendMessage(prompt, out); + } - var sessionLog = getSessionService().createSession(user.getName(), "", user.getUsername(), - hostSystem.get().getHost()); + private void startShellLoop(ConnectedSystem connectedSystem) throws GeneralSecurityException { + shellHandler = ShellHandlerRunnable.builder(). + sessionRoute(sessionRoute).callback(callback).session(session). + in(in).running(true). + sshListenerService(sshListenerService). + sessionTrackingService(sessionTrackingService). + hostSystemSelectionService(hostSystemSelectionService).selectedHostSystem(selectedHostSystem).terminalResponseService(terminalResponseService). + build(); + log.info("Submitting shell handler to executor"); - var sessionRules = getTerminalService().createRules(config); + shellFuture = taskExecutor.submit(shellHandler); + } - var connectedSystem = getTerminalService().openTerminal(user, sessionLog, hostGroup, "", - hostSystem.get().getSshPassword(), - hostSystem.get(), - sessionRules); - TerminalSessionMetadata sessionMetadata = TerminalSessionMetadata.builder().sessionStatus("ACTIVE") - .hostSystem(hostSystem.get()) - .user(user) - .startTime(new java.sql.Timestamp(System.currentTimeMillis())) - .sessionLog(sessionLog) - .build(); + @Override + public void destroy(ChannelSession channel) throws Exception { + log.info("Destroying SSH proxy shell session"); + shellFuture.cancel(true); + cleanup(); + } - sessionMetadata = getTerminalSessionMetadataService().createSession(sessionMetadata); + private void cleanup() { + String sessionId = session.getIoSession().getId() + ""; + sessionRoute.cleanup(sessionId); - activeSessions.put(hostGroup.getId().toString(), connectedSystem); + if (callback != null) { + callback.onExit(0); + } - return connectedSystem; + log.info("SSH proxy shell session cleaned up"); } + private void sendWelcomeMessage() throws IOException { String hostInfo = selectedHostSystem != null ? String.format( @@ -216,261 +226,4 @@ private void sendWelcomeMessage() throws IOException { terminalResponseService.sendMessage(welcome, out); sendPrompt(); } - - private void sendPrompt() throws IOException { - String hostname = selectedHostSystem != null ? selectedHostSystem.getHost() : "unknown"; - String prompt = String.format("[sentrius@%s]$ ", hostname); - terminalResponseService.sendMessage(prompt, out); - } - - private void startShellLoop(ConnectedSystem connectedSystem) throws GeneralSecurityException { - var listenerThread = new ResponseServiceSession(connectedSystem, in, out); - var encryptedSessionId = cryptoService.encrypt(connectedSystem.getSession().getId().toString()); - getSshListenerService().startListeningToSshServer(encryptedSessionId, listenerThread); - running = true; - - shellThread = new Thread(() -> { - try { - byte[] buffer = new byte[1024]; - StringBuilder commandBuffer = new StringBuilder(); - var auditLog = - Session.TerminalMessage.newBuilder(); - - while (running) { - int bytesRead = in.read(buffer); - if (bytesRead == -1) { - // EOF reached - break; - } - - if (bytesRead > 0){ - log.info("Read {} bytes from SSH input stream", bytesRead); - } - - for (int i = 0; i < bytesRead; i++) { - byte b = buffer[i]; - char c = (char) b; - - // Process input character and send audit log - if (c >= 32 && c <= 126) { - // Printable characters - auditLog.setCommand(String.valueOf(c)); - commandBuffer.append(c); - auditLog.setType(Session.MessageType.USER_DATA); - auditLog.setKeycode(-1); - getSshListenerService().processTerminalMessage(connectedSystem, - auditLog.build()); - auditLog = Session.TerminalMessage.newBuilder(); - } else { - // Control characters and special keys - if ( handleBuiltinCommand(commandBuffer.toString()) ){ - commandBuffer = new StringBuilder(); - auditLog.setKeycode(c); - - auditLog.setType(Session.MessageType.USER_DATA); - getSshListenerService().processTerminalMessage( - connectedSystem, - auditLog.build() - ); - auditLog = Session.TerminalMessage.newBuilder(); - } else { - // Forward command to target SSH server - connectedSystem.getCommander().write(SshListenerService.keyMap.get(3)); - connectedSystem.getTerminalAuditor().clear(0); // clear in case - } - } - } - } - - } catch (IOException e) { - if (running) { - log.error("Error in SSH shell loop", e); - } - } finally { - cleanup(); - } - }); - - shellThread.start(); - } - - - private boolean handleBuiltinCommand(String command) throws IOException { - String cmd = command.toLowerCase().trim(); - String[] parts = command.trim().split("\\s+"); - - switch (cmd) { - case "exit": - case "quit": - terminalResponseService.sendMessage("Goodbye!\r\n", out); - running = false; - callback.onExit(0); - return true; - - case "help": - showHelp(); - return true; - - case "status": - showStatus(); - return false; - - case "hosts": - showAvailableHosts(); - return false; - - default: - if (parts.length >= 2 && "connect".equals(parts[0].toLowerCase())) { - return handleConnectCommand(parts); - } - return true; - } - } - - private void showHelp() throws IOException { - String help = "\r\n" + - "Sentrius SSH Proxy - Built-in Commands:\r\n" + - " help - Show this help message\r\n" + - " status - Show session status\r\n" + - " hosts - List available target hosts\r\n" + - " connect - Connect to HostSystem by ID\r\n" + - " connect - Connect to HostSystem by display name\r\n" + - " exit - Close SSH session\r\n" + - "\r\n" + - "All other commands are forwarded to the target SSH server\r\n" + - "and subject to Sentrius security policies.\r\n\r\n"; - - terminalResponseService.sendMessage(help, out); - } - - private void showStatus() throws IOException { - String hostInfo = selectedHostSystem != null - ? String.format( - "%s (%s:%d)", selectedHostSystem.getDisplayName(), - selectedHostSystem.getHost(), selectedHostSystem.getPort() - ) - : "No target host configured"; - - String status = String.format( - "\r\n" + - "Sentrius SSH Proxy Status:\r\n" + - " User: %s\r\n" + - " Target Host: %s\r\n" + - " Session Active: %s\r\n" + - " Safeguards: ENABLED\r\n\r\n", - session.getUsername(), - hostInfo, - running ? "YES" : "NO" - ); - - terminalResponseService.sendMessage(status, out); - } - - private void showAvailableHosts() throws IOException { - var hostSystems = hostSystemSelectionService.getAllHostSystems(); - - StringBuilder hostList = new StringBuilder("\r\nAvailable HostSystems:\r\n"); - hostList.append("ID\tName\t\t\tHost:Port\t\tStatus\r\n"); - hostList.append("────────────────────────────────────────────────────────────\r\n"); - - if (hostSystems.isEmpty()) { - hostList.append("No HostSystems configured in database.\r\n"); - } else { - for (HostSystem hs : hostSystems) { - String name = hs.getDisplayName() != null ? hs.getDisplayName() : "N/A"; - String hostPort = String.format("%s:%d", hs.getHost(), hs.getPort()); - String status = - hostSystemSelectionService.isHostSystemValid(hs) ? "Valid" : "Invalid"; - String current = - (selectedHostSystem != null && selectedHostSystem.getId().equals(hs.getId())) ? " *" : ""; - - hostList.append(String.format( - "%d\t%-15s\t%-15s\t%s%s\r\n", - hs.getId(), name, hostPort, status, current - )); - } - hostList.append("\r\n* = Current selection\r\n"); - } - hostList.append("\r\n"); - - terminalResponseService.sendMessage(hostList.toString(), out); - } - - private boolean handleConnectCommand(String[] parts) throws IOException { - if (parts.length < 2) { - terminalResponseService.sendMessage("Usage: connect \r\n", out); - return true; - } - - String target = parts[1]; - HostSystem targetHost = null; - - // Try to parse as ID first - try { - Long id = Long.parseLong(target); - targetHost = hostSystemSelectionService.getHostSystemById(id).orElse(null); - } catch (NumberFormatException e) { - // Not a number, try by display name - var hostsByName = hostSystemSelectionService.getHostSystemsByDisplayName(target); - if (!hostsByName.isEmpty()) { - targetHost = hostsByName.get(0); - if (hostsByName.size() > 1) { - terminalResponseService.sendMessage( - String.format("Warning: Multiple hosts found with name '%s', using first one.\r\n", target), - out - ); - } - } - } - - if (targetHost == null) { - terminalResponseService.sendMessage( - String.format("Error: HostSystem '%s' not found.\r\n", target), out); - return true; - } - - if (!hostSystemSelectionService.isHostSystemValid(targetHost)) { - terminalResponseService.sendMessage( - String.format("Error: HostSystem '%s' is not properly configured.\r\n", target), out); - return true; - } - - selectedHostSystem = targetHost; - terminalResponseService.sendMessage( - String.format( - "Connected to HostSystem: %s (%s:%d)\r\n", - targetHost.getDisplayName(), targetHost.getHost(), targetHost.getPort() - ), out - ); - - log.info( - "SSH proxy session switched to HostSystem: {} ({}:{})", - targetHost.getDisplayName(), targetHost.getHost(), targetHost.getPort() - ); - - return true; - } - - - @Override - public void destroy(ChannelSession channel) throws Exception { - log.info("Destroying SSH proxy shell session"); - running = false; - cleanup(); - } - - private void cleanup() { - String sessionId = session.getIoSession().getId() + ""; - activeSessions.remove(sessionId); - - if (shellThread != null && shellThread.isAlive()) { - shellThread.interrupt(); - } - - if (callback != null) { - callback.onExit(0); - } - - log.info("SSH proxy shell session cleaned up"); - } } diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShellHandler.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShellHandler.java index eb2c35fe..23e75efe 100644 --- a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShellHandler.java +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/handler/SshProxyShellHandler.java @@ -1,5 +1,7 @@ package io.sentrius.sso.sshproxy.handler; +import io.sentrius.sso.core.config.SystemOptions; +import io.sentrius.sso.core.config.ThreadSafeDynamicPropertiesService; import io.sentrius.sso.core.model.ConnectedSystem; import io.sentrius.sso.core.services.ChatService; import io.sentrius.sso.core.services.HostGroupService; @@ -14,15 +16,20 @@ import io.sentrius.sso.sshproxy.service.HostSystemSelectionService; import io.sentrius.sso.sshproxy.service.InlineTerminalResponseService; import io.sentrius.sso.sshproxy.service.SshCommandProcessor; +import io.sentrius.sso.sshproxy.streams.SessionRoute; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.sshd.common.Factory; import org.apache.sshd.server.command.Command; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.stereotype.Component; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; /** * SSH shell handler that integrates with Sentrius safeguards. @@ -48,21 +55,31 @@ public class SshProxyShellHandler implements Factory { final TerminalService terminalService; final UserService userService; + final ThreadSafeDynamicPropertiesService systemOptions; + + + @Qualifier("taskExecutor") // Specify the custom task executor to use + private final ThreadPoolTaskExecutor taskExecutor; + + @Override public Command create() { + if (Boolean.valueOf( systemOptions.getProperty("lockdownEnabled", "false"))) { + throw new RuntimeException("SSH access is disabled by system lockdown"); + } + var sessionRoute = + SessionRoute.builder().sshListenerService(sshListenerService).terminalSessionMetadataService(terminalSessionMetadataService).cryptoService(cryptoService).hostGroupService(hostGroupService).terminalService(terminalService).sessionService(sessionService).build(); return new SshProxyShell( commandProcessor, terminalResponseService, hostSystemSelectionService, config, sessionTrackingService, - sessionService, sshListenerService, cryptoService, - terminalSessionMetadataService, - hostGroupService, - terminalService, - userService + sessionRoute, + userService, + taskExecutor ); } diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/service/HostSystemSelectionService.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/service/HostSystemSelectionService.java index 6a93005b..b329270e 100644 --- a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/service/HostSystemSelectionService.java +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/service/HostSystemSelectionService.java @@ -3,11 +3,11 @@ import io.sentrius.sso.core.model.HostSystem; import io.sentrius.sso.core.model.hostgroup.HostGroup; import io.sentrius.sso.core.repository.SystemRepository; -import jakarta.transaction.Transactional; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.hibernate.Hibernate; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.util.List; import java.util.Optional; @@ -26,16 +26,25 @@ public class HostSystemSelectionService { /** * Get a HostSystem by ID for SSH proxy connection. */ + + @Transactional(readOnly = true) public Optional getHostSystemById(Long id) { try { - return systemRepository.findById(id); + var hostSystem = systemRepository.findById(id); + hostSystem.ifPresent(hs -> { + Hibernate.initialize(hs.getHostGroups()); + for(HostGroup group : hs.getHostGroups()) { + Hibernate.initialize(group.getRules()); + } + }); + return hostSystem; } catch (Exception e) { log.error("Error retrieving HostSystem with ID: {}", id, e); return Optional.empty(); } } - /** + /** * Get all available HostSystems for SSH proxy. */ public List getAllHostSystems() { @@ -50,9 +59,20 @@ public List getAllHostSystems() { /** * Find HostSystems by display name. */ + @Transactional(readOnly = true) public List getHostSystemsByDisplayName(String displayName) { try { - return systemRepository.findByDisplayName(displayName); + var listOfHostSystems = systemRepository.findByDisplayName(displayName); + if (!listOfHostSystems.isEmpty()) { + for (var hostSystem : listOfHostSystems) { + Hibernate.initialize(hostSystem.getHostGroups()); + for (HostGroup group : hostSystem.getHostGroups()) { + Hibernate.initialize(group.getRules()); + } + } + } + return listOfHostSystems; + } catch (Exception e) { log.error("Error retrieving HostSystems by display name: {}", displayName, e); return List.of(); diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/streams/SessionRoute.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/streams/SessionRoute.java new file mode 100644 index 00000000..4451ccf5 --- /dev/null +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/streams/SessionRoute.java @@ -0,0 +1,108 @@ +package io.sentrius.sso.sshproxy.streams; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.InvocationTargetException; +import java.security.CryptoPrimitive; +import java.security.GeneralSecurityException; +import java.sql.SQLException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; +import io.sentrius.sso.core.model.ConnectedSystem; +import io.sentrius.sso.core.model.hostgroup.HostGroup; +import io.sentrius.sso.core.model.hostgroup.ProfileConfiguration; +import io.sentrius.sso.core.model.metadata.TerminalSessionMetadata; +import io.sentrius.sso.core.model.users.User; +import io.sentrius.sso.core.services.HostGroupService; +import io.sentrius.sso.core.services.SessionService; +import io.sentrius.sso.core.services.SshListenerService; +import io.sentrius.sso.core.services.TerminalService; +import io.sentrius.sso.core.services.metadata.TerminalSessionMetadataService; +import io.sentrius.sso.core.services.security.CryptoService; +import io.sentrius.sso.sshproxy.handler.ResponseServiceSession; +import lombok.Builder; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.hibernate.Hibernate; + +@Slf4j +@Getter +@Builder +public final class SessionRoute { + public final AtomicReference current = new AtomicReference<>(); + public final SwappableOutputStream out = new SwappableOutputStream(null); // to the SSH client + final HostGroupService hostGroupService; + final TerminalService terminalService; + final SessionService sessionService; + final CryptoService cryptoService; + final SshListenerService sshListenerService; + + final TerminalSessionMetadataService terminalSessionMetadataService; + + + // Track active sessions + private static final ConcurrentMap activeSessions = new ConcurrentHashMap<>(); + + + + public ConnectedSystem connect(User user, HostGroup hostGroup, InputStream in, Long hostId) + throws IOException, ClassNotFoundException, InvocationTargetException, NoSuchMethodException, + InstantiationException, IllegalAccessException, SQLException, GeneralSecurityException { + var hostSystem = getHostGroupService().getHostSystem(hostId); + + Hibernate.initialize(hostSystem.get().getPublicKeyList()); + + ProfileConfiguration config = hostGroup.getConfiguration(); + + var sessionLog = getSessionService().createSession(user.getName(), "", user.getUsername(), + hostSystem.get().getHost()); + + + log.info("** Session rule size {}", config.getSessionRules().size()); + config.getSessionRules().forEach( + rule -> { + log.info("** Adding session rule: {}", rule.getSessionRuleClass()); + } + + ); + + + var sessionRules = getTerminalService().createRules(config); + + + var connectedSystem = getTerminalService().openTerminal(user, sessionLog, hostGroup, "", + hostSystem.get().getSshPassword(), + hostSystem.get(), + sessionRules); + + + TerminalSessionMetadata sessionMetadata = TerminalSessionMetadata.builder().sessionStatus("ACTIVE") + .hostSystem(hostSystem.get()) + .user(user) + .startTime(new java.sql.Timestamp(System.currentTimeMillis())) + .sessionLog(sessionLog) + .build(); + + sessionMetadata = getTerminalSessionMetadataService().createSession(sessionMetadata); + + activeSessions.put(hostGroup.getId().toString(), connectedSystem); + + var listenerThread = new ResponseServiceSession(connectedSystem, in, out); + var encryptedSessionId = cryptoService.encrypt(connectedSystem.getSession().getId().toString()); + getSshListenerService().startListeningToSshServer(encryptedSessionId, listenerThread); + + current.set(connectedSystem); + + return connectedSystem; + } + + public void set(ConnectedSystem next) { current.set(next); } + + public void setOutputStream(OutputStream next) { out.set(next); } + + public void cleanup(String sessionId){ + activeSessions.remove(sessionId); + } +} \ No newline at end of file diff --git a/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/streams/SwappableOutputStream.java b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/streams/SwappableOutputStream.java new file mode 100644 index 00000000..f12cffcd --- /dev/null +++ b/ssh-proxy/src/main/java/io/sentrius/sso/sshproxy/streams/SwappableOutputStream.java @@ -0,0 +1,23 @@ +package io.sentrius.sso.sshproxy.streams; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicReference; + +public final class SwappableOutputStream extends OutputStream { + private final AtomicReference delegate = new AtomicReference<>(); + + public SwappableOutputStream(OutputStream initial) { + delegate.set(initial); + } + public void set(OutputStream next) { delegate.set(next); } + + @Override public void write(int b) throws IOException { delegate.get().write(b); } + @Override public void write(byte[] b, int off, int len) throws IOException { delegate.get().write(b, off, len); } + @Override public void flush() throws IOException { delegate.get().flush(); } + + @Override public void close() throws IOException { + // Do NOT close the underlying stream here; the SSH layer owns it. + // You can no-op or close the current delegate if that matches your lifecycle. + } +} \ No newline at end of file