diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index bd5a3d40fe..a76d286b27 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -255,7 +255,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo "Commands:", " • list-scenarios - See all available scenarios", " • list-initializers - See all available initializers", - " • list-targets - See all available targets in the registry", + " • list-targets [opts] - See all available targets in the registry", " • run [opts] - Execute a security scenario", " • scenario-history - View your session history", " • print-scenario [N] - Display detailed results", diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 1264956ccb..2131f72a4a 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -14,6 +14,7 @@ from __future__ import annotations import argparse +import dataclasses import inspect import json import logging @@ -342,6 +343,172 @@ def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: return name +# --------------------------------------------------------------------------- +# Shell argument specification +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class _ArgSpec: + """ + Declarative specification for a single shell-mode CLI argument. + + Each instance describes one CLI flag (or set of aliases) and how its + value(s) should be collected and validated. A list of ``_ArgSpec`` objects + is passed to ``_parse_shell_arguments`` which handles the actual parsing + loop. Adding a new flag only requires defining a new ``_ArgSpec`` + constant, not editing any parsing logic. + + Attributes: + flags: CLI flag strings that trigger this argument (e.g., ``["--strategies", "-s"]``). + result_key: Key name in the returned dict (e.g., ``"scenario_strategies"``). + multi_value: If True, collect values until the next flag. + If False, consume exactly one value. + parser: Optional callable to transform each raw string value. + Applied per-item for multi-value args, or to the single value otherwise. + """ + + flags: list[str] + result_key: str + multi_value: bool = False + parser: Callable[[str], Any] | None = None + + +_INITIALIZERS_ARG = _ArgSpec( + flags=["--initializers"], + result_key="initializers", + multi_value=True, + parser=_parse_initializer_arg, +) +_INIT_SCRIPTS_ARG = _ArgSpec( + flags=["--initialization-scripts"], + result_key="initialization_scripts", + multi_value=True, +) + +_STRATEGIES_ARG = _ArgSpec( + flags=["--strategies", "-s"], + result_key="scenario_strategies", + multi_value=True, +) +_MAX_CONCURRENCY_ARG = _ArgSpec( + flags=["--max-concurrency"], + result_key="max_concurrency", + parser=lambda v: validate_integer(v, name="--max-concurrency", min_value=1), +) +_MAX_RETRIES_ARG = _ArgSpec( + flags=["--max-retries"], + result_key="max_retries", + parser=lambda v: validate_integer(v, name="--max-retries", min_value=0), +) +_MEMORY_LABELS_ARG = _ArgSpec( + flags=["--memory-labels"], + result_key="memory_labels", + parser=parse_memory_labels, +) +_LOG_LEVEL_ARG = _ArgSpec( + flags=["--log-level"], + result_key="log_level", + parser=lambda v: validate_log_level(log_level=v), +) +_DATASET_NAMES_ARG = _ArgSpec( + flags=["--dataset-names"], + result_key="dataset_names", + multi_value=True, +) +_MAX_DATASET_SIZE_ARG = _ArgSpec( + flags=["--max-dataset-size"], + result_key="max_dataset_size", + parser=lambda v: validate_integer(v, name="--max-dataset-size", min_value=1), +) +_TARGET_ARG = _ArgSpec( + flags=["--target"], + result_key="target", +) + +_RUN_ARG_SPECS: list[_ArgSpec] = [ + _INITIALIZERS_ARG, + _INIT_SCRIPTS_ARG, + _STRATEGIES_ARG, + _MAX_CONCURRENCY_ARG, + _MAX_RETRIES_ARG, + _MEMORY_LABELS_ARG, + _LOG_LEVEL_ARG, + _DATASET_NAMES_ARG, + _MAX_DATASET_SIZE_ARG, + _TARGET_ARG, +] + +_LIST_TARGETS_ARG_SPECS: list[_ArgSpec] = [ + _INITIALIZERS_ARG, + _INIT_SCRIPTS_ARG, +] + + +# --------------------------------------------------------------------------- +# Generic shell argument parser +# --------------------------------------------------------------------------- + + +def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> dict[str, Any]: + """ + Parse a list of shell tokens against a set of argument specifications. + + Each ``_ArgSpec`` in *arg_specs* declares how its flag(s) should be handled + (multi-value collection vs. single-value consumption) and what validation + or transformation to apply. + + Args: + parts: Token list (already split on whitespace, positional args removed). + arg_specs: Argument specifications that this command accepts. + + Returns: + Dictionary mapping each spec's ``result_key`` to its parsed value, + defaulting to ``None`` for arguments not present in *parts*. + + Raises: + ValueError: On unknown flags or missing values. + """ + # Build lookup: flag string → spec + flag_to_spec: dict[str, _ArgSpec] = {} + for spec in arg_specs: + for flag in spec.flags: + flag_to_spec[flag] = spec + + # Initialise result with None defaults + result: dict[str, Any] = {spec.result_key: None for spec in arg_specs} + + i = 0 + while i < len(parts): + token = parts[i] + spec = flag_to_spec.get(token) + + if spec is None: + valid = sorted(flag_to_spec.keys()) + raise ValueError(f"Unknown argument: {token}. Valid arguments: {', '.join(valid)}") + + i += 1 + + if spec.multi_value: + values: list[Any] = [] + # Collect values until the next flag (whether valid or invalid) + while i < len(parts) and not (parts[i].startswith("--") or parts[i] in flag_to_spec): + item = spec.parser(parts[i]) if spec.parser else parts[i] + values.append(item) + i += 1 + if len(values) == 0: + raise ValueError(f"{spec.flags[0]} requires at least one value") + result[spec.result_key] = values + else: + if i >= len(parts): + raise ValueError(f"{spec.flags[0]} requires a value") + raw = parts[i] + result[spec.result_key] = spec.parser(raw) if spec.parser else raw + i += 1 + + return result + + def parse_run_arguments(*, args_string: str) -> dict[str, Any]: """ Parse run command arguments from a string (for shell mode). @@ -371,92 +538,30 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: if not parts: raise ValueError("No scenario name provided") - result: dict[str, Any] = { - "scenario_name": parts[0], - "initializers": None, - "initialization_scripts": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - "target": None, - } - - i = 1 - while i < len(parts): - if parts[i] == "--initializers": - # Collect initializers until next flag, parsing name:key=val syntax - result["initializers"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initializers"].append(_parse_initializer_arg(parts[i])) - i += 1 - elif parts[i] == "--initialization-scripts": - # Collect script paths until next flag - result["initialization_scripts"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initialization_scripts"].append(parts[i]) - i += 1 - elif parts[i] in ("--strategies", "-s"): - # Collect strategies until next flag - result["scenario_strategies"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": - result["scenario_strategies"].append(parts[i]) - i += 1 - elif parts[i] == "--max-concurrency": - i += 1 - if i >= len(parts): - raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) - i += 1 - elif parts[i] == "--max-retries": - i += 1 - if i >= len(parts): - raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) - i += 1 - elif parts[i] == "--memory-labels": - i += 1 - if i >= len(parts): - raise ValueError("--memory-labels requires a value") - result["memory_labels"] = parse_memory_labels(parts[i]) - i += 1 - elif parts[i] == "--log-level": - i += 1 - if i >= len(parts): - raise ValueError("--log-level requires a value") - result["log_level"] = validate_log_level(log_level=parts[i]) - i += 1 - elif parts[i] == "--dataset-names": - # Collect dataset names until next flag - result["dataset_names"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["dataset_names"].append(parts[i]) - i += 1 - elif parts[i] == "--max-dataset-size": - i += 1 - if i >= len(parts): - raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) - i += 1 - elif parts[i] == "--target": - i += 1 - if i >= len(parts): - raise ValueError("--target requires a value") - result["target"] = parts[i] - i += 1 - else: - raise ValueError(f"Unknown argument: {parts[i]}") - + result = _parse_shell_arguments(parts=parts[1:], arg_specs=_RUN_ARG_SPECS) + result["scenario_name"] = parts[0] return result +def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]: + """ + Parse list-targets command arguments from a string (for shell mode). + + Args: + args_string: Space-separated argument string (e.g., "--initializers target"). + + Returns: + Dictionary with parsed arguments: + - initializers: Optional[list[str | dict[str, Any]]] + - initialization_scripts: Optional[list[str]] + + Raises: + ValueError: If parsing or validation fails. + """ + parts = args_string.split() + return _parse_shell_arguments(parts=parts, arg_specs=_LIST_TARGETS_ARG_SPECS) + + # --------------------------------------------------------------------------- # Shared argparse builder # --------------------------------------------------------------------------- diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 3db5552011..bc9519052a 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -27,6 +27,7 @@ from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg from pyrit.cli._cli_args import add_common_arguments as add_common_arguments from pyrit.cli._cli_args import non_negative_int as non_negative_int +from pyrit.cli._cli_args import parse_list_targets_arguments as parse_list_targets_arguments from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments from pyrit.cli._cli_args import positive_int as positive_int @@ -135,9 +136,6 @@ def __init__( ) from e raise - # Store the merged configuration - self._config = config - # Extract values from config for internal use # Use canonical mapping from configuration_loader self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] @@ -187,6 +185,63 @@ async def initialize_async(self) -> None: self._initialized = True + def with_overrides( + self, + *, + initializer_names: Optional[list[Any]] = None, + initialization_scripts: Optional[list[Path]] = None, + log_level: Optional[int] = None, + ) -> FrontendCore: + """ + Create a derived FrontendCore with per-command overrides. + + Copies inherited state (database, env_files, operator, operation, config) + from this instance and applies the given overrides. Shares registries + with the parent to avoid redundant re-discovery and skips re-reading + config files. + + Args: + initializer_names (Optional[list[Any]]): Per-command initializer overrides. + Each entry can be a string name or a dict with 'name' and optional 'args'. + None keeps the parent's value. + initialization_scripts (Optional[list[Path]]): Per-command script overrides. + None keeps the parent's value. + log_level (Optional[int]): Per-command log level override. + None keeps the parent's value. + + Returns: + FrontendCore: A new context ready for use, without re-reading config files. + """ + derived = object.__new__(FrontendCore) + + # Inherit from parent + derived._database = self._database + derived._env_files = self._env_files + derived._operator = self._operator + derived._operation = self._operation + + # Apply overrides or inherit + derived._log_level = log_level if log_level is not None else self._log_level + + if initializer_names is not None: + loader = ConfigurationLoader.from_dict({"initializers": initializer_names}) + derived._initializer_configs = loader._initializer_configs + else: + derived._initializer_configs = self._initializer_configs + + if initialization_scripts is not None: + derived._initialization_scripts = initialization_scripts + else: + derived._initialization_scripts = self._initialization_scripts + + # Share registries (singletons, no need to re-discover) + derived._scenario_registry = self._scenario_registry + derived._initializer_registry = self._initializer_registry + derived._initialized = True + derived._silent_reinit = True + + return derived + @property def scenario_registry(self) -> ScenarioRegistry: """ @@ -254,18 +309,16 @@ async def list_initializers_async( async def list_targets_async( *, context: FrontendCore, - initializer_names: Optional[list[Any]] = None, ) -> list[str]: """ List available target names from the TargetRegistry. Since targets are registered by initializers, this function requires initializers - to have been run first. If initializer_names are provided, they will be resolved - and run before querying the registry. + to have been run first. Configure initializers on the FrontendCore context + (via initializer_names or initialization_scripts) before calling this function. Args: context: PyRIT context with loaded registries. - initializer_names: Optional list of initializer entries to run before listing. Returns: Sorted list of registered target names. @@ -273,25 +326,24 @@ async def list_targets_async( if not context._initialized: await context.initialize_async() - # If initializer names are provided, run them to populate the target registry - if initializer_names or context._initializer_configs: - configs = context._initializer_configs - if configs: - initializer_instances = [] - for config in configs: + # Run initializers and/or initialization scripts to populate the target registry + if context._initializer_configs or context._initialization_scripts: + initializer_instances = [] + if context._initializer_configs: + for config in context._initializer_configs: initializer_class = context.initializer_registry.get_class(config.name) instance = initializer_class() if config.args: instance.set_params_from_args(args=config.args) initializer_instances.append(instance) - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - silent=getattr(context, "_silent_reinit", False), - ) + await initialize_pyrit_async( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances or None, + env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), + ) target_registry = TargetRegistry.get_registry_singleton() return target_registry.get_names() diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index aefdfa5f22..9a8ca771c4 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -196,18 +196,20 @@ def main(args: Optional[list[str]] = None) -> int: return asyncio.run(frontend_core.print_initializers_list_async(context=context)) if parsed_args.list_targets: - # Need initializers to populate target registry - context = frontend_core.FrontendCore( - config_file=parsed_args.config_file, - initializer_names=parsed_args.initializers, - log_level=parsed_args.log_level, - ) - return asyncio.run(frontend_core.print_targets_list_async(context=context)) + # Need initializers or initialization scripts to populate the target registry + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 - if parsed_args.list_targets: - # Need initializers to populate target registry context = frontend_core.FrontendCore( config_file=parsed_args.config_file, + initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, log_level=parsed_args.log_level, ) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index f19602bee0..ae38edcde8 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -33,7 +33,7 @@ class PyRITShell(cmd.Cmd): Commands: list-scenarios - List all available scenarios list-initializers - List all available initializers - list-targets - List all available targets from the registry + list-targets [opts] - List all available targets from the registry run [opts] - Run a scenario with optional parameters scenario-history - List all previous scenario runs print-scenario [N] - Print detailed results for scenario run(s) @@ -189,6 +189,9 @@ def cmdloop(self, intro: Optional[str] = None) -> None: def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" + if arg.strip(): + print(f"Error: list-scenarios does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) @@ -197,6 +200,9 @@ def do_list_scenarios(self, arg: str) -> None: def do_list_initializers(self, arg: str) -> None: """List all available initializers.""" + if arg.strip(): + print(f"Error: list-initializers does not accept arguments, got: {arg.strip()}") + return self._ensure_initialized() try: asyncio.run(self._fc.print_initializers_list_async(context=self.context)) @@ -204,10 +210,43 @@ def do_list_initializers(self, arg: str) -> None: print(f"Error listing initializers: {e}") def do_list_targets(self, arg: str) -> None: - """List all available targets from the TargetRegistry.""" + """ + List all available targets from the TargetRegistry. + + Usage: + list-targets + list-targets --initializers [ ...] + list-targets --initialization-scripts [ ...] + + Options: + --initializers ... Built-in initializers to run first + --initialization-scripts <...> Custom Python scripts to run first + + Examples: + list-targets --initializers target + list-targets --initializers target:tags=default,scorer + """ self._ensure_initialized() try: - asyncio.run(self._fc.print_targets_list_async(context=self.context)) + list_targets_context = self.context + if arg.strip(): + args = self._fc.parse_list_targets_arguments(args_string=arg) + + resolved_scripts = None + if args["initialization_scripts"]: + resolved_scripts = self._fc.resolve_initialization_scripts( + script_paths=args["initialization_scripts"] + ) + list_targets_context = self.context.with_overrides( + initialization_scripts=resolved_scripts, + initializer_names=args["initializers"], + ) + + asyncio.run(self._fc.print_targets_list_async(context=list_targets_context)) + except ValueError as e: + print(f"Error: {e}") + except FileNotFoundError as e: + print(f"Error: {e}") except Exception as e: print(f"Error listing targets: {e}") @@ -292,16 +331,13 @@ def do_run(self, line: str) -> None: print(f"Error: {e}") return - # Create a context for this run with overrides - run_context = self._fc.FrontendCore( - initialization_scripts=resolved_scripts, + # Create a context for this run with per-command overrides, + # inheriting config_file, database, and env_files from startup. + run_context = self.context.with_overrides( initializer_names=args["initializers"], - log_level=args["log_level"] if args["log_level"] else self.default_log_level, + initialization_scripts=resolved_scripts, + log_level=args["log_level"], ) - # Use the existing registries (don't reinitialize) - run_context._scenario_registry = self.context._scenario_registry - run_context._initializer_registry = self.context._initializer_registry - run_context._initialized = True try: result = asyncio.run( @@ -338,6 +374,9 @@ def do_scenario_history(self, arg: str) -> None: Shows a numbered list of all scenario runs with the commands used. """ + if arg.strip(): + print(f"Error: scenario-history does not accept arguments, got: {arg.strip()}") + return if not self._scenario_history: print("No scenario runs in history.") return @@ -467,8 +506,9 @@ def do_help(self, arg: str) -> None: print(" pyrit_shell") print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG") else: - # Show help for specific command - super().do_help(arg) + # Convert hyphens to underscores (e.g. help list-targets -> help list_targets) for command lookup + normalized_arg = arg.replace("-", "_") + super().do_help(normalized_arg) def do_exit(self, arg: str) -> bool: """ diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 61b3c7bb50..2507f4eb4c 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -12,6 +12,7 @@ import pytest from pyrit.cli import frontend_core +from pyrit.cli._cli_args import _ArgSpec, _parse_shell_arguments from pyrit.registry import InitializerMetadata, ScenarioMetadata @@ -554,6 +555,93 @@ def test_colon_but_no_params_returns_string(self) -> None: assert result == "target" +class TestParseShellArguments: + """Tests for the generic _parse_shell_arguments function.""" + + def test_empty_parts_returns_none_defaults(self): + """Test that empty input returns None for all result keys.""" + spec = _ArgSpec(flags=["--foo"], result_key="foo") + result = _parse_shell_arguments(parts=[], arg_specs=[spec]) + assert result == {"foo": None} + + def test_single_value_arg(self): + """Test parsing a single-value argument.""" + spec = _ArgSpec(flags=["--name"], result_key="name") + result = _parse_shell_arguments(parts=["--name", "alice"], arg_specs=[spec]) + assert result["name"] == "alice" + + def test_single_value_with_parser(self): + """Test that single-value parser is applied.""" + spec = _ArgSpec(flags=["--count"], result_key="count", parser=int) + result = _parse_shell_arguments(parts=["--count", "42"], arg_specs=[spec]) + assert result["count"] == 42 + + def test_single_value_missing_raises(self): + """Test that missing value for single-value arg raises ValueError.""" + spec = _ArgSpec(flags=["--name"], result_key="name") + with pytest.raises(ValueError, match="--name requires a value"): + _parse_shell_arguments(parts=["--name"], arg_specs=[spec]) + + def test_multi_value_arg(self): + """Test collecting multiple values until next flag.""" + spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + result = _parse_shell_arguments(parts=["--items", "a", "b", "c"], arg_specs=[spec]) + assert result["items"] == ["a", "b", "c"] + + def test_multi_value_stops_at_next_flag(self): + """Test that multi-value collection stops at the next known flag.""" + items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + name_spec = _ArgSpec(flags=["--name"], result_key="name") + result = _parse_shell_arguments( + parts=["--items", "a", "b", "--name", "alice"], + arg_specs=[items_spec, name_spec], + ) + assert result["items"] == ["a", "b"] + assert result["name"] == "alice" + + def test_multi_value_stops_at_short_flag_alias(self): + """Test that multi-value collection stops at a short flag alias like -s.""" + long_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + short_spec = _ArgSpec(flags=["-s", "--short"], result_key="short", multi_value=True) + result = _parse_shell_arguments( + parts=["--items", "a", "b", "-s", "x"], + arg_specs=[long_spec, short_spec], + ) + assert result["items"] == ["a", "b"] + assert result["short"] == ["x"] + + def test_multi_value_with_parser(self): + """Test that parser transforms each collected value.""" + spec = _ArgSpec(flags=["--nums"], result_key="nums", multi_value=True, parser=int) + result = _parse_shell_arguments(parts=["--nums", "1", "2", "3"], arg_specs=[spec]) + assert result["nums"] == [1, 2, 3] + + def test_multi_value_no_values_raises(self): + """Test that multi-value arg with no values raises ValueError.""" + items_spec = _ArgSpec(flags=["--items"], result_key="items", multi_value=True) + name_spec = _ArgSpec(flags=["--name"], result_key="name") + with pytest.raises(ValueError, match="--items requires at least one value"): + _parse_shell_arguments( + parts=["--items", "--name", "alice"], + arg_specs=[items_spec, name_spec], + ) + + def test_unknown_flag_raises(self): + """Test that an unknown flag raises ValueError.""" + spec = _ArgSpec(flags=["--known"], result_key="known") + with pytest.raises(ValueError, match="Unknown argument: --unknown"): + _parse_shell_arguments(parts=["--unknown"], arg_specs=[spec]) + + def test_multiple_specs_all_none_when_unused(self): + """Test that unused specs default to None.""" + specs = [ + _ArgSpec(flags=["--a"], result_key="a"), + _ArgSpec(flags=["--b"], result_key="b", multi_value=True), + ] + result = _parse_shell_arguments(parts=[], arg_specs=specs) + assert result == {"a": None, "b": None} + + class TestParseRunArguments: """Tests for parse_run_arguments function.""" @@ -672,6 +760,46 @@ def test_parse_run_arguments_missing_value(self): frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") +class TestParseListTargetsArguments: + """Tests for parse_list_targets_arguments function.""" + + def test_parse_list_targets_arguments_empty(self): + """Test parsing empty string returns defaults.""" + result = frontend_core.parse_list_targets_arguments(args_string="") + assert result["initializers"] is None + assert result["initialization_scripts"] is None + + def test_parse_list_targets_arguments_with_initializers(self): + """Test parsing with initializers.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target init2") + assert result["initializers"] == ["target", "init2"] + + def test_parse_list_targets_arguments_with_initializer_params(self): + """Test parsing initializers with key=value params.""" + result = frontend_core.parse_list_targets_arguments(args_string="--initializers target:tags=default,scorer") + assert result["initializers"] == [{"name": "target", "args": {"tags": ["default", "scorer"]}}] + + def test_parse_list_targets_arguments_with_initialization_scripts(self): + """Test parsing with initialization-scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initialization-scripts script1.py script2.py" + ) + assert result["initialization_scripts"] == ["script1.py", "script2.py"] + + def test_parse_list_targets_arguments_with_both(self): + """Test parsing with both initializers and scripts.""" + result = frontend_core.parse_list_targets_arguments( + args_string="--initializers target --initialization-scripts script1.py" + ) + assert result["initializers"] == ["target"] + assert result["initialization_scripts"] == ["script1.py"] + + def test_parse_list_targets_arguments_unknown_arg_raises(self): + """Test parsing with unknown argument raises ValueError.""" + with pytest.raises(ValueError, match="Unknown argument"): + frontend_core.parse_list_targets_arguments(args_string="--unknown-flag") + + @pytest.mark.asyncio @pytest.mark.usefixtures("patch_central_database") class TestRunScenarioAsync: @@ -933,9 +1061,125 @@ def test_parse_run_arguments_target_with_other_args(self): args_string="test_scenario --target my_target --initializers init1 --max-concurrency 5" ) - assert result["target"] == "my_target" - assert result["initializers"] == ["init1"] - assert result["max_concurrency"] == 5 + +class TestWithOverrides: + """Tests for FrontendCore.with_overrides method.""" + + def _make_initialized_parent(self) -> frontend_core.FrontendCore: + """Create a fully-initialized FrontendCore for testing with_overrides.""" + parent = frontend_core.FrontendCore( + database=frontend_core.IN_MEMORY, + initializer_names=["parent_init"], + log_level=logging.WARNING, + ) + parent._scenario_registry = MagicMock() + parent._initializer_registry = MagicMock() + parent._initialized = True + parent._silent_reinit = True + return parent + + def test_with_overrides_inherits_fields(self): + """Test that derived context inherits database, env_files, operator, operation.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._database == parent._database + assert derived._env_files == parent._env_files + assert derived._operator == parent._operator + assert derived._operation == parent._operation + + def test_with_overrides_shares_registries(self): + """Test that derived context shares scenario and initializer registries.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._scenario_registry is parent._scenario_registry + assert derived._initializer_registry is parent._initializer_registry + + def test_with_overrides_sets_initialized_and_silent(self): + """Test that derived context is marked initialized with silent reinit.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides() + + assert derived._initialized is True + assert derived._silent_reinit is True + + def test_with_overrides_none_keeps_parent_values(self): + """Test that passing None for all overrides keeps parent's values.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides( + initializer_names=None, + initialization_scripts=None, + log_level=None, + ) + + assert derived._initializer_configs == parent._initializer_configs + assert derived._initialization_scripts == parent._initialization_scripts + assert derived._log_level == parent._log_level + + def test_with_overrides_initializer_names(self): + """Test that initializer_names override normalizes to InitializerConfig objects.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(initializer_names=["target", "dataset"]) + + assert derived._initializer_configs is not None + names = [ic.name for ic in derived._initializer_configs] + assert names == ["target", "dataset"] + # Parent should still have original + assert [ic.name for ic in parent._initializer_configs] == ["parent_init"] + + def test_with_overrides_initializer_names_dict(self): + """Test initializer_names with dict entries (name + args).""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(initializer_names=[{"name": "target", "args": {"tags": "default"}}]) + + assert derived._initializer_configs is not None + assert len(derived._initializer_configs) == 1 + assert derived._initializer_configs[0].name == "target" + assert derived._initializer_configs[0].args == {"tags": "default"} + + def test_with_overrides_initialization_scripts(self): + """Test that initialization_scripts override replaces parent's scripts.""" + parent = self._make_initialized_parent() + new_scripts = [Path("/new/script.py")] + + derived = parent.with_overrides(initialization_scripts=new_scripts) + + assert derived._initialization_scripts == new_scripts + # Parent should be unchanged + assert parent._initialization_scripts != new_scripts + + def test_with_overrides_log_level(self): + """Test that log_level override replaces parent's log level.""" + parent = self._make_initialized_parent() + + derived = parent.with_overrides(log_level=logging.DEBUG) + + assert derived._log_level == logging.DEBUG + assert parent._log_level == logging.WARNING + + def test_with_overrides_does_not_mutate_parent(self): + """Test that with_overrides does not modify the parent context.""" + parent = self._make_initialized_parent() + original_configs = parent._initializer_configs + original_log_level = parent._log_level + original_scripts = parent._initialization_scripts + + parent.with_overrides( + initializer_names=["new_init"], + initialization_scripts=[Path("/new.py")], + log_level=logging.DEBUG, + ) + + assert parent._initializer_configs is original_configs + assert parent._log_level == original_log_level + assert parent._initialization_scripts is original_scripts def test_parse_run_arguments_target_missing_value(self): """Test parsing --target without a value raises ValueError.""" @@ -1141,3 +1385,31 @@ async def test_print_targets_list_empty( captured = capsys.readouterr() assert "No targets found" in captured.out assert "--initializers target" in captured.out + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_list_targets_with_initialization_scripts_calls_initialize( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test list_targets_async calls initialize_pyrit_async when only scripts are configured.""" + mock_registry = MagicMock() + mock_registry.get_names.return_value = ["script_target"] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + context._initialization_scripts = ["/path/to/script.py"] + context._initializer_configs = None + + result = await frontend_core.list_targets_async(context=context) + + assert result == ["script_target"] + # Verify initialize_pyrit_async was called with the scripts + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + assert call_kwargs["initialization_scripts"] == ["/path/to/script.py"] + assert call_kwargs["initializers"] is None diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 34a8b8ad52..a4c3620ca7 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -214,6 +214,55 @@ def test_main_list_scenarios_with_missing_script(self, mock_resolve_scripts: Mag assert result == 1 + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_initializers( + self, + mock_frontend_core: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initializers passes initializers to FrontendCore.""" + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initializers", "target"]) + + assert result == 0 + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initializer_names"] == ["target"] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_targets_with_scripts( + self, + mock_frontend_core: MagicMock, + mock_resolve_scripts: MagicMock, + mock_print_targets: AsyncMock, + ): + """Test main with --list-targets and --initialization-scripts passes scripts to FrontendCore.""" + mock_resolve_scripts.return_value = [Path("/test/script.py")] + mock_print_targets.return_value = 0 + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "script.py"]) + + assert result == 0 + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initialization_scripts"] == [Path("/test/script.py")] + mock_print_targets.assert_called_once() + + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_main_list_targets_with_missing_script(self, mock_resolve_scripts: MagicMock): + """Test main with --list-targets and missing script file.""" + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + result = pyrit_scan.main(["--list-targets", "--initialization-scripts", "missing.py"]) + + assert result == 1 + def test_main_no_scenario_specified(self, capsys): """Test main without scenario name.""" result = pyrit_scan.main([]) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 89a218644d..4f562f9917 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -151,6 +151,15 @@ def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, captured = capsys.readouterr() assert "Error listing scenarios" in captured.out + def test_do_list_scenarios_rejects_args(self, shell, capsys): + """Test do_list_scenarios rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_scenarios("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) def test_do_list_initializers(self, mock_print_initializers: AsyncMock, shell): """Test do_list_initializers command.""" @@ -171,6 +180,67 @@ def test_do_list_initializers_with_exception(self, mock_print_initializers: Asyn captured = capsys.readouterr() assert "Error listing initializers" in captured.out + def test_do_list_initializers_rejects_args(self, shell, capsys): + """Test do_list_initializers rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_list_initializers("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_no_args(self, mock_print_targets: AsyncMock, shell): + """Test do_list_targets with no arguments uses the default context.""" + s, ctx, _ = shell + + s.do_list_targets("") + + mock_print_targets.assert_called_once_with(context=ctx) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.parse_list_targets_arguments") + def test_do_list_targets_with_initializers( + self, + mock_parse: MagicMock, + mock_print_targets: AsyncMock, + shell, + ): + """Test do_list_targets with --initializers uses context.with_overrides.""" + s, ctx, _ = shell + mock_parse.return_value = {"initializers": ["target"], "initialization_scripts": None} + mock_derived = MagicMock() + ctx.with_overrides = MagicMock(return_value=mock_derived) + + s.do_list_targets("--initializers target") + + mock_parse.assert_called_once_with(args_string="--initializers target") + ctx.with_overrides.assert_called_once_with( + initialization_scripts=None, + initializer_names=["target"], + ) + mock_print_targets.assert_called_once_with(context=mock_derived) + + @patch("pyrit.cli.frontend_core.print_targets_list_async", new_callable=AsyncMock) + def test_do_list_targets_with_exception(self, mock_print_targets: AsyncMock, shell, capsys): + """Test do_list_targets handles exceptions.""" + s, ctx, _ = shell + mock_print_targets.side_effect = RuntimeError("Test error") + + s.do_list_targets("") + + captured = capsys.readouterr() + assert "Error listing targets" in captured.out + + def test_do_list_targets_parse_error(self, shell, capsys): + """Test do_list_targets shows error for invalid args.""" + s, ctx, _ = shell + + s.do_list_targets("--unknown-flag") + + captured = capsys.readouterr() + assert "Error" in captured.out + def test_do_run_empty_line(self, shell, capsys): """Test do_run with empty line.""" s, ctx, _ = shell @@ -380,6 +450,15 @@ def test_do_scenario_history_empty(self, shell, capsys): captured = capsys.readouterr() assert "No scenario runs in history" in captured.out + def test_do_scenario_history_rejects_args(self, shell, capsys): + """Test do_scenario_history rejects unexpected arguments.""" + s, ctx, _ = shell + + s.do_scenario_history("--unknown foo") + + captured = capsys.readouterr() + assert "does not accept arguments" in captured.out + def test_do_scenario_history_with_runs(self, shell, capsys): """Test do_scenario_history with scenario runs.""" s, ctx, _ = shell @@ -502,6 +581,14 @@ def test_do_help_with_arg(self, shell): s.do_help("run") mock_parent_help.assert_called_with("run") + def test_do_help_with_hyphenated_arg(self, shell): + """Test do_help converts hyphens to underscores for command lookup.""" + s, ctx, _ = shell + + with patch("cmd.Cmd.do_help") as mock_parent_help: + s.do_help("list-targets") + mock_parent_help.assert_called_with("list_targets") + @patch.object(cmd.Cmd, "cmdloop") @patch.object(banner, "play_animation") def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, shell):