Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/cli/_banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo
"Commands:",
" • list-scenarios - See all available scenarios",
" • list-initializers - See all available initializers",
" • list-targets - See all available targets in the registry",
" • list-targets [opts] - See all available targets in the registry",
" • run <scenario> [opts] - Execute a security scenario",
" • scenario-history - View your session history",
" • print-scenario [N] - Display detailed results",
Expand Down
271 changes: 188 additions & 83 deletions pyrit/cli/_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import argparse
import dataclasses
import inspect
import json
import logging
Expand Down Expand Up @@ -342,6 +343,172 @@ def _parse_initializer_arg(arg: str) -> str | dict[str, Any]:
return name


# ---------------------------------------------------------------------------
# Shell argument specification
# ---------------------------------------------------------------------------


@dataclasses.dataclass(frozen=True)
class _ArgSpec:
"""
Declarative specification for a single shell-mode CLI argument.

Each instance describes one CLI flag (or set of aliases) and how its
value(s) should be collected and validated. A list of ``_ArgSpec`` objects
is passed to ``_parse_shell_arguments`` which handles the actual parsing
loop. Adding a new flag only requires defining a new ``_ArgSpec``
constant, not editing any parsing logic.

Attributes:
flags: CLI flag strings that trigger this argument (e.g., ``["--strategies", "-s"]``).
result_key: Key name in the returned dict (e.g., ``"scenario_strategies"``).
multi_value: If True, collect values until the next flag.
If False, consume exactly one value.
parser: Optional callable to transform each raw string value.
Applied per-item for multi-value args, or to the single value otherwise.
"""

flags: list[str]
result_key: str
multi_value: bool = False
parser: Callable[[str], Any] | None = None


_INITIALIZERS_ARG = _ArgSpec(
flags=["--initializers"],
result_key="initializers",
multi_value=True,
parser=_parse_initializer_arg,
)
_INIT_SCRIPTS_ARG = _ArgSpec(
flags=["--initialization-scripts"],
result_key="initialization_scripts",
multi_value=True,
)

_STRATEGIES_ARG = _ArgSpec(
flags=["--strategies", "-s"],
result_key="scenario_strategies",
multi_value=True,
)
_MAX_CONCURRENCY_ARG = _ArgSpec(
flags=["--max-concurrency"],
result_key="max_concurrency",
parser=lambda v: validate_integer(v, name="--max-concurrency", min_value=1),
)
_MAX_RETRIES_ARG = _ArgSpec(
flags=["--max-retries"],
result_key="max_retries",
parser=lambda v: validate_integer(v, name="--max-retries", min_value=0),
)
_MEMORY_LABELS_ARG = _ArgSpec(
flags=["--memory-labels"],
result_key="memory_labels",
parser=parse_memory_labels,
)
_LOG_LEVEL_ARG = _ArgSpec(
flags=["--log-level"],
result_key="log_level",
parser=lambda v: validate_log_level(log_level=v),
)
_DATASET_NAMES_ARG = _ArgSpec(
flags=["--dataset-names"],
result_key="dataset_names",
multi_value=True,
)
_MAX_DATASET_SIZE_ARG = _ArgSpec(
flags=["--max-dataset-size"],
result_key="max_dataset_size",
parser=lambda v: validate_integer(v, name="--max-dataset-size", min_value=1),
)
_TARGET_ARG = _ArgSpec(
flags=["--target"],
result_key="target",
)

_RUN_ARG_SPECS: list[_ArgSpec] = [
_INITIALIZERS_ARG,
_INIT_SCRIPTS_ARG,
_STRATEGIES_ARG,
_MAX_CONCURRENCY_ARG,
_MAX_RETRIES_ARG,
_MEMORY_LABELS_ARG,
_LOG_LEVEL_ARG,
_DATASET_NAMES_ARG,
_MAX_DATASET_SIZE_ARG,
_TARGET_ARG,
]

_LIST_TARGETS_ARG_SPECS: list[_ArgSpec] = [
_INITIALIZERS_ARG,
_INIT_SCRIPTS_ARG,
]


# ---------------------------------------------------------------------------
# Generic shell argument parser
# ---------------------------------------------------------------------------


def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> dict[str, Any]:
"""
Parse a list of shell tokens against a set of argument specifications.

Each ``_ArgSpec`` in *arg_specs* declares how its flag(s) should be handled
(multi-value collection vs. single-value consumption) and what validation
or transformation to apply.

Args:
parts: Token list (already split on whitespace, positional args removed).
arg_specs: Argument specifications that this command accepts.

Returns:
Dictionary mapping each spec's ``result_key`` to its parsed value,
defaulting to ``None`` for arguments not present in *parts*.

Raises:
ValueError: On unknown flags or missing values.
"""
# Build lookup: flag string → spec
flag_to_spec: dict[str, _ArgSpec] = {}
for spec in arg_specs:
for flag in spec.flags:
flag_to_spec[flag] = spec

# Initialise result with None defaults
result: dict[str, Any] = {spec.result_key: None for spec in arg_specs}

i = 0
while i < len(parts):
token = parts[i]
spec = flag_to_spec.get(token)

if spec is None:
valid = sorted(flag_to_spec.keys())
raise ValueError(f"Unknown argument: {token}. Valid arguments: {', '.join(valid)}")

i += 1

if spec.multi_value:
values: list[Any] = []
# Collect values until the next flag (whether valid or invalid)
while i < len(parts) and not (parts[i].startswith("--") or parts[i] in flag_to_spec):
item = spec.parser(parts[i]) if spec.parser else parts[i]
values.append(item)
i += 1
if len(values) == 0:
raise ValueError(f"{spec.flags[0]} requires at least one value")
result[spec.result_key] = values
else:
if i >= len(parts):
raise ValueError(f"{spec.flags[0]} requires a value")
raw = parts[i]
result[spec.result_key] = spec.parser(raw) if spec.parser else raw
i += 1

return result


def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
"""
Parse run command arguments from a string (for shell mode).
Expand Down Expand Up @@ -371,92 +538,30 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
if not parts:
raise ValueError("No scenario name provided")

result: dict[str, Any] = {
"scenario_name": parts[0],
"initializers": None,
"initialization_scripts": None,
"scenario_strategies": None,
"max_concurrency": None,
"max_retries": None,
"memory_labels": None,
"log_level": None,
"dataset_names": None,
"max_dataset_size": None,
"target": None,
}

i = 1
while i < len(parts):
if parts[i] == "--initializers":
# Collect initializers until next flag, parsing name:key=val syntax
result["initializers"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["initializers"].append(_parse_initializer_arg(parts[i]))
i += 1
elif parts[i] == "--initialization-scripts":
# Collect script paths until next flag
result["initialization_scripts"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["initialization_scripts"].append(parts[i])
i += 1
elif parts[i] in ("--strategies", "-s"):
# Collect strategies until next flag
result["scenario_strategies"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s":
result["scenario_strategies"].append(parts[i])
i += 1
elif parts[i] == "--max-concurrency":
i += 1
if i >= len(parts):
raise ValueError("--max-concurrency requires a value")
result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1)
i += 1
elif parts[i] == "--max-retries":
i += 1
if i >= len(parts):
raise ValueError("--max-retries requires a value")
result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0)
i += 1
elif parts[i] == "--memory-labels":
i += 1
if i >= len(parts):
raise ValueError("--memory-labels requires a value")
result["memory_labels"] = parse_memory_labels(parts[i])
i += 1
elif parts[i] == "--log-level":
i += 1
if i >= len(parts):
raise ValueError("--log-level requires a value")
result["log_level"] = validate_log_level(log_level=parts[i])
i += 1
elif parts[i] == "--dataset-names":
# Collect dataset names until next flag
result["dataset_names"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["dataset_names"].append(parts[i])
i += 1
elif parts[i] == "--max-dataset-size":
i += 1
if i >= len(parts):
raise ValueError("--max-dataset-size requires a value")
result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1)
i += 1
elif parts[i] == "--target":
i += 1
if i >= len(parts):
raise ValueError("--target requires a value")
result["target"] = parts[i]
i += 1
else:
raise ValueError(f"Unknown argument: {parts[i]}")

result = _parse_shell_arguments(parts=parts[1:], arg_specs=_RUN_ARG_SPECS)
result["scenario_name"] = parts[0]
return result


def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]:
"""
Parse list-targets command arguments from a string (for shell mode).

Args:
args_string: Space-separated argument string (e.g., "--initializers target").

Returns:
Dictionary with parsed arguments:
- initializers: Optional[list[str | dict[str, Any]]]
- initialization_scripts: Optional[list[str]]

Raises:
ValueError: If parsing or validation fails.
"""
parts = args_string.split()
return _parse_shell_arguments(parts=parts, arg_specs=_LIST_TARGETS_ARG_SPECS)


# ---------------------------------------------------------------------------
# Shared argparse builder
# ---------------------------------------------------------------------------
Expand Down
Loading
Loading