From af8f8e08f73ad52a9121b6af61857d93d43c1009 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 1 Apr 2026 13:25:32 -0700 Subject: [PATCH 1/8] updates --- pyrit/cli/_banner.py | 2 +- pyrit/cli/_cli_args.py | 42 +++++++++++++ pyrit/cli/frontend_core.py | 32 +++++----- pyrit/cli/pyrit_scan.py | 18 +++--- pyrit/cli/pyrit_shell.py | 59 ++++++++++++++++-- tests/unit/cli/test_frontend_core.py | 68 ++++++++++++++++++++ tests/unit/cli/test_pyrit_scan.py | 49 +++++++++++++++ tests/unit/cli/test_pyrit_shell.py | 93 ++++++++++++++++++++++++++++ 8 files changed, 332 insertions(+), 31 deletions(-) diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index bd5a3d40fe..a76d286b27 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -255,7 +255,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo "Commands:", " • list-scenarios - See all available scenarios", " • list-initializers - See all available initializers", - " • list-targets - See all available targets in the registry", + " • list-targets [opts] - See all available targets in the registry", " • run [opts] - Execute a security scenario", " • scenario-history - View your session history", " • print-scenario [N] - Display detailed results", diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 1264956ccb..80e383dfe5 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -457,6 +457,48 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: return result +def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]: + """ + Parse list-targets command arguments from a string (for shell mode). + + Args: + args_string: Space-separated argument string (e.g., "--initializers target"). + + Returns: + Dictionary with parsed arguments: + - initializers: Optional[list[str | dict[str, Any]]] + - initialization_scripts: Optional[list[str]] + + Raises: + ValueError: If parsing or validation fails. + """ + parts = args_string.split() + + result: dict[str, Any] = { + "initializers": None, + "initialization_scripts": None, + } + + i = 0 + while i < len(parts): + if parts[i] == "--initializers": + result["initializers"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initializers"].append(_parse_initializer_arg(parts[i])) + i += 1 + elif parts[i] == "--initialization-scripts": + result["initialization_scripts"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initialization_scripts"].append(parts[i]) + i += 1 + else: + raise ValueError(f"Unknown argument: {parts[i]}") + + return result + + # --------------------------------------------------------------------------- # Shared argparse builder # --------------------------------------------------------------------------- diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 3db5552011..8bcf5f4655 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -27,6 +27,7 @@ from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg from pyrit.cli._cli_args import add_common_arguments as add_common_arguments from pyrit.cli._cli_args import non_negative_int as non_negative_int +from pyrit.cli._cli_args import parse_list_targets_arguments as parse_list_targets_arguments from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments from pyrit.cli._cli_args import positive_int as positive_int @@ -254,18 +255,16 @@ async def list_initializers_async( async def list_targets_async( *, context: FrontendCore, - initializer_names: Optional[list[Any]] = None, ) -> list[str]: """ List available target names from the TargetRegistry. Since targets are registered by initializers, this function requires initializers - to have been run first. If initializer_names are provided, they will be resolved - and run before querying the registry. + to have been run first. Configure initializers on the FrontendCore context + (via initializer_names or initialization_scripts) before calling this function. Args: context: PyRIT context with loaded registries. - initializer_names: Optional list of initializer entries to run before listing. Returns: Sorted list of registered target names. @@ -273,25 +272,24 @@ async def list_targets_async( if not context._initialized: await context.initialize_async() - # If initializer names are provided, run them to populate the target registry - if initializer_names or context._initializer_configs: - configs = context._initializer_configs - if configs: - initializer_instances = [] - for config in configs: + # Run initializers and/or initialization scripts to populate the target registry + if context._initializer_configs or context._initialization_scripts: + initializer_instances = [] + if context._initializer_configs: + for config in context._initializer_configs: initializer_class = context.initializer_registry.get_class(config.name) instance = initializer_class() if config.args: instance.set_params_from_args(args=config.args) initializer_instances.append(instance) - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - silent=getattr(context, "_silent_reinit", False), - ) + await initialize_pyrit_async( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances or None, + env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), + ) target_registry = TargetRegistry.get_registry_singleton() return target_registry.get_names() diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index aefdfa5f22..f6f77bea2b 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -197,17 +197,19 @@ def main(args: Optional[list[str]] = None) -> int: if parsed_args.list_targets: # Need initializers to populate target registry - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initializer_names=parsed_args.initializers, - log_level=parsed_args.log_level, - ) - return asyncio.run(frontend_core.print_targets_list_async(context=context)) + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 - if parsed_args.list_targets: - # Need initializers to populate target registry context = frontend_core.FrontendCore( config_file=parsed_args.config_file, + initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, log_level=parsed_args.log_level, ) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index f19602bee0..6707f14e95 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -33,7 +33,7 @@ class PyRITShell(cmd.Cmd): Commands: list-scenarios - List all available scenarios list-initializers - List all available initializers - list-targets - List all available targets from the registry + list-targets [opts] - List all available targets from the registry run [opts] - Run a scenario with optional parameters scenario-history - List all previous scenario runs print-scenario [N] - Print detailed results for scenario run(s) @@ -189,6 +189,9 @@ def cmdloop(self, intro: Optional[str] = None) -> None: def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" + if arg.strip(): + print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) @@ -197,6 +200,9 @@ def do_list_scenarios(self, arg: str) -> None: def do_list_initializers(self, arg: str) -> None: """List all available initializers.""" + if arg.strip(): + print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_initializers_list_async(context=self.context)) @@ -204,10 +210,49 @@ def do_list_initializers(self, arg: str) -> None: print(f"Error listing initializers: {e}") def do_list_targets(self, arg: str) -> None: - """List all available targets from the TargetRegistry.""" + """ + List all available targets from the TargetRegistry. + + Usage: + list-targets + list-targets --initializers [ ...] + list-targets --initialization-scripts [ ...] + + Options: + --initializers ... Built-in initializers to run first + --initialization-scripts <...> Custom Python scripts to run first + + Examples: + list-targets --initializers target + list-targets --initializers target:tags=default,scorer + """ self._ensure_initialized() try: - asyncio.run(self._fc.print_targets_list_async(context=self.context)) + context_to_use = self.context + + if arg.strip(): + args = self._fc.parse_list_targets_arguments(args_string=arg) + + resolved_scripts = None + if args["initialization_scripts"]: + resolved_scripts = self._fc.resolve_initialization_scripts( + script_paths=args["initialization_scripts"] + ) + + context_to_use = self._fc.FrontendCore( + initialization_scripts=resolved_scripts, + initializer_names=args["initializers"], + log_level=self.default_log_level, + ) + context_to_use._scenario_registry = self.context._scenario_registry + context_to_use._initializer_registry = self.context._initializer_registry + context_to_use._initialized = True + + asyncio.run(self._fc.print_targets_list_async(context=context_to_use)) + except ValueError as e: + print(f"Error: {e}") + except FileNotFoundError as e: + print(f"Error: {e}") except Exception as e: print(f"Error listing targets: {e}") @@ -338,6 +383,9 @@ def do_scenario_history(self, arg: str) -> None: Shows a numbered list of all scenario runs with the commands used. """ + if arg.strip(): + print(f"Error: scenario-history does not accept arguments, got: {arg.strip()}") + return if not self._scenario_history: print("No scenario runs in history.") return @@ -467,8 +515,9 @@ def do_help(self, arg: str) -> None: print(" pyrit_shell") print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG") else: - # Show help for specific command - super().do_help(arg) + # Convert hyphens to underscores (e.g. help list-targets -> help list_targets) for command lookup + normalized_arg = arg.replace("-", "_") + super().do_help(normalized_arg) def do_exit(self, arg: str) -> bool: """ diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 61b3c7bb50..e3cbbca8d0 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -672,6 +672,46 @@ def test_parse_run_arguments_missing_value(self): frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") +class TestParseListTargetsArguments: + """Tests for parse_list_targets_arguments function.""" + + def test_parse_list_targets_arguments_empty(self): + """Test parsing empty string returns defaults.""" + result = frontend_core.parse_list_targets_arguments(args_string="") + assert result["initializers"] is None + assert result["initialization_scripts"] is None + + def test_parse_list_targets_arguments_with_initializers(self): + """Test parsing with initializers.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target init2") + assert result["initializers"] == ["target", "init2"] + + def test_parse_list_targets_arguments_with_initializer_params(self): + """Test parsing initializers with key=value params.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target:tags=default,scorer") + assert result["initializers"] == [{"name": "target", "args": {"tags": ["default", "scorer"]}}] + + def test_parse_list_targets_arguments_with_initialization_scripts(self): + """Test parsing with initialization-scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initialization-scripts script1.py script2.py" + ) + assert result["initialization_scripts"] == ["script1.py", "script2.py"] + + def test_parse_list_targets_arguments_with_both(self): + """Test parsing with both initializers and scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initializers target --initialization-scripts script1.py" + ) + assert result["initializers"] == ["target"] + assert result["initialization_scripts"] == ["script1.py"] + + def test_parse_list_targets_arguments_unknown_arg_raises(self): + """Test parsing with unknown argument raises ValueError.""" + with pytest.raises(ValueError, match="Unknown argument"): + frontend_core.parse_list_targets_arguments(args_string="--unknown-flag") + + @pytest.mark.asyncio @pytest.mark.usefixtures("patch_central_database") class TestRunScenarioAsync: @@ -1141,3 +1181,31 @@ async def test_print_targets_list_empty( captured = capsys.readouterr() assert "No targets found" in captured.out assert "--initializers target" in captured.out + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_list_targets_with_initialization_scripts_calls_initialize( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test list_targets_async calls initialize_pyrit_async when only scripts are configured.""" + mock_registry = MagicMock() + mock_registry.get_names.return_value = ["script_target"] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + context._initialization_scripts = ["/path/to/script.py"] + context._initializer_configs = None + + result = await frontend_core.list_targets_async(context=context) + + assert result == ["script_target"] + # Verify initialize_pyrit_async was called with the scripts + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + assert call_kwargs["initialization_scripts"] == ["/path/to/script.py"] + assert call_kwargs["initializers"] is None diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 34a8b8ad52..a4c3620ca7 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -214,6 +214,55 @@ def test_main_list_scenarios_with_missing_script(self, mock_resolve_scripts: Mag assert result == 1 + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_initializers( + self, + mock_frontend_core: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initializers passes initializers to FrontendCore.""" + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initializers", "target"]) + + assert result == 0 + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initializer_names"] == ["target"] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_scripts( + self, + mock_frontend_core: MagicMock, + mock_resolve_scripts: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initialization-scripts passes scripts to FrontendCore.""" + mock_resolve_scripts.return_value = [Path("/test/script.py")] + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "script.py"]) + + assert result == 0 + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initialization_scripts"] == [Path("/test/script.py")] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_main_list_targets_with_missing_script(self, mock_resolve_scripts: MagicMock): + """Test main with --list-targets and missing script file.""" + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "missing.py"]) + + assert result == 1 + def test_main_no_scenario_specified(self, capsys): """Test main without scenario name.""" result = pyrit_scan.main([]) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 89a218644d..eca05ec159 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -151,6 +151,15 @@ def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, captured = capsys.readouterr() assert "Error listing scenarios" in captured.out + def test_do_list_scenarios_rejects_args(self, shell, capsys): + """Test do_list_scenarios rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_scenarios("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) def test_do_list_initializers(self, mock_print_initializers: AsyncMock, shell): """Test do_list_initializers command.""" @@ -171,6 +180,73 @@ def test_do_list_initializers_with_exception(self, mock_print_initializers: Asyn captured = capsys.readouterr() assert "Error listing initializers" in captured.out + def test_do_list_initializers_rejects_args(self, shell, capsys): + """Test do_list_initializers rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_initializers("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_no_args(self, mock_print_targets: AsyncMock, shell): + """Test do_list_targets with no arguments uses the default context.""" + s, ctx, _ = shell + + s.do_list_targets("") + + mock_print_targets.assert_called_once_with(context=ctx) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + @patch("pyrit.cli.frontend_core.parse_list_targets_arguments") + def test_do_list_targets_with_initializers( + self, + mock_parse: MagicMock, + mock_fc_class: MagicMock, + mock_print_targets: AsyncMock, + shell, + ): + """Test do_list_targets with --initializers creates a new context.""" + s, ctx, _ = shell + mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} + mock_run_context = MagicMock() + mock_fc_class.return_value = mock_run_context + + s.do_list_targets("--initializers target") + + mock_parse.assert_called_once_with(args_string="--initializers target") + mock_fc_class.assert_called_once_with( + initialization_scripts=None, + initializer_names=["target"], + log_level=s.default_log_level, + ) + assert mock_run_context._scenario_registry == ctx._scenario_registry + assert mock_run_context._initializer_registry == ctx._initializer_registry + assert mock_run_context._initialized is True + mock_print_targets.assert_called_once_with(context=mock_run_context) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_with_exception(self, mock_print_targets: AsyncMock, shell, capsys): + """Test do_list_targets handles exceptions.""" + s, ctx, _ = shell + mock_print_targets.side_effect = RuntimeError("Test error") + + s.do_list_targets("") + + captured = capsys.readouterr() + assert "Error listing targets" in captured.out + + def test_do_list_targets_parse_error(self, shell, capsys): + """Test do_list_targets shows error for invalid args.""" + s, ctx, _ = shell + + s.do_list_targets("--unknown-flag") + + captured = capsys.readouterr() + assert "Error" in captured.out + def test_do_run_empty_line(self, shell, capsys): """Test do_run with empty line.""" s, ctx, _ = shell @@ -380,6 +456,15 @@ def test_do_scenario_history_empty(self, shell, capsys): captured = capsys.readouterr() assert "No scenario runs in history" in captured.out + def test_do_scenario_history_rejects_args(self, shell, capsys): + """Test do_scenario_history rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_scenario_history("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + def test_do_scenario_history_with_runs(self, shell, capsys): """Test do_scenario_history with scenario runs.""" s, ctx, _ = shell @@ -502,6 +587,14 @@ def test_do_help_with_arg(self, shell): s.do_help("run") mock_parent_help.assert_called_with("run") + def test_do_help_with_hyphenated_arg(self, shell): + """Test do_help converts hyphens to underscores for command lookup.""" + s, ctx, _ = shell + + with patch("cmd.Cmd.do_help") as mock_parent_help: + s.do_help("list-targets") + mock_parent_help.assert_called_with("list_targets") + @patch.object(cmd.Cmd, "cmdloop") @patch.object(banner, "play_animation") def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, shell): From 642c0d28a25d382ee3423d5eef233cb8ab553237 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 8 Apr 2026 10:53:02 -0700 Subject: [PATCH 2/8] fix bug --- pyrit/cli/pyrit_shell.py | 61 +++++++++++----- tests/unit/cli/test_pyrit_shell.py | 110 ++++++++++++++++++++++++++--- 2 files changed, 146 insertions(+), 25 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 6707f14e95..d3fb73f5f2 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -166,6 +166,42 @@ def _ensure_initialized(self) -> None: self._init_complete.wait() self._raise_init_error() + def _rebuild_context( + self, + *, + initializer_names: Optional[list[Any]] = None, + initialization_scripts: Optional[list[Path]] = None, + log_level: Optional[int] = None, + ) -> frontend_core.FrontendCore: + """ + Create a per-command FrontendCore that inherits the shell's startup config. + + Propagates config_file, database, and env_files from the shell's startup + kwargs, then overrides initializer_names, initialization_scripts, and + log_level for the current command. Shares registries with the shell + context to avoid redundant re-discovery. + + Args: + initializer_names (Optional[list[Any]]): Per-command initializer overrides. + initialization_scripts (Optional[list[Path]]): Per-command script overrides. + log_level (Optional[int]): Per-command log level override. + + Returns: + frontend_core.FrontendCore: A new context ready for use in a command. + """ + cmd_kwargs = dict(self._context_kwargs) + cmd_kwargs["initializer_names"] = initializer_names + cmd_kwargs["initialization_scripts"] = initialization_scripts + cmd_kwargs["log_level"] = log_level if log_level is not None else self.default_log_level + + cmd_context = self._fc.FrontendCore(**cmd_kwargs) + cmd_context._scenario_registry = self.context._scenario_registry + cmd_context._initializer_registry = self.context._initializer_registry + cmd_context._initialized = True + cmd_context._silent_reinit = True + + return cmd_context + def cmdloop(self, intro: Optional[str] = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: @@ -228,8 +264,7 @@ def do_list_targets(self, arg: str) -> None: """ self._ensure_initialized() try: - context_to_use = self.context - + list_targets_context = self.context if arg.strip(): args = self._fc.parse_list_targets_arguments(args_string=arg) @@ -238,17 +273,12 @@ def do_list_targets(self, arg: str) -> None: resolved_scripts = self._fc.resolve_initialization_scripts( script_paths=args["initialization_scripts"] ) - - context_to_use = self._fc.FrontendCore( + list_targets_context = self._rebuild_context( initialization_scripts=resolved_scripts, initializer_names=args["initializers"], - log_level=self.default_log_level, ) - context_to_use._scenario_registry = self.context._scenario_registry - context_to_use._initializer_registry = self.context._initializer_registry - context_to_use._initialized = True - asyncio.run(self._fc.print_targets_list_async(context=context_to_use)) + asyncio.run(self._fc.print_targets_list_async(context=list_targets_context)) except ValueError as e: print(f"Error: {e}") except FileNotFoundError as e: @@ -337,16 +367,13 @@ def do_run(self, line: str) -> None: print(f"Error: {e}") return - # Create a context for this run with overrides - run_context = self._fc.FrontendCore( - initialization_scripts=resolved_scripts, + # Create a context for this run with per-command overrides, + # inheriting config_file, database, and env_files from startup. + run_context = self._rebuild_context( initializer_names=args["initializers"], - log_level=args["log_level"] if args["log_level"] else self.default_log_level, + initialization_scripts=resolved_scripts, + log_level=args["log_level"], ) - # Use the existing registries (don't reinitialize) - run_context._scenario_registry = self.context._scenario_registry - run_context._initializer_registry = self.context._initializer_registry - run_context._initialized = True try: result = asyncio.run( diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index eca05ec159..f948024476 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -208,7 +208,7 @@ def test_do_list_targets_with_initializers( mock_print_targets: AsyncMock, shell, ): - """Test do_list_targets with --initializers creates a new context.""" + """Test do_list_targets with --initializers uses _rebuild_context.""" s, ctx, _ = shell mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} mock_run_context = MagicMock() @@ -217,14 +217,13 @@ def test_do_list_targets_with_initializers( s.do_list_targets("--initializers target") mock_parse.assert_called_once_with(args_string="--initializers target") - mock_fc_class.assert_called_once_with( - initialization_scripts=None, - initializer_names=["target"], - log_level=s.default_log_level, - ) - assert mock_run_context._scenario_registry == ctx._scenario_registry - assert mock_run_context._initializer_registry == ctx._initializer_registry + # _rebuild_context passes _context_kwargs plus overrides to FrontendCore + call_kwargs = mock_fc_class.call_args[1] + assert call_kwargs["initializer_names"] == ["target"] + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["log_level"] == s.default_log_level assert mock_run_context._initialized is True + assert mock_run_context._silent_reinit is True mock_print_targets.assert_called_once_with(context=mock_run_context) @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) @@ -696,6 +695,101 @@ def test_default_unknown_command(self, shell, capsys): assert "Unknown command" in captured.out +class TestRebuildContext: + """Tests for _rebuild_context helper method.""" + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_propagates_startup_config(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context passes config_file, database, env_files from startup kwargs.""" + s, ctx, _ = shell + s._context_kwargs = { + "config_file": Path("/my/config.yaml"), + "database": "InMemory", + "env_files": [Path("/my/.env")], + } + mock_derived = MagicMock() + mock_fc_class.return_value = mock_derived + + result = s._rebuild_context( + initializer_names=["target"], + initialization_scripts=None, + log_level=logging.DEBUG, + ) + + call_kwargs = mock_fc_class.call_args[1] + assert call_kwargs["config_file"] == Path("/my/config.yaml") + assert call_kwargs["database"] == "InMemory" + assert call_kwargs["env_files"] == [Path("/my/.env")] + assert call_kwargs["initializer_names"] == ["target"] + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["log_level"] == logging.DEBUG + assert result is mock_derived + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_shares_registries(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context shares scenario and initializer registries from shell context.""" + s, ctx, _ = shell + mock_derived = MagicMock() + mock_fc_class.return_value = mock_derived + + s._rebuild_context(initializer_names=None) + + assert mock_derived._scenario_registry == ctx._scenario_registry + assert mock_derived._initializer_registry == ctx._initializer_registry + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_sets_initialized_and_silent(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context sets _initialized and _silent_reinit on the derived context.""" + s, ctx, _ = shell + mock_derived = MagicMock() + mock_fc_class.return_value = mock_derived + + s._rebuild_context(initializer_names=None) + + assert mock_derived._initialized is True + assert mock_derived._silent_reinit is True + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_defaults_log_level_to_shell_default(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context uses default_log_level when log_level is None.""" + s, ctx, _ = shell + s.default_log_level = logging.ERROR + mock_fc_class.return_value = MagicMock() + + s._rebuild_context(initializer_names=None, log_level=None) + + call_kwargs = mock_fc_class.call_args[1] + assert call_kwargs["log_level"] == logging.ERROR + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_overrides_do_not_mutate_context_kwargs(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context does not modify the original _context_kwargs dict.""" + s, ctx, _ = shell + s._context_kwargs = {"config_file": Path("/original.yaml")} + original_kwargs = dict(s._context_kwargs) + mock_fc_class.return_value = MagicMock() + + s._rebuild_context(initializer_names=["new_init"], initialization_scripts=[Path("/script.py")]) + + assert s._context_kwargs == original_kwargs + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_with_empty_startup_kwargs(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context works when shell was started with no extra kwargs.""" + s, ctx, _ = shell + s._context_kwargs = {} + mock_fc_class.return_value = MagicMock() + + s._rebuild_context(initializer_names=["target"]) + + call_kwargs = mock_fc_class.call_args[1] + assert call_kwargs["initializer_names"] == ["target"] + # config_file, database, env_files should not be in kwargs + assert "config_file" not in call_kwargs + assert "database" not in call_kwargs + assert "env_files" not in call_kwargs + + class TestMain: """Tests for main function.""" From 105ef1f2262162883d0c5a54b2a30c19ff8cd4f0 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 8 Apr 2026 12:32:57 -0700 Subject: [PATCH 3/8] copilot doc suggestion --- pyrit/cli/pyrit_scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index f6f77bea2b..9a8ca771c4 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -196,7 +196,7 @@ def main(args: Optional[list[str]] = None) -> int: return asyncio.run(frontend_core.print_initializers_list_async(context=context)) if parsed_args.list_targets: - # Need initializers to populate target registry + # Need initializers or initialization scripts to populate the target registry initialization_scripts = None if parsed_args.initialization_scripts: try: From 8e90714b004d41f897ebecfc7d299cf638fefb57 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Wed, 8 Apr 2026 12:51:24 -0700 Subject: [PATCH 4/8] small check for not None --- pyrit/cli/pyrit_shell.py | 6 ++++-- tests/unit/cli/test_pyrit_shell.py | 22 ++++++++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index d3fb73f5f2..83f5405090 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -190,8 +190,10 @@ def _rebuild_context( frontend_core.FrontendCore: A new context ready for use in a command. """ cmd_kwargs = dict(self._context_kwargs) - cmd_kwargs["initializer_names"] = initializer_names - cmd_kwargs["initialization_scripts"] = initialization_scripts + if initializer_names is not None: + cmd_kwargs["initializer_names"] = initializer_names + if initialization_scripts is not None: + cmd_kwargs["initialization_scripts"] = initialization_scripts cmd_kwargs["log_level"] = log_level if log_level is not None else self.default_log_level cmd_context = self._fc.FrontendCore(**cmd_kwargs) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index f948024476..70c55cec10 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -220,7 +220,8 @@ def test_do_list_targets_with_initializers( # _rebuild_context passes _context_kwargs plus overrides to FrontendCore call_kwargs = mock_fc_class.call_args[1] assert call_kwargs["initializer_names"] == ["target"] - assert call_kwargs["initialization_scripts"] is None + # initialization_scripts=None should not appear in kwargs (preserves startup value) + assert "initialization_scripts" not in call_kwargs assert call_kwargs["log_level"] == s.default_log_level assert mock_run_context._initialized is True assert mock_run_context._silent_reinit is True @@ -721,10 +722,27 @@ def test_rebuild_context_propagates_startup_config(self, mock_fc_class: MagicMoc assert call_kwargs["database"] == "InMemory" assert call_kwargs["env_files"] == [Path("/my/.env")] assert call_kwargs["initializer_names"] == ["target"] - assert call_kwargs["initialization_scripts"] is None + # initialization_scripts=None should NOT override startup kwargs + assert "initialization_scripts" not in call_kwargs assert call_kwargs["log_level"] == logging.DEBUG assert result is mock_derived + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_rebuild_context_none_does_not_override_startup_kwargs(self, mock_fc_class: MagicMock, shell): + """Test _rebuild_context preserves startup initializer_names/scripts when None is passed.""" + s, ctx, _ = shell + s._context_kwargs = { + "initializer_names": ["startup_init"], + "initialization_scripts": [Path("/startup_script.py")], + } + mock_fc_class.return_value = MagicMock() + + s._rebuild_context(initializer_names=None, initialization_scripts=None) + + call_kwargs = mock_fc_class.call_args[1] + assert call_kwargs["initializer_names"] == ["startup_init"] + assert call_kwargs["initialization_scripts"] == [Path("/startup_script.py")] + @patch("pyrit.cli.frontend_core.FrontendCore") def test_rebuild_context_shares_registries(self, mock_fc_class: MagicMock, shell): """Test _rebuild_context shares scenario and initializer registries from shell context.""" From 9b763db1dffc5c53e60017d120087868b328d156 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Thu, 9 Apr 2026 10:01:31 -0700 Subject: [PATCH 5/8] PR feedback --- pyrit/cli/frontend_core.py | 58 ++++++++++++ pyrit/cli/pyrit_shell.py | 42 +-------- tests/unit/cli/test_frontend_core.py | 123 +++++++++++++++++++++++- tests/unit/cli/test_pyrit_shell.py | 134 ++------------------------- 4 files changed, 188 insertions(+), 169 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 8bcf5f4655..9fd454dce4 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -188,6 +188,64 @@ async def initialize_async(self) -> None: self._initialized = True + def with_overrides( + self, + *, + initializer_names: Optional[list[Any]] = None, + initialization_scripts: Optional[list[Path]] = None, + log_level: Optional[int] = None, + ) -> FrontendCore: + """ + Create a derived FrontendCore with per-command overrides. + + Copies inherited state (database, env_files, operator, operation, config) + from this instance and applies the given overrides. Shares registries + with the parent to avoid redundant re-discovery and skips re-reading + config files. + + Args: + initializer_names (Optional[list[Any]]): Per-command initializer overrides. + Each entry can be a string name or a dict with 'name' and optional 'args'. + None keeps the parent's value. + initialization_scripts (Optional[list[Path]]): Per-command script overrides. + None keeps the parent's value. + log_level (Optional[int]): Per-command log level override. + None keeps the parent's value. + + Returns: + FrontendCore: A new context ready for use, without re-reading config files. + """ + derived = object.__new__(FrontendCore) + + # Inherit from parent + derived._database = self._database + derived._env_files = self._env_files + derived._operator = self._operator + derived._operation = self._operation + derived._config = self._config + + # Apply overrides or inherit + derived._log_level = log_level if log_level is not None else self._log_level + + if initializer_names is not None: + loader = ConfigurationLoader.from_dict({"initializers": initializer_names}) + derived._initializer_configs = loader._initializer_configs + else: + derived._initializer_configs = self._initializer_configs + + if initialization_scripts is not None: + derived._initialization_scripts = initialization_scripts + else: + derived._initialization_scripts = self._initialization_scripts + + # Share registries (singletons, no need to re-discover) + derived._scenario_registry = self._scenario_registry + derived._initializer_registry = self._initializer_registry + derived._initialized = True + derived._silent_reinit = True + + return derived + @property def scenario_registry(self) -> ScenarioRegistry: """ diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 83f5405090..ae38edcde8 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -166,44 +166,6 @@ def _ensure_initialized(self) -> None: self._init_complete.wait() self._raise_init_error() - def _rebuild_context( - self, - *, - initializer_names: Optional[list[Any]] = None, - initialization_scripts: Optional[list[Path]] = None, - log_level: Optional[int] = None, - ) -> frontend_core.FrontendCore: - """ - Create a per-command FrontendCore that inherits the shell's startup config. - - Propagates config_file, database, and env_files from the shell's startup - kwargs, then overrides initializer_names, initialization_scripts, and - log_level for the current command. Shares registries with the shell - context to avoid redundant re-discovery. - - Args: - initializer_names (Optional[list[Any]]): Per-command initializer overrides. - initialization_scripts (Optional[list[Path]]): Per-command script overrides. - log_level (Optional[int]): Per-command log level override. - - Returns: - frontend_core.FrontendCore: A new context ready for use in a command. - """ - cmd_kwargs = dict(self._context_kwargs) - if initializer_names is not None: - cmd_kwargs["initializer_names"] = initializer_names - if initialization_scripts is not None: - cmd_kwargs["initialization_scripts"] = initialization_scripts - cmd_kwargs["log_level"] = log_level if log_level is not None else self.default_log_level - - cmd_context = self._fc.FrontendCore(**cmd_kwargs) - cmd_context._scenario_registry = self.context._scenario_registry - cmd_context._initializer_registry = self.context._initializer_registry - cmd_context._initialized = True - cmd_context._silent_reinit = True - - return cmd_context - def cmdloop(self, intro: Optional[str] = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: @@ -275,7 +237,7 @@ def do_list_targets(self, arg: str) -> None: resolved_scripts = self._fc.resolve_initialization_scripts( script_paths=args["initialization_scripts"] ) - list_targets_context = self._rebuild_context( + list_targets_context = self.context.with_overrides( initialization_scripts=resolved_scripts, initializer_names=args["initializers"], ) @@ -371,7 +333,7 @@ def do_run(self, line: str) -> None: # Create a context for this run with per-command overrides, # inheriting config_file, database, and env_files from startup. - run_context = self._rebuild_context( + run_context = self.context.with_overrides( initializer_names=args["initializers"], initialization_scripts=resolved_scripts, log_level=args["log_level"], diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index e3cbbca8d0..fc0a12440d 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -973,9 +973,126 @@ def test_parse_run_arguments_target_with_other_args(self): args_string="test_scenario --target my_target --initializers init1 --max-concurrency 5" ) - assert result["target"] == "my_target" - assert result["initializers"] == ["init1"] - assert result["max_concurrency"] == 5 + +class TestWithOverrides: + """Tests for FrontendCore.with_overrides method.""" + + def _make_initialized_parent(self) -> frontend_core.FrontendCore: + """Create a fully-initialized FrontendCore for testing with_overrides.""" + parent = frontend_core.FrontendCore( + database=frontend_core.IN_MEMORY, + initializer_names=["parent_init"], + log_level=logging.WARNING, + ) + parent._scenario_registry = MagicMock() + parent._initializer_registry = MagicMock() + parent._initialized = True + parent._silent_reinit = True + return parent + + def test_with_overrides_inherits_fields(self): + """Test that derived context inherits database, env_files, operator, operation, config.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._database == parent._database + assert derived._env_files == parent._env_files + assert derived._operator == parent._operator + assert derived._operation == parent._operation + assert derived._config is parent._config + + def test_with_overrides_shares_registries(self): + """Test that derived context shares scenario and initializer registries.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._scenario_registry is parent._scenario_registry + assert derived._initializer_registry is parent._initializer_registry + + def test_with_overrides_sets_initialized_and_silent(self): + """Test that derived context is marked initialized with silent reinit.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._initialized is True + assert derived._silent_reinit is True + + def test_with_overrides_none_keeps_parent_values(self): + """Test that passing None for all overrides keeps parent's values.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides( + initializer_names=None, + initialization_scripts=None, + log_level=None, + ) + + assert derived._initializer_configs == parent._initializer_configs + assert derived._initialization_scripts == parent._initialization_scripts + assert derived._log_level == parent._log_level + + def test_with_overrides_initializer_names(self): + """Test that initializer_names override normalizes to InitializerConfig objects.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(initializer_names=["target", "dataset"]) + + assert derived._initializer_configs is not None + names = [ic.name for ic in derived._initializer_configs] + assert names == ["target", "dataset"] + # Parent should still have original + assert [ic.name for ic in parent._initializer_configs] == ["parent_init"] + + def test_with_overrides_initializer_names_dict(self): + """Test initializer_names with dict entries (name + args).""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(initializer_names=[{"name": "target", "args": {"tags": "default"}}]) + + assert derived._initializer_configs is not None + assert len(derived._initializer_configs) == 1 + assert derived._initializer_configs[0].name == "target" + assert derived._initializer_configs[0].args == {"tags": "default"} + + def test_with_overrides_initialization_scripts(self): + """Test that initialization_scripts override replaces parent's scripts.""" + parent = self._make_initialized_parent() + new_scripts = [Path("/new/script.py")] + + derived = parent.with_overrides(initialization_scripts=new_scripts) + + assert derived._initialization_scripts == new_scripts + # Parent should be unchanged + assert parent._initialization_scripts != new_scripts + + def test_with_overrides_log_level(self): + """Test that log_level override replaces parent's log level.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(log_level=logging.DEBUG) + + assert derived._log_level == logging.DEBUG + assert parent._log_level == logging.WARNING + + def test_with_overrides_does_not_mutate_parent(self): + """Test that with_overrides does not modify the parent context.""" + parent = self._make_initialized_parent() + original_configs = parent._initializer_configs + original_log_level = parent._log_level + original_scripts = parent._initialization_scripts + + parent.with_overrides( + initializer_names=["new_init"], + initialization_scripts=[Path("/new.py")], + log_level=logging.DEBUG, + ) + + assert parent._initializer_configs is original_configs + assert parent._log_level == original_log_level + assert parent._initialization_scripts is original_scripts def test_parse_run_arguments_target_missing_value(self): """Test parsing --target without a value raises ValueError.""" diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 70c55cec10..4f562f9917 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -199,33 +199,27 @@ def test_do_list_targets_no_args(self, mock_print_targets: AsyncMock, shell): mock_print_targets.assert_called_once_with(context=ctx) @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.frontend_core.FrontendCore") @patch("pyrit.cli.frontend_core.parse_list_targets_arguments") def test_do_list_targets_with_initializers( self, mock_parse: MagicMock, - mock_fc_class: MagicMock, mock_print_targets: AsyncMock, shell, ): - """Test do_list_targets with --initializers uses _rebuild_context.""" + """Test do_list_targets with --initializers uses context.with_overrides.""" s, ctx, _ = shell mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} - mock_run_context = MagicMock() - mock_fc_class.return_value = mock_run_context + mock_derived = MagicMock() + ctx.with_overrides = MagicMock(return_value=mock_derived) s.do_list_targets("--initializers target") mock_parse.assert_called_once_with(args_string="--initializers target") - # _rebuild_context passes _context_kwargs plus overrides to FrontendCore - call_kwargs = mock_fc_class.call_args[1] - assert call_kwargs["initializer_names"] == ["target"] - # initialization_scripts=None should not appear in kwargs (preserves startup value) - assert "initialization_scripts" not in call_kwargs - assert call_kwargs["log_level"] == s.default_log_level - assert mock_run_context._initialized is True - assert mock_run_context._silent_reinit is True - mock_print_targets.assert_called_once_with(context=mock_run_context) + ctx.with_overrides.assert_called_once_with( + initialization_scripts=None, + initializer_names=["target"], + ) + mock_print_targets.assert_called_once_with(context=mock_derived) @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) def test_do_list_targets_with_exception(self, mock_print_targets: AsyncMock, shell, capsys): @@ -696,118 +690,6 @@ def test_default_unknown_command(self, shell, capsys): assert "Unknown command" in captured.out -class TestRebuildContext: - """Tests for _rebuild_context helper method.""" - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_propagates_startup_config(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context passes config_file, database, env_files from startup kwargs.""" - s, ctx, _ = shell - s._context_kwargs = { - "config_file": Path("/my/config.yaml"), - "database": "InMemory", - "env_files": [Path("/my/.env")], - } - mock_derived = MagicMock() - mock_fc_class.return_value = mock_derived - - result = s._rebuild_context( - initializer_names=["target"], - initialization_scripts=None, - log_level=logging.DEBUG, - ) - - call_kwargs = mock_fc_class.call_args[1] - assert call_kwargs["config_file"] == Path("/my/config.yaml") - assert call_kwargs["database"] == "InMemory" - assert call_kwargs["env_files"] == [Path("/my/.env")] - assert call_kwargs["initializer_names"] == ["target"] - # initialization_scripts=None should NOT override startup kwargs - assert "initialization_scripts" not in call_kwargs - assert call_kwargs["log_level"] == logging.DEBUG - assert result is mock_derived - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_none_does_not_override_startup_kwargs(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context preserves startup initializer_names/scripts when None is passed.""" - s, ctx, _ = shell - s._context_kwargs = { - "initializer_names": ["startup_init"], - "initialization_scripts": [Path("/startup_script.py")], - } - mock_fc_class.return_value = MagicMock() - - s._rebuild_context(initializer_names=None, initialization_scripts=None) - - call_kwargs = mock_fc_class.call_args[1] - assert call_kwargs["initializer_names"] == ["startup_init"] - assert call_kwargs["initialization_scripts"] == [Path("/startup_script.py")] - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_shares_registries(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context shares scenario and initializer registries from shell context.""" - s, ctx, _ = shell - mock_derived = MagicMock() - mock_fc_class.return_value = mock_derived - - s._rebuild_context(initializer_names=None) - - assert mock_derived._scenario_registry == ctx._scenario_registry - assert mock_derived._initializer_registry == ctx._initializer_registry - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_sets_initialized_and_silent(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context sets _initialized and _silent_reinit on the derived context.""" - s, ctx, _ = shell - mock_derived = MagicMock() - mock_fc_class.return_value = mock_derived - - s._rebuild_context(initializer_names=None) - - assert mock_derived._initialized is True - assert mock_derived._silent_reinit is True - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_defaults_log_level_to_shell_default(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context uses default_log_level when log_level is None.""" - s, ctx, _ = shell - s.default_log_level = logging.ERROR - mock_fc_class.return_value = MagicMock() - - s._rebuild_context(initializer_names=None, log_level=None) - - call_kwargs = mock_fc_class.call_args[1] - assert call_kwargs["log_level"] == logging.ERROR - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_overrides_do_not_mutate_context_kwargs(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context does not modify the original _context_kwargs dict.""" - s, ctx, _ = shell - s._context_kwargs = {"config_file": Path("/original.yaml")} - original_kwargs = dict(s._context_kwargs) - mock_fc_class.return_value = MagicMock() - - s._rebuild_context(initializer_names=["new_init"], initialization_scripts=[Path("/script.py")]) - - assert s._context_kwargs == original_kwargs - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_rebuild_context_with_empty_startup_kwargs(self, mock_fc_class: MagicMock, shell): - """Test _rebuild_context works when shell was started with no extra kwargs.""" - s, ctx, _ = shell - s._context_kwargs = {} - mock_fc_class.return_value = MagicMock() - - s._rebuild_context(initializer_names=["target"]) - - call_kwargs = mock_fc_class.call_args[1] - assert call_kwargs["initializer_names"] == ["target"] - # config_file, database, env_files should not be in kwargs - assert "config_file" not in call_kwargs - assert "database" not in call_kwargs - assert "env_files" not in call_kwargs - - class TestMain: """Tests for main function.""" From 7d3b50c05e9cbeeaeb8b9257c4906e780015d017 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Thu, 9 Apr 2026 10:19:47 -0700 Subject: [PATCH 6/8] remove ._config --- pyrit/cli/frontend_core.py | 4 ---- tests/unit/cli/test_frontend_core.py | 3 +-- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 9fd454dce4..bc9519052a 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -136,9 +136,6 @@ def __init__( ) from e raise - # Store the merged configuration - self._config = config - # Extract values from config for internal use # Use canonical mapping from configuration_loader self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] @@ -222,7 +219,6 @@ def with_overrides( derived._env_files = self._env_files derived._operator = self._operator derived._operation = self._operation - derived._config = self._config # Apply overrides or inherit derived._log_level = log_level if log_level is not None else self._log_level diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index fc0a12440d..422bc6a2d3 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -991,7 +991,7 @@ def _make_initialized_parent(self) -> frontend_core.FrontendCore: return parent def test_with_overrides_inherits_fields(self): - """Test that derived context inherits database, env_files, operator, operation, config.""" + """Test that derived context inherits database, env_files, operator, operation.""" parent = self._make_initialized_parent() derived = parent.with_overrides() @@ -1000,7 +1000,6 @@ def test_with_overrides_inherits_fields(self): assert derived._env_files == parent._env_files assert derived._operator == parent._operator assert derived._operation == parent._operation - assert derived._config is parent._config def test_with_overrides_shares_registries(self): """Test that derived context shares scenario and initializer registries.""" From 6d59facca40b9f0261db0afdea06915f11a01eb9 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Thu, 9 Apr 2026 18:01:18 -0700 Subject: [PATCH 7/8] cli_arg_spec --- pyrit/cli/_cli_args.py | 273 ++++++++++++++++----------- tests/unit/cli/test_frontend_core.py | 89 +++++++++ 2 files changed, 255 insertions(+), 107 deletions(-) diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 80e383dfe5..e0e8ab7129 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -14,6 +14,7 @@ from __future__ import annotations import argparse +import dataclasses import inspect import json import logging @@ -342,6 +343,168 @@ def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: return name +# --------------------------------------------------------------------------- +# Shell argument specification +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class _ArgSpec: + """ + Declarative specification for a single shell-mode CLI argument. + + Each instance describes one CLI flag (or set of aliases) and how its + value(s) should be collected and validated. A list of ``_ArgSpec`` objects + is passed to ``_parse_shell_arguments`` which handles the actual parsing + loop. Adding a new flag only requires defining a new ``_ArgSpec`` + constant, not editing any parsing logic. + + Attributes: + flags: CLI flag strings that trigger this argument (e.g., ``["--strategies", "-s"]``). + result_key: Key name in the returned dict (e.g., ``"scenario_strategies"``). + multi_value: If True, collect values until the next flag. + If False, consume exactly one value. + parser: Optional callable to transform each raw string value. + Applied per-item for multi-value args, or to the single value otherwise. + """ + + flags: list[str] + result_key: str + multi_value: bool = False + parser: Callable[[str], Any] | None = None + + +_INITIALIZERS_ARG = _ArgSpec( + flags=["--initializers"], + result_key="initializers", + multi_value=True, + parser=_parse_initializer_arg, +) +_INIT_SCRIPTS_ARG = _ArgSpec( + flags=["--initialization-scripts"], + result_key="initialization_scripts", + multi_value=True, +) + +_STRATEGIES_ARG = _ArgSpec( + flags=["--strategies", "-s"], + result_key="scenario_strategies", + multi_value=True, +) +_MAX_CONCURRENCY_ARG = _ArgSpec( + flags=["--max-concurrency"], + result_key="max_concurrency", + parser=lambda v: validate_integer(v, name="--max-concurrency", min_value=1), +) +_MAX_RETRIES_ARG = _ArgSpec( + flags=["--max-retries"], + result_key="max_retries", + parser=lambda v: validate_integer(v, name="--max-retries", min_value=0), +) +_MEMORY_LABELS_ARG = _ArgSpec( + flags=["--memory-labels"], + result_key="memory_labels", + parser=parse_memory_labels, +) +_LOG_LEVEL_ARG = _ArgSpec( + flags=["--log-level"], + result_key="log_level", + parser=lambda v: validate_log_level(log_level=v), +) +_DATASET_NAMES_ARG = _ArgSpec( + flags=["--dataset-names"], + result_key="dataset_names", + multi_value=True, +) +_MAX_DATASET_SIZE_ARG = _ArgSpec( + flags=["--max-dataset-size"], + result_key="max_dataset_size", + parser=lambda v: validate_integer(v, name="--max-dataset-size", min_value=1), +) +_TARGET_ARG = _ArgSpec( + flags=["--target"], + result_key="target", +) + +_RUN_ARG_SPECS: list[_ArgSpec] = [ + _INITIALIZERS_ARG, + _INIT_SCRIPTS_ARG, + _STRATEGIES_ARG, + _MAX_CONCURRENCY_ARG, + _MAX_RETRIES_ARG, + _MEMORY_LABELS_ARG, + _LOG_LEVEL_ARG, + _DATASET_NAMES_ARG, + _MAX_DATASET_SIZE_ARG, + _TARGET_ARG, +] + +_LIST_TARGETS_ARG_SPECS: list[_ArgSpec] = [ + _INITIALIZERS_ARG, + _INIT_SCRIPTS_ARG, +] + + +# --------------------------------------------------------------------------- +# Generic shell argument parser +# --------------------------------------------------------------------------- + + +def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> dict[str, Any]: + """ + Parse a list of shell tokens against a set of argument specifications. + + Each ``_ArgSpec`` in *arg_specs* declares how its flag(s) should be handled + (multi-value collection vs. single-value consumption) and what validation + or transformation to apply. + + Args: + parts: Token list (already split on whitespace, positional args removed). + arg_specs: Argument specifications that this command accepts. + + Returns: + Dictionary mapping each spec's ``result_key`` to its parsed value, + defaulting to ``None`` for arguments not present in *parts*. + + Raises: + ValueError: On unknown flags or missing values. + """ + # Build lookup: flag string → spec + flag_to_spec: dict[str, _ArgSpec] = {} + for spec in arg_specs: + for flag in spec.flags: + flag_to_spec[flag] = spec + + # Initialise result with None defaults + result: dict[str, Any] = {spec.result_key: None for spec in arg_specs} + + i = 0 + while i < len(parts): + token = parts[i] + spec = flag_to_spec.get(token) + + if spec is None: + raise ValueError(f"Unknown argument: {token}") + + i += 1 + + if spec.multi_value: + values: list[Any] = [] + while i < len(parts) and not parts[i].startswith("--"): + item = spec.parser(parts[i]) if spec.parser else parts[i] + values.append(item) + i += 1 + result[spec.result_key] = values + else: + if i >= len(parts): + raise ValueError(f"{spec.flags[0]} requires a value") + raw = parts[i] + result[spec.result_key] = spec.parser(raw) if spec.parser else raw + i += 1 + + return result + + def parse_run_arguments(*, args_string: str) -> dict[str, Any]: """ Parse run command arguments from a string (for shell mode). @@ -371,89 +534,8 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: if not parts: raise ValueError("No scenario name provided") - result: dict[str, Any] = { - "scenario_name": parts[0], - "initializers": None, - "initialization_scripts": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - i = 1 - while i < len(parts): - if parts[i] == "--initializers": - # Collect initializers until next flag, parsing name:key=val syntax - result["initializers"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initializers"].append(_parse_initializer_arg(parts[i])) - i += 1 - elif parts[i] == "--initialization-scripts": - # Collect script paths until next flag - result["initialization_scripts"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initialization_scripts"].append(parts[i]) - i += 1 - elif parts[i] in ("--strategies", "-s"): - # Collect strategies until next flag - result["scenario_strategies"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": - result["scenario_strategies"].append(parts[i]) - i += 1 - elif parts[i] == "--max-concurrency": - i += 1 - if i >= len(parts): - raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) - i += 1 - elif parts[i] == "--max-retries": - i += 1 - if i >= len(parts): - raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) - i += 1 - elif parts[i] == "--memory-labels": - i += 1 - if i >= len(parts): - raise ValueError("--memory-labels requires a value") - result["memory_labels"] = parse_memory_labels(parts[i]) - i += 1 - elif parts[i] == "--log-level": - i += 1 - if i >= len(parts): - raise ValueError("--log-level requires a value") - result["log_level"] = validate_log_level(log_level=parts[i]) - i += 1 - elif parts[i] == "--dataset-names": - # Collect dataset names until next flag - result["dataset_names"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["dataset_names"].append(parts[i]) - i += 1 - elif parts[i] == "--max-dataset-size": - i += 1 - if i >= len(parts): - raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) - i += 1 - elif parts[i] == "--target": - i += 1 - if i >= len(parts): - raise ValueError("--target requires a value") - result["target"] = parts[i] - i += 1 - else: - raise ValueError(f"Unknown argument: {parts[i]}") - + result = _parse_shell_arguments(parts=parts[1:], arg_specs=_RUN_ARG_SPECS) + result["scenario_name"] = parts[0] return result @@ -473,30 +555,7 @@ def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]: ValueError: If parsing or validation fails. """ parts = args_string.split() - - result: dict[str, Any] = { - "initializers": None, - "initialization_scripts": None, - } - - i = 0 - while i < len(parts): - if parts[i] == "--initializers": - result["initializers"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initializers"].append(_parse_initializer_arg(parts[i])) - i += 1 - elif parts[i] == "--initialization-scripts": - result["initialization_scripts"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initialization_scripts"].append(parts[i]) - i += 1 - else: - raise ValueError(f"Unknown argument: {parts[i]}") - - return result + return _parse_shell_arguments(parts=parts, arg_specs=_LIST_TARGETS_ARG_SPECS) # --------------------------------------------------------------------------- diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 422bc6a2d3..bb5febe4f5 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -12,6 +12,7 @@ import pytest from pyrit.cli import frontend_core +from pyrit.cli._cli_args import _ArgSpec, _parse_shell_arguments from pyrit.registry import InitializerMetadata, ScenarioMetadata @@ -554,6 +555,94 @@ def test_colon_but_no_params_returns_string(self) -> None: assert result == "target" +class TestParseShellArguments: + """Tests for the generic _parse_shell_arguments function.""" + + def test_empty_parts_returns_none_defaults(self): + """Test that empty input returns None for all result keys.""" + spec = _ArgSpec(flags=["--foo"], result_key="foo") + result = _parse_shell_arguments(parts=[], arg_specs=[spec]) + assert result == {"foo": None} + + def test_single_value_arg(self): + """Test parsing a single-value argument.""" + spec = _ArgSpec(flags=["--name"], result_key="name") + result = _parse_shell_arguments(parts=["--name", "alice"], arg_specs=[spec]) + assert result["name"] == "alice" + + def test_single_value_with_parser(self): + """Test that single-value parser is applied.""" + spec = _ArgSpec(flags=["--count"], result_key="count", parser=int) + result = _parse_shell_arguments(parts=["--count", "42"], arg_specs=[spec]) + assert result["count"] == 42 + + def test_single_value_missing_raises(self): + """Test that missing value for single-value arg raises ValueError.""" + spec = _ArgSpec(flags=["--name"], result_key="name") + with pytest.raises(ValueError, match="--name requires a value"): + _parse_shell_arguments(parts=["--name"], arg_specs=[spec]) + + def test_multi_value_arg(self): + """Test collecting multiple values until next flag.""" + spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + result = _parse_shell_arguments(parts=["--items", "a", "b", "c"], arg_specs=[spec]) + assert result["items"] == ["a", "b", "c"] + + def test_multi_value_stops_at_next_flag(self): + """Test that multi-value collection stops at the next known flag.""" + items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + name_spec = _ArgSpec(flags=["--name"], result_key="name") + result = _parse_shell_arguments( + parts=["--items", "a", "b", "--name", "alice"], + arg_specs=[items_spec, name_spec], + ) + assert result["items"] == ["a", "b"] + assert result["name"] == "alice" + + def test_multi_value_stops_at_short_flag_alias(self): + """Test that multi-value collection stops at a short flag alias like -s.""" + long_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + short_spec = _ArgSpec(flags=["-s", "--short"], result_key="short", multi_value=True) + result = _parse_shell_arguments( + parts=["--items", "a", "b", "-s", "x"], + arg_specs=[long_spec, short_spec], + ) + assert result["items"] == ["a", "b"] + assert result["short"] == ["x"] + + def test_multi_value_with_parser(self): + """Test that parser transforms each collected value.""" + spec = _ArgSpec(flags=["--nums"], result_key="nums", multi_value=True, parser=int) + result = _parse_shell_arguments(parts=["--nums", "1", "2", "3"], arg_specs=[spec]) + assert result["nums"] == [1, 2, 3] + + def test_multi_value_empty_list_when_no_values(self): + """Test that multi-value arg with no values produces an empty list.""" + items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + name_spec = _ArgSpec(flags=["--name"], result_key="name") + result = _parse_shell_arguments( + parts=["--items", "--name", "alice"], + arg_specs=[items_spec, name_spec], + ) + assert result["items"] == [] + assert result["name"] == "alice" + + def test_unknown_flag_raises(self): + """Test that an unknown flag raises ValueError.""" + spec = _ArgSpec(flags=["--known"], result_key="known") + with pytest.raises(ValueError, match="Unknown argument: --unknown"): + _parse_shell_arguments(parts=["--unknown"], arg_specs=[spec]) + + def test_multiple_specs_all_none_when_unused(self): + """Test that unused specs default to None.""" + specs = [ + _ArgSpec(flags=["--a"], result_key="a"), + _ArgSpec(flags=["--b"], result_key="b", multi_value=True), + ] + result = _parse_shell_arguments(parts=[], arg_specs=specs) + assert result == {"a": None, "b": None} + + class TestParseRunArguments: """Tests for parse_run_arguments function.""" From f7d1eb7a5006a718dce3db921d6c4fce36554294 Mon Sep 17 00:00:00 2001 From: jsong468 Date: Fri, 10 Apr 2026 14:38:40 -0700 Subject: [PATCH 8/8] pr feedback --- pyrit/cli/_cli_args.py | 8 ++++++-- tests/unit/cli/test_frontend_core.py | 15 +++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index e0e8ab7129..2131f72a4a 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -484,16 +484,20 @@ def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> di spec = flag_to_spec.get(token) if spec is None: - raise ValueError(f"Unknown argument: {token}") + valid = sorted(flag_to_spec.keys()) + raise ValueError(f"Unknown argument: {token}. Valid arguments: {', '.join(valid)}") i += 1 if spec.multi_value: values: list[Any] = [] - while i < len(parts) and not parts[i].startswith("--"): + # Collect values until the next flag (whether valid or invalid) + while i < len(parts) and not (parts[i].startswith("--") or parts[i] in flag_to_spec): item = spec.parser(parts[i]) if spec.parser else parts[i] values.append(item) i += 1 + if len(values) == 0: + raise ValueError(f"{spec.flags[0]} requires at least one value") result[spec.result_key] = values else: if i >= len(parts): diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index bb5febe4f5..2507f4eb4c 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -616,16 +616,15 @@ def test_multi_value_with_parser(self): result = _parse_shell_arguments(parts=["--nums", "1", "2", "3"], arg_specs=[spec]) assert result["nums"] == [1, 2, 3] - def test_multi_value_empty_list_when_no_values(self): - """Test that multi-value arg with no values produces an empty list.""" + def test_multi_value_no_values_raises(self): + """Test that multi-value arg with no values raises ValueError.""" items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) name_spec = _ArgSpec(flags=["--name"], result_key="name") - result = _parse_shell_arguments( - parts=["--items", "--name", "alice"], - arg_specs=[items_spec, name_spec], - ) - assert result["items"] == [] - assert result["name"] == "alice" + with pytest.raises(ValueError, match="--items requires at least one value"): + _parse_shell_arguments( + parts=["--items", "--name", "alice"], + arg_specs=[items_spec, name_spec], + ) def test_unknown_flag_raises(self): """Test that an unknown flag raises ValueError."""