diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index c37e2a765..43eb1a600 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -12,6 +12,7 @@ @click.command("shell") @opt_config() +@click.argument("command", nargs=-1) # client specific # TODO: warn if these are specified with exporter config @click.option("--lease", "lease_name") @@ -19,9 +20,17 @@ @opt_duration_partial(default=timedelta(minutes=30), show_default="00:30:00") # end client specific @handle_exceptions -def shell(config, lease_name, selector, duration): +def shell(config, command: tuple[str, ...], lease_name, selector, duration): """ - Spawns a shell connecting to a local or remote exporter + Spawns a shell (or custom command) connecting to a local or remote exporter + + COMMAND is the custom command to run instead of shell. + + Example: + + .. code-block:: bash + + $ jmp shell --exporter foo -- python bar.py """ match config: @@ -31,11 +40,23 @@ def shell(config, lease_name, selector, duration): with config.lease(selector=selector, lease_name=lease_name, duration=duration) as lease: with lease.serve_unix() as path: with lease.monitor(): - exit_code = launch_shell(path, "remote", config.drivers.allow, config.drivers.unsafe) + exit_code = launch_shell( + path, + "remote", + config.drivers.allow, + config.drivers.unsafe, + command=command, + ) sys.exit(exit_code) case ExporterConfigV1Alpha1(): with config.serve_unix() as path: # SAFETY: the exporter config is local thus considered trusted - launch_shell(path, "local", allow=[], unsafe=True) + launch_shell( + path, + "local", + allow=[], + unsafe=True, + command=command, + ) diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index 832dce4e3..6984e78ac 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -80,7 +80,14 @@ def env(): PROMPT_CWD = "\\W" -def launch_shell(host: str, context: str, allow: list[str], unsafe: bool) -> int: +def launch_shell( + host: str, + context: str, + allow: [str], + unsafe: bool, + *, + command: tuple[str, ...] | None = None, +) -> int: """Launch a shell with a custom prompt indicating the exporter type. Args: @@ -89,21 +96,21 @@ def launch_shell(host: str, context: str, allow: list[str], unsafe: bool) -> int allow: List of allowed drivers unsafe: Whether to allow drivers outside of the allow list """ - cmd = [os.environ.get("SHELL", "bash")] - if cmd[0].endswith("bash"): - cmd.append("--norc") - cmd.append("--noprofile") - - process = Popen( - cmd, - stdin=sys.stdin, - stdout=sys.stdout, - stderr=sys.stderr, - env=os.environ - | { - JUMPSTARTER_HOST: host, - JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow), - "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", - }, - ) + + env = os.environ | { + JUMPSTARTER_HOST: host, + JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow), + "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", + } + + if command: + process = Popen(command, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) + else: + cmd = [os.environ.get("SHELL", "bash")] + if cmd[0].endswith("bash"): + cmd.append("--norc") + cmd.append("--noprofile") + + process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) + return process.wait()