diff --git a/packages/jumpstarter/jumpstarter/common/utils.py b/packages/jumpstarter/jumpstarter/common/utils.py index 236aa04ef..f732c942a 100644 --- a/packages/jumpstarter/jumpstarter/common/utils.py +++ b/packages/jumpstarter/jumpstarter/common/utils.py @@ -49,7 +49,7 @@ def serve(root_device: Driver): def launch_shell( host: str, context: str, - allow: [str], + allow: list[str], unsafe: bool, *, command: tuple[str, ...] | None = None, @@ -63,20 +63,52 @@ def launch_shell( unsafe: Whether to allow drivers outside of the allow list """ - env = os.environ | { + shell = os.environ.get("SHELL", "bash") + shell_name = os.path.basename(shell) + + common_env = os.environ | { JUMPSTARTER_HOST: host, JMP_DRIVERS_ALLOW: "UNSAFE" if unsafe else ",".join(allow), - "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", } if command: - process = Popen(command, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) - else: - cmd = [os.environ.get("SHELL", "bash")] - if cmd[0].endswith("bash"): - cmd.append("--norc") - cmd.append("--noprofile") - + process = Popen(command, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=common_env) + return process.wait() + + if shell_name.endswith("bash"): + env = common_env | { + "PS1": f"{ANSI_GRAY}{PROMPT_CWD} {ANSI_YELLOW}⚡{ANSI_WHITE}{context} {ANSI_YELLOW}➤{ANSI_RESET} ", + } + cmd = [shell, "--norc", "--noprofile"] process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) + return process.wait() + + elif shell_name == "fish": + fish_fn = ( + "function fish_prompt; " + "set_color grey; " + 'printf "%s" (basename $PWD); ' + "set_color yellow; " + 'printf "⚡"; ' + "set_color white; " + f'printf "{context}"; ' + "set_color yellow; " + 'printf "➤ "; ' + "set_color normal; " + "end" + ) + cmd = [shell, "--init-command", fish_fn] + process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=common_env) + return process.wait() + + elif shell_name == "zsh": + env = common_env | { + "PS1": f"%F{{8}}%1~ %F{{yellow}}⚡%F{{white}}{context} %F{{yellow}}➤%f ", + } + cmd = [shell, "--no-rcs"] + process = Popen(cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=env) + return process.wait() - return process.wait() + else: + process = Popen([shell], stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr, env=common_env) + return process.wait()