diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 5ff336c181..fea1ec3023 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -851,6 +851,78 @@ def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: return name +def _split_shell_arguments(args_string: str) -> list[str]: + """ + Split shell-style arguments while preserving quotes inside unquoted JSON tokens. + + This supports the shell usage expected by `pyrit shell` for values like: + - `--initialization-scripts "/tmp/my script.py"` + - `--memory-labels '{"experiment": "test 1"}'` + + Quotes are only treated as grouping characters when they start a new token. + This preserves internal JSON quotes for unquoted tokens such as + `--memory-labels {"key":"value"}`. + + Args: + args_string: Raw argument string entered in shell mode. + + Returns: + list[str]: Tokenized arguments with surrounding quotes removed. + + Raises: + ValueError: If a quoted argument is not terminated. + """ + parts: list[str] = [] + current: list[str] = [] + quote_char: str | None = None + i = 0 + + while i < len(args_string): + char = args_string[i] + + if quote_char: + if char == "\\" and i + 1 < len(args_string): + next_char = args_string[i + 1] + if next_char in (quote_char, "\\"): + current.append(next_char) + i += 2 + continue + if char == quote_char: + quote_char = None + else: + current.append(char) + i += 1 + continue + + if char.isspace(): + if current: + parts.append("".join(current)) + current = [] + i += 1 + continue + + if char in ("'", '"') and not current: + quote_char = char + i += 1 + continue + + current.append(char) + i += 1 + + if quote_char: + raise ValueError("Unterminated quoted argument") + + if current: + parts.append("".join(current)) + + return parts + + +def _is_run_argument_flag(part: str) -> bool: + """Return whether a token is a recognized run-command flag.""" + return part.startswith("--") or part == "-s" + + def parse_run_arguments(*, args_string: str) -> dict[str, Any]: """ Parse run command arguments from a string (for shell mode). @@ -875,7 +947,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: Raises: ValueError: If parsing or validation fails. """ - parts = args_string.split() + parts = _split_shell_arguments(args_string) if not parts: raise ValueError("No scenario name provided") @@ -901,28 +973,28 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: # Collect initializers until next flag, parsing name:key=val syntax result["initializers"] = [] i += 1 - while i < len(parts) and not parts[i].startswith("--"): + while i < len(parts) and not _is_run_argument_flag(parts[i]): result["initializers"].append(_parse_initializer_arg(parts[i])) i += 1 elif parts[i] == "--initialization-scripts": # Collect script paths until next flag result["initialization_scripts"] = [] i += 1 - while i < len(parts) and not parts[i].startswith("--"): + while i < len(parts) and not _is_run_argument_flag(parts[i]): result["initialization_scripts"].append(parts[i]) i += 1 elif parts[i] == "--env-files": # Collect env file paths until next flag result["env_files"] = [] i += 1 - while i < len(parts) and not parts[i].startswith("--"): + while i < len(parts) and not _is_run_argument_flag(parts[i]): result["env_files"].append(parts[i]) i += 1 elif parts[i] in ("--strategies", "-s"): # Collect strategies until next flag result["scenario_strategies"] = [] i += 1 - while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": + while i < len(parts) and not _is_run_argument_flag(parts[i]): result["scenario_strategies"].append(parts[i]) i += 1 elif parts[i] == "--max-concurrency": @@ -959,7 +1031,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: # Collect dataset names until next flag result["dataset_names"] = [] i += 1 - while i < len(parts) and not parts[i].startswith("--"): + while i < len(parts) and not _is_run_argument_flag(parts[i]): result["dataset_names"].append(parts[i]) i += 1 elif parts[i] == "--max-dataset-size": diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 7c040deb55..b0a7477e61 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -616,6 +616,15 @@ def test_parse_run_arguments_with_short_strategies(self): assert result["scenario_strategies"] == ["s1", "s2"] + def test_parse_run_arguments_with_short_strategies_after_initializers(self): + """Test that -s is treated as a flag after multi-value initializers.""" + result = frontend_core.parse_run_arguments( + args_string="test_scenario --initializers init1 -s s1 s2" + ) + + assert result["initializers"] == ["init1"] + assert result["scenario_strategies"] == ["s1", "s2"] + def test_parse_run_arguments_with_max_concurrency(self): """Test parsing with max-concurrency.""" result = frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency 5") @@ -654,6 +663,23 @@ def test_parse_run_arguments_with_initialization_scripts(self): assert result["initialization_scripts"] == ["script1.py", "script2.py"] + def test_parse_run_arguments_with_quoted_paths(self): + """Test parsing quoted paths with spaces for shell mode.""" + result = frontend_core.parse_run_arguments( + args_string='test_scenario --initialization-scripts "/tmp/my script.py" --env-files "/tmp/dev env.env"' + ) + + assert result["initialization_scripts"] == ["/tmp/my script.py"] + assert result["env_files"] == ["/tmp/dev env.env"] + + def test_parse_run_arguments_with_quoted_memory_labels(self): + """Test parsing quoted JSON for memory-labels in shell mode.""" + result = frontend_core.parse_run_arguments( + args_string="""test_scenario --memory-labels '{"experiment": "test 1"}'""" + ) + + assert result["memory_labels"] == {"experiment": "test 1"} + def test_parse_run_arguments_complex(self): """Test parsing complex argument combination.""" args = "test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10"