diff --git a/README.md b/README.md
index 73d0731..025a3d7 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# AzureFox
-
+
Find attack paths, pivot opportunities, and movement across Azure before you drown in inventory.
diff --git a/schemas/chains.schema.json b/schemas/chains.schema.json
index 62d6d10..266d003 100644
--- a/schemas/chains.schema.json
+++ b/schemas/chains.schema.json
@@ -11,6 +11,8 @@
"current_gap",
"artifact_preference_order",
"backing_commands",
+ "reused_sources",
+ "live_sources",
"source_artifacts",
"paths",
"issues"
diff --git a/src/azurefox/chains/registry.py b/src/azurefox/chains/registry.py
index d3582b3..6b10d23 100644
--- a/src/azurefox/chains/registry.py
+++ b/src/azurefox/chains/registry.py
@@ -2,9 +2,27 @@
from dataclasses import dataclass
+from azurefox.config import GlobalOptions
+from azurefox.models.commands import (
+ AksOutput,
+ AppServicesOutput,
+ ArmDeploymentsOutput,
+ AutomationOutput,
+ DatabasesOutput,
+ DevopsOutput,
+ EnvVarsOutput,
+ FunctionsOutput,
+ KeyVaultOutput,
+ PermissionsOutput,
+ RbacOutput,
+ RoleTrustsOutput,
+ StorageOutput,
+ TokensCredentialsOutput,
+)
+
GROUPED_COMMAND_NAME = "chains"
-GROUPED_COMMAND_INPUT_MODES = ("live", "artifacts")
-PREFERRED_ARTIFACT_ORDER = ("loot", "json")
+GROUPED_COMMAND_INPUT_MODES = ("live", "artifacts", "mixed")
+PREFERRED_ARTIFACT_ORDER = ("json",)
SEMANTIC_LOOT_CHAIN_FAMILIES = (
"credential-path",
"deployment-path",
@@ -32,6 +50,24 @@ class ChainFamilySpec:
source_commands: tuple[ChainSourceSpec, ...]
+_CHAIN_SOURCE_MODELS = {
+ "devops": DevopsOutput,
+ "automation": AutomationOutput,
+ "permissions": PermissionsOutput,
+ "rbac": RbacOutput,
+ "role-trusts": RoleTrustsOutput,
+ "keyvault": KeyVaultOutput,
+ "arm-deployments": ArmDeploymentsOutput,
+ "app-services": AppServicesOutput,
+ "functions": FunctionsOutput,
+ "aks": AksOutput,
+ "env-vars": EnvVarsOutput,
+ "tokens-credentials": TokensCredentialsOutput,
+ "databases": DatabasesOutput,
+ "storage": StorageOutput,
+}
+
+
CHAIN_FAMILIES: tuple[ChainFamilySpec, ...] = (
ChainFamilySpec(
name="credential-path",
@@ -498,3 +534,37 @@ def get_chain_family_spec(name: str) -> ChainFamilySpec | None:
if spec.name == name:
return spec
return None
+
+
+def chain_source_model(command: str):
+ try:
+ return _CHAIN_SOURCE_MODELS[command]
+ except KeyError as exc:
+ raise ValueError(f"Missing empty grouped-source model for '{command}'") from exc
+
+
+def empty_chain_source_fields(
+ command: str,
+ issue: dict[str, object],
+ options: GlobalOptions,
+) -> dict[str, object]:
+ fields_by_command: dict[str, dict[str, object]] = {
+ "devops": {"pipelines": [], "issues": [issue]},
+ "automation": {"automation_accounts": [], "issues": [issue]},
+ "permissions": {"permissions": [], "issues": [issue]},
+ "rbac": {"principals": [], "scopes": [], "role_assignments": [], "issues": [issue]},
+ "role-trusts": {"mode": options.role_trusts_mode, "trusts": [], "issues": [issue]},
+ "keyvault": {"key_vaults": [], "issues": [issue]},
+ "arm-deployments": {"deployments": [], "issues": [issue]},
+ "app-services": {"app_services": [], "issues": [issue]},
+ "functions": {"function_apps": [], "issues": [issue]},
+ "aks": {"aks_clusters": [], "issues": [issue]},
+ "env-vars": {"env_vars": [], "issues": [issue]},
+ "tokens-credentials": {"surfaces": [], "issues": [issue]},
+ "databases": {"database_servers": [], "issues": [issue]},
+ "storage": {"storage_assets": [], "issues": [issue]},
+ }
+ try:
+ return fields_by_command[command]
+ except KeyError as exc:
+ raise ValueError(f"Missing empty grouped-source shape for '{command}'") from exc
diff --git a/src/azurefox/chains/runner.py b/src/azurefox/chains/runner.py
index 6a2db86..33c4bfb 100644
--- a/src/azurefox/chains/runner.py
+++ b/src/azurefox/chains/runner.py
@@ -1,8 +1,10 @@
from __future__ import annotations
+import json
import re
from collections import defaultdict
-from dataclasses import replace
+from dataclasses import dataclass, replace
+from pathlib import Path
from azurefox.chains.compute_control import collect_compute_control_records
from azurefox.chains.credential_path import collect_credential_path_records
@@ -14,7 +16,10 @@
)
from azurefox.chains.registry import (
GROUPED_COMMAND_NAME,
+ PREFERRED_ARTIFACT_ORDER,
ChainFamilySpec,
+ chain_source_model,
+ empty_chain_source_fields,
get_chain_family_spec,
implemented_chain_family_names,
is_implemented_chain_family,
@@ -34,26 +39,16 @@
)
from azurefox.models.chains import (
ChainPathRecord,
+ ChainSourceArtifact,
ChainsOutput,
)
-from azurefox.models.commands import (
- AksOutput,
- AppServicesOutput,
- ArmDeploymentsOutput,
- AutomationOutput,
- ChainsCommandOutput,
- DatabasesOutput,
- DevopsOutput,
- EnvVarsOutput,
- FunctionsOutput,
- KeyVaultOutput,
- PermissionsOutput,
- RbacOutput,
- RoleTrustsOutput,
- StorageOutput,
- TokensCredentialsOutput,
+from azurefox.models.commands import ChainsCommandOutput
+from azurefox.models.common import (
+ SCHEMA_VERSION,
+ ArmDeploymentSummary,
+ CollectionIssue,
+ CommandMetadata,
)
-from azurefox.models.common import ArmDeploymentSummary, CollectionIssue, CommandMetadata
from azurefox.registry import get_command_specs
from azurefox.scope_hints import permission_scope_description, permission_scope_phrase
from azurefox.target_matching import (
@@ -140,6 +135,41 @@
}
+@dataclass(slots=True)
+class ChainSourceLoadState:
+ command: str
+ output: object
+ mode: str
+ reuse_issue: CollectionIssue | None = None
+ source_artifact: ChainSourceArtifact | None = None
+
+
+@dataclass(slots=True)
+class ChainFamilyLoadResult:
+ outputs: dict[str, object]
+ source_states: list[ChainSourceLoadState]
+
+ @property
+ def reused_sources(self) -> list[str]:
+ return [state.command for state in self.source_states if state.mode == "artifacts"]
+
+ @property
+ def live_sources(self) -> list[str]:
+ return [state.command for state in self.source_states if state.mode == "live"]
+
+ @property
+ def source_artifacts(self) -> list[ChainSourceArtifact]:
+ return [
+ state.source_artifact
+ for state in self.source_states
+ if state.source_artifact is not None
+ ]
+
+ @property
+ def reuse_issues(self) -> list[CollectionIssue]:
+ return [state.reuse_issue for state in self.source_states if state.reuse_issue is not None]
+
+
def implemented_chain_families() -> tuple[str, ...]:
return implemented_chain_family_names()
@@ -155,15 +185,15 @@ def run_chain_family(
if not is_implemented_chain_family(family_name):
raise ValueError(f"Chain family '{family_name}' is not implemented yet")
- loaded = _collect_family_outputs(provider, options, family_name)
+ load_result = _collect_family_outputs(provider, options, family_name)
if family_name == "credential-path":
- return _build_credential_path_output(provider, options, family_name, loaded)
+ return _build_credential_path_output(provider, options, family_name, load_result)
if family_name == "deployment-path":
- return _build_deployment_path_output(options, family_name, loaded)
+ return _build_deployment_path_output(options, family_name, load_result)
if family_name == "escalation-path":
- return _build_escalation_path_output(options, family_name, loaded)
+ return _build_escalation_path_output(options, family_name, load_result)
if family_name == "compute-control":
- return _build_compute_control_output(options, family_name, loaded)
+ return _build_compute_control_output(options, family_name, load_result)
raise ValueError(f"Unsupported chain family '{family_name}'")
@@ -172,27 +202,293 @@ def _collect_family_outputs(
provider: BaseProvider,
options: GlobalOptions,
family_name: str,
-) -> dict[str, object]:
+) -> ChainFamilyLoadResult:
family = get_chain_family_spec(family_name)
if family is None:
raise ValueError(f"Unknown chain family '{family_name}'")
collector_by_name = {spec.name: spec.collector for spec in get_command_specs()}
- loaded: dict[str, object] = {}
+ outputs: dict[str, object] = {}
+ source_states: list[ChainSourceLoadState] = []
for source in family.source_commands:
+ reused_output, source_artifact, reuse_issue = _load_reusable_chain_source_output(
+ provider,
+ options,
+ source.command,
+ )
+ if reused_output is not None and source_artifact is not None:
+ outputs[source.command] = reused_output
+ source_states.append(
+ ChainSourceLoadState(
+ command=source.command,
+ output=reused_output,
+ mode="artifacts",
+ reuse_issue=reuse_issue,
+ source_artifact=source_artifact,
+ )
+ )
+ continue
+
collector = collector_by_name[source.command]
try:
- loaded[source.command] = collector(provider, options)
+ outputs[source.command] = collector(provider, options)
except Exception as exc:
- loaded[source.command] = _empty_chain_source_output(
+ outputs[source.command] = _empty_chain_source_output(
command=source.command,
provider=provider,
options=options,
exc=exc,
)
+ source_states.append(
+ ChainSourceLoadState(
+ command=source.command,
+ output=outputs[source.command],
+ mode="live",
+ reuse_issue=reuse_issue,
+ )
+ )
+
+ return ChainFamilyLoadResult(
+ outputs=outputs,
+ source_states=source_states,
+ )
+
+
+def _load_reusable_chain_source_output(
+ provider: BaseProvider,
+ options: GlobalOptions,
+ command: str,
+) -> tuple[object | None, ChainSourceArtifact | None, CollectionIssue | None]:
+ if options.live_only:
+ return None, None, None
- return loaded
+ for artifact_type in PREFERRED_ARTIFACT_ORDER:
+ artifact_path = _chain_source_artifact_path(options, command, artifact_type)
+ if artifact_path is None or not artifact_path.exists():
+ continue
+
+ try:
+ payload = json.loads(artifact_path.read_text(encoding="utf-8"))
+ except OSError as exc:
+ return None, None, _artifact_reuse_issue(
+ command=command,
+ artifact_type=artifact_type,
+ artifact_path=artifact_path,
+ reason="read_error",
+ detail=str(exc),
+ )
+ except json.JSONDecodeError as exc:
+ return None, None, _artifact_reuse_issue(
+ command=command,
+ artifact_type=artifact_type,
+ artifact_path=artifact_path,
+ reason="invalid_json",
+ detail=str(exc),
+ )
+
+ mismatch_reason = _chain_source_artifact_mismatch_reason(
+ payload,
+ provider,
+ options,
+ command,
+ )
+ if mismatch_reason is not None:
+ return None, None, _artifact_reuse_issue(
+ command=command,
+ artifact_type=artifact_type,
+ artifact_path=artifact_path,
+ reason="metadata_mismatch",
+ detail=mismatch_reason,
+ )
+
+ model_class = chain_source_model(command)
+ try:
+ output = model_class.model_validate(payload)
+ except Exception as exc:
+ return None, None, _artifact_reuse_issue(
+ command=command,
+ artifact_type=artifact_type,
+ artifact_path=artifact_path,
+ reason="model_validation_failed",
+ detail=str(exc),
+ )
+
+ return (
+ output,
+ ChainSourceArtifact(
+ command=command,
+ artifact_type=artifact_type,
+ path=str(artifact_path),
+ ),
+ None,
+ )
+
+ return None, None, None
+
+
+def _chain_source_artifact_path(
+ options: GlobalOptions,
+ command: str,
+ artifact_type: str,
+) -> Path | None:
+ directories = {
+ "json": options.json_dir,
+ }
+ directory = directories.get(artifact_type)
+ if directory is None:
+ return None
+ extension = "json" if artifact_type == "json" else artifact_type
+ return directory / f"{command}.{extension}"
+
+
+def _chain_source_artifact_mismatch_reason(
+ payload: dict[str, object],
+ provider: BaseProvider,
+ options: GlobalOptions,
+ command: str,
+) -> str | None:
+ metadata = payload.get("metadata")
+ if not isinstance(metadata, dict):
+ return "metadata is missing or not an object"
+ if metadata.get("command") != command:
+ return f"command mismatch: expected {command}, got {metadata.get('command')}"
+ if metadata.get("schema_version") != SCHEMA_VERSION:
+ return (
+ "schema_version mismatch: "
+ f"expected {SCHEMA_VERSION}, got {metadata.get('schema_version')}"
+ )
+
+ context = provider.metadata_context()
+ current_subscription = options.subscription or context.get("subscription_id")
+ current_tenant = options.tenant or context.get("tenant_id")
+ current_token_source = context.get("token_source")
+ current_auth_mode = context.get("auth_mode")
+
+ mismatch = _metadata_value_mismatch(
+ field_name="subscription_id",
+ artifact_value=metadata.get("subscription_id"),
+ current_value=current_subscription,
+ )
+ if mismatch is not None:
+ return mismatch
+ mismatch = _metadata_value_mismatch(
+ field_name="tenant_id",
+ artifact_value=metadata.get("tenant_id"),
+ current_value=current_tenant,
+ )
+ if mismatch is not None:
+ return mismatch
+ mismatch = _metadata_value_mismatch(
+ field_name="token_source",
+ artifact_value=metadata.get("token_source"),
+ current_value=current_token_source,
+ )
+ if mismatch is not None:
+ return mismatch
+ mismatch = _metadata_value_mismatch(
+ field_name="auth_mode",
+ artifact_value=metadata.get("auth_mode"),
+ current_value=current_auth_mode,
+ )
+ if mismatch is not None:
+ return mismatch
+
+ if command == "devops":
+ mismatch = _metadata_value_mismatch(
+ field_name="devops_organization",
+ artifact_value=metadata.get("devops_organization"),
+ current_value=options.devops_organization,
+ )
+ if mismatch is not None:
+ return mismatch
+
+ if command == "role-trusts" and payload.get("mode") != options.role_trusts_mode:
+ return (
+ "role-trusts mode mismatch: "
+ f"expected {options.role_trusts_mode}, got {payload.get('mode')}"
+ )
+
+ if command == "permissions":
+ current_principal_id = _current_principal_id(provider)
+ if current_principal_id is None:
+ return "current principal id could not be determined"
+ artifact_current_ids = sorted(
+ {
+ str(item.get("principal_id"))
+ for item in payload.get("permissions", [])
+ if isinstance(item, dict)
+ and item.get("is_current_identity")
+ and item.get("principal_id")
+ }
+ )
+ if artifact_current_ids != [current_principal_id]:
+ return (
+ "current principal mismatch: "
+ f"expected [{current_principal_id}], got {artifact_current_ids}"
+ )
+
+ return None
+
+
+def _metadata_value_mismatch(
+ *,
+ field_name: str,
+ artifact_value: object,
+ current_value: str | None,
+) -> str | None:
+ if current_value is None:
+ return None
+ if artifact_value == current_value:
+ return None
+ return f"{field_name} mismatch: expected {current_value}, got {artifact_value}"
+
+
+def _artifact_reuse_issue(
+ *,
+ command: str,
+ artifact_type: str,
+ artifact_path: str,
+ reason: str,
+ detail: str,
+) -> CollectionIssue:
+ reason_labels = {
+ "read_error": "artifact read error",
+ "invalid_json": "invalid JSON artifact",
+ "metadata_mismatch": "artifact metadata mismatch",
+ "model_validation_failed": "artifact model validation failed",
+ }
+ reason_label = reason_labels.get(reason, reason.replace("_", " "))
+ return CollectionIssue(
+ kind="artifact_reuse_skipped",
+ message=(
+ f"Skipped local {artifact_type} artifact reuse for {command}: "
+ f"{reason_label}: {detail}"
+ ),
+ scope=command,
+ context={
+ "collector": command,
+ "artifact_type": artifact_type,
+ "artifact_path": str(artifact_path),
+ "reason": reason,
+ },
+ )
+
+
+def _current_principal_id(provider: BaseProvider) -> str | None:
+ try:
+ whoami = provider.whoami()
+ except Exception:
+ return None
+
+ principal = whoami.get("principal")
+ if not isinstance(principal, dict):
+ return None
+
+ principal_id = principal.get("id")
+ if not principal_id:
+ return None
+ return str(principal_id)
def _empty_chain_source_output(
@@ -204,7 +500,7 @@ def _empty_chain_source_output(
):
issue = _chain_source_issue(command, exc)
context = provider.metadata_context()
- model_class = _empty_chain_source_model(command)
+ model_class = chain_source_model(command)
payload = {
"metadata": CommandMetadata(
command=command,
@@ -214,61 +510,11 @@ def _empty_chain_source_output(
token_source=context.get("token_source"),
auth_mode=context.get("auth_mode"),
),
- **_empty_chain_source_fields(command, issue, options),
+ **empty_chain_source_fields(command, issue, options),
}
return model_class.model_validate(payload)
-def _empty_chain_source_fields(
- command: str,
- issue: dict[str, object],
- options: GlobalOptions,
-) -> dict[str, object]:
- fields_by_command: dict[str, dict[str, object]] = {
- "devops": {"pipelines": [], "issues": [issue]},
- "automation": {"automation_accounts": [], "issues": [issue]},
- "permissions": {"permissions": [], "issues": [issue]},
- "rbac": {"principals": [], "scopes": [], "role_assignments": [], "issues": [issue]},
- "role-trusts": {"mode": options.role_trusts_mode, "trusts": [], "issues": [issue]},
- "keyvault": {"key_vaults": [], "issues": [issue]},
- "arm-deployments": {"deployments": [], "issues": [issue]},
- "app-services": {"app_services": [], "issues": [issue]},
- "functions": {"function_apps": [], "issues": [issue]},
- "aks": {"aks_clusters": [], "issues": [issue]},
- "env-vars": {"env_vars": [], "issues": [issue]},
- "tokens-credentials": {"surfaces": [], "issues": [issue]},
- "databases": {"database_servers": [], "issues": [issue]},
- "storage": {"storage_assets": [], "issues": [issue]},
- }
- try:
- return fields_by_command[command]
- except KeyError as exc:
- raise ValueError(f"Missing empty grouped-source shape for '{command}'") from exc
-
-
-def _empty_chain_source_model(command: str):
- models_by_command = {
- "devops": DevopsOutput,
- "automation": AutomationOutput,
- "permissions": PermissionsOutput,
- "rbac": RbacOutput,
- "role-trusts": RoleTrustsOutput,
- "keyvault": KeyVaultOutput,
- "arm-deployments": ArmDeploymentsOutput,
- "app-services": AppServicesOutput,
- "functions": FunctionsOutput,
- "aks": AksOutput,
- "env-vars": EnvVarsOutput,
- "tokens-credentials": TokensCredentialsOutput,
- "databases": DatabasesOutput,
- "storage": StorageOutput,
- }
- try:
- return models_by_command[command]
- except KeyError as exc:
- raise ValueError(f"Missing empty grouped-source model for '{command}'") from exc
-
-
def _chain_source_issue(command: str, exc: Exception) -> dict[str, object]:
return {
"kind": str(getattr(exc, "kind", "unknown")),
@@ -282,12 +528,12 @@ def _build_credential_path_output(
provider: BaseProvider,
options: GlobalOptions,
family_name: str,
- loaded: dict[str, object],
+ load_result: ChainFamilyLoadResult,
) -> ChainsCommandOutput:
family = get_chain_family_spec(family_name)
assert family is not None # pragma: no cover - guarded above
- paths, issues = collect_credential_path_records(provider, family_name, loaded)
+ paths, issues = collect_credential_path_records(provider, family_name, load_result.outputs)
paths.sort(
key=lambda item: (
@@ -303,6 +549,7 @@ def _build_credential_path_output(
options=options,
family=family,
family_name=family_name,
+ load_result=load_result,
paths=paths,
issues=issues,
)
@@ -311,11 +558,13 @@ def _build_credential_path_output(
def _build_deployment_path_output(
options: GlobalOptions,
family_name: str,
- loaded: dict[str, object],
+ load_result: ChainFamilyLoadResult,
) -> ChainsCommandOutput:
family = get_chain_family_spec(family_name)
assert family is not None # pragma: no cover - guarded above
+ loaded = load_result.outputs
+
devops_output = loaded["devops"]
automation_output = loaded["automation"]
permissions_output = loaded["permissions"]
@@ -523,6 +772,7 @@ def _build_deployment_path_output(
options=options,
family=family,
family_name=family_name,
+ load_result=load_result,
paths=paths,
issues=issues,
)
@@ -531,11 +781,23 @@ def _build_deployment_path_output(
def _build_escalation_path_output(
options: GlobalOptions,
family_name: str,
- loaded: dict[str, object],
+ load_result: ChainFamilyLoadResult | dict[str, object],
) -> ChainsCommandOutput:
family = get_chain_family_spec(family_name)
assert family is not None # pragma: no cover - guarded above
+ if isinstance(load_result, ChainFamilyLoadResult):
+ loaded = load_result.outputs
+ else:
+ loaded = load_result
+ load_result = ChainFamilyLoadResult(
+ outputs=loaded,
+ source_states=[
+ ChainSourceLoadState(command=command, output=output, mode="live")
+ for command, output in loaded.items()
+ ],
+ )
+
permissions_output = loaded["permissions"]
role_trusts_output = loaded["role-trusts"]
@@ -583,6 +845,7 @@ def _build_escalation_path_output(
options=options,
family=family,
family_name=family_name,
+ load_result=load_result,
paths=paths,
issues=issues,
)
@@ -628,17 +891,18 @@ def _current_foothold_contexts_from_permissions(permissions: list[object]) -> li
def _build_compute_control_output(
options: GlobalOptions,
family_name: str,
- loaded: dict[str, object],
+ load_result: ChainFamilyLoadResult,
) -> ChainsCommandOutput:
family = get_chain_family_spec(family_name)
assert family is not None # pragma: no cover - guarded above
- paths, issues = collect_compute_control_records(family_name, loaded)
+ paths, issues = collect_compute_control_records(family_name, load_result.outputs)
return _build_chains_command_output(
options=options,
family=family,
family_name=family_name,
+ load_result=load_result,
paths=paths,
issues=issues,
)
@@ -649,9 +913,11 @@ def _build_chains_command_output(
options: GlobalOptions,
family: ChainFamilySpec,
family_name: str,
+ load_result: ChainFamilyLoadResult,
paths: list[ChainPathRecord],
issues: list[CollectionIssue],
) -> ChainsCommandOutput:
+ all_issues = [*load_result.reuse_issues, *issues]
return ChainsCommandOutput(
metadata=CommandMetadata(
command=GROUPED_COMMAND_NAME,
@@ -662,19 +928,29 @@ def _build_chains_command_output(
),
grouped_command_name=GROUPED_COMMAND_NAME,
family=family_name,
- input_mode="live",
+ input_mode=_chain_input_mode(load_result),
command_state="extraction-only",
summary=family.summary,
claim_boundary=family.allowed_claim,
current_gap=family.current_gap,
- artifact_preference_order=[],
+ artifact_preference_order=[] if options.live_only else list(PREFERRED_ARTIFACT_ORDER),
backing_commands=[source.command for source in family.source_commands],
- source_artifacts=[],
+ reused_sources=list(load_result.reused_sources),
+ live_sources=list(load_result.live_sources),
+ source_artifacts=list(load_result.source_artifacts),
paths=paths,
- issues=issues,
+ issues=all_issues,
)
+def _chain_input_mode(load_result: ChainFamilyLoadResult) -> str:
+ if load_result.reused_sources and load_result.live_sources:
+ return "mixed"
+ if load_result.reused_sources:
+ return "artifacts"
+ return "live"
+
+
def _build_deployment_source_record(
family_name: str,
*,
diff --git a/src/azurefox/cli.py b/src/azurefox/cli.py
index e47b5cc..59c30e0 100644
--- a/src/azurefox/cli.py
+++ b/src/azurefox/cli.py
@@ -277,11 +277,18 @@ def vmss(ctx: typer.Context) -> None:
@app.command("chains")
def chains(
ctx: typer.Context,
+ live_only: bool = typer.Option(
+ False,
+ "--live-only",
+ help="Collect grouped chain sources live only and skip local artifact reuse.",
+ ),
family: str | None = typer.Argument(
None, help="Chain family name, or 'help' to list the available chain families."
),
) -> None:
options: GlobalOptions = ctx.obj
+ if live_only:
+ options = replace(options, live_only=True)
if family in {None, "help"}:
try:
if options.output != OutputMode.JSON:
diff --git a/src/azurefox/clients/graph.py b/src/azurefox/clients/graph.py
index 5cebf3d..7f70b09 100644
--- a/src/azurefox/clients/graph.py
+++ b/src/azurefox/clients/graph.py
@@ -142,54 +142,48 @@ def batch_list_objects_by_key(
self,
requests: list[GraphBatchRequest],
) -> tuple[dict[str, list[dict[str, Any]]], dict[str, AzureFoxError]]:
- pending = list(requests)
partial: dict[str, list[dict[str, Any]]] = {}
- results: dict[str, list[dict[str, Any]]] = {}
- errors: dict[str, AzureFoxError] = {}
-
- while pending:
- chunk = pending[:GRAPH_BATCH_MAX_REQUESTS]
- pending = pending[GRAPH_BATCH_MAX_REQUESTS:]
-
- bodies, body_errors = self._batch_execute(chunk)
- for request in chunk:
- if request.key in body_errors:
- errors[request.key] = body_errors[request.key]
- partial.pop(request.key, None)
- continue
-
- body = bodies.get(request.key)
- if body is None:
- errors[request.key] = AzureFoxError(
- ErrorKind.UNKNOWN,
- f"Graph batch request missing response for {request.path}",
- )
- partial.pop(request.key, None)
- continue
- values = body.get("value", [])
- if isinstance(values, list):
- partial.setdefault(request.key, []).extend(
- item for item in values if isinstance(item, dict)
- )
+ def _consume_list_body(
+ request: GraphBatchRequest,
+ body: dict[str, Any],
+ ) -> tuple[bool, list[dict[str, Any]] | None]:
+ values = body.get("value", [])
+ if isinstance(values, list):
+ partial.setdefault(request.key, []).extend(
+ item for item in values if isinstance(item, dict)
+ )
- next_url = body.get("@odata.nextLink")
- if isinstance(next_url, str) and next_url:
- pending.append(GraphBatchRequest(key=request.key, path=next_url))
- continue
+ next_url = body.get("@odata.nextLink")
+ if isinstance(next_url, str) and next_url:
+ return False, None
- results[request.key] = list(partial.get(request.key, []))
- partial.pop(request.key, None)
+ final_items = list(partial.get(request.key, []))
+ partial.pop(request.key, None)
+ return True, final_items
- return results, errors
+ return self._batch_collect_by_key(requests, _consume_list_body)
def batch_get_objects_by_key(
self,
requests: list[GraphBatchRequest],
) -> tuple[dict[str, dict[str, Any]], dict[str, AzureFoxError]]:
- results: dict[str, dict[str, Any]] = {}
- errors: dict[str, AzureFoxError] = {}
+ def _consume_get_body(
+ _request: GraphBatchRequest,
+ body: dict[str, Any],
+ ) -> tuple[bool, dict[str, Any] | None]:
+ return True, body
+
+ return self._batch_collect_by_key(requests, _consume_get_body)
+
+ def _batch_collect_by_key(
+ self,
+ requests: list[GraphBatchRequest],
+ consume_body,
+ ) -> tuple[dict[str, Any], dict[str, AzureFoxError]]:
pending = list(requests)
+ results: dict[str, Any] = {}
+ errors: dict[str, AzureFoxError] = {}
while pending:
chunk = pending[:GRAPH_BATCH_MAX_REQUESTS]
@@ -209,7 +203,16 @@ def batch_get_objects_by_key(
)
continue
- results[request.key] = body
+ completed, final_value = consume_body(request, body)
+ if not completed:
+ pending.append(
+ GraphBatchRequest(
+ key=request.key,
+ path=str(body["@odata.nextLink"]),
+ )
+ )
+ continue
+ results[request.key] = final_value
return results, errors
@@ -369,13 +372,24 @@ def _graph_batch_request_error(
if message:
pieces.append(message)
formatted = " ".join(pieces)
+ kind = _graph_batch_error_kind(status=status, code=code, message=message)
return AzureFoxError(
- classify_exception(Exception(formatted)),
+ kind,
formatted,
details={"body": json.dumps(body)[:500]},
)
+def _graph_batch_error_kind(*, status: int, code: str, message: str) -> ErrorKind:
+ if status == 401:
+ return ErrorKind.AUTH_FAILURE
+ if status == 403:
+ return ErrorKind.PERMISSION_DENIED
+ if status == 429:
+ return ErrorKind.THROTTLING
+ return classify_exception(Exception(" ".join(part for part in (code, message) if part)))
+
+
def _graph_ssl_context() -> ssl.SSLContext:
if certifi is not None:
return ssl.create_default_context(cafile=certifi.where())
diff --git a/src/azurefox/collectors/provider.py b/src/azurefox/collectors/provider.py
index bc7a1d9..277b6ca 100644
--- a/src/azurefox/collectors/provider.py
+++ b/src/azurefox/collectors/provider.py
@@ -1858,102 +1858,32 @@ def principals(self) -> dict:
rbac_data = self.rbac()
whoami_data = self.whoami()
identity_data = self.managed_identities()
-
- records: dict[str, dict] = {}
- issues = [
- *rbac_data.get("issues", []),
- *whoami_data.get("issues", []),
- *identity_data.get("issues", []),
- ]
-
- def ensure_record(principal_id: str) -> dict:
- if principal_id not in records:
- records[principal_id] = {
- "id": principal_id,
- "principal_type": "unknown",
- "display_name": None,
- "tenant_id": None,
- "sources": [],
- "scope_ids": [],
- "assignment_scope_ids": [],
- "role_names": [],
- "role_assignment_count": 0,
- "identity_names": [],
- "identity_types": [],
- "attached_to": [],
- "is_current_identity": False,
- }
- return records[principal_id]
-
- for principal in rbac_data.get("principals", []):
- principal_id = principal.get("id")
- if not principal_id:
- continue
- record = ensure_record(principal_id)
- _merge_principal_attributes(record, principal)
- _append_unique(record["sources"], "rbac")
-
- for assignment in rbac_data.get("role_assignments", []):
- principal_id = assignment.get("principal_id")
- if not principal_id:
- continue
- record = ensure_record(principal_id)
- role_name = assignment.get("role_name")
- scope_id = assignment.get("scope_id")
- if role_name:
- _append_unique(record["role_names"], role_name)
- if scope_id:
- _append_unique(record["scope_ids"], scope_id)
- _append_unique(record["assignment_scope_ids"], scope_id)
- record["role_assignment_count"] += 1
- principal_type = assignment.get("principal_type")
- if principal_type:
- record["principal_type"] = _normalize_principal_type(
- record["principal_type"],
- principal_type,
- )
- _append_unique(record["sources"], "rbac")
-
- principal = whoami_data.get("principal")
- if principal and principal.get("id"):
- record = ensure_record(principal["id"])
- _merge_principal_attributes(record, principal)
- record["is_current_identity"] = True
- for scope in whoami_data.get("effective_scopes", []):
- scope_id = scope.get("id")
- if scope_id:
- _append_unique(record["scope_ids"], scope_id)
- _append_unique(record["sources"], "whoami")
-
- for identity in identity_data.get("identities", []):
- principal_id = identity.get("principal_id")
- if not principal_id:
- continue
- record = ensure_record(principal_id)
- if record["principal_type"] == "unknown":
- record["principal_type"] = "ServicePrincipal"
- _append_unique(record["identity_names"], identity.get("name"))
- _append_unique(record["identity_types"], identity.get("identity_type"))
- for scope_id in identity.get("scope_ids", []):
- _append_unique(record["scope_ids"], scope_id)
- for attachment in identity.get("attached_to", []):
- _append_unique(record["attached_to"], attachment)
- _append_unique(record["sources"], "managed-identities")
-
- principals = sorted(
- records.values(),
- key=_principal_sort_key,
+ principals, issues, _assignment_scope_ids_by_principal = _principal_records_from_sources(
+ rbac_data=rbac_data,
+ whoami_data=whoami_data,
+ identity_data=identity_data,
)
return {"principals": principals, "issues": issues}
def permissions(self) -> dict:
- principal_data = self.principals()
+ rbac_data = self.rbac()
+ whoami_data = self.whoami()
+ identity_data = self.managed_identities()
+ principals, issues, assignment_scope_ids_by_principal = _principal_records_from_sources(
+ rbac_data=rbac_data,
+ whoami_data=whoami_data,
+ identity_data=identity_data,
+ )
permission_rows: list[dict] = []
- for principal in principal_data.get("principals", []):
+ for principal in principals:
role_names = sorted(set(principal.get("role_names", [])))
+ principal_id = str(principal.get("id") or "")
scope_ids = sorted(
- set(principal.get("assignment_scope_ids") or principal.get("scope_ids", []))
+ set(
+ assignment_scope_ids_by_principal.get(principal_id)
+ or principal.get("scope_ids", [])
+ )
)
high_impact_roles = sorted(
{
@@ -1986,7 +1916,7 @@ def permissions(self) -> dict:
item["principal_id"] or "",
)
)
- return {"permissions": permission_rows, "issues": principal_data.get("issues", [])}
+ return {"permissions": permission_rows, "issues": issues}
def privesc(self) -> dict:
permissions_data = self.permissions()
@@ -4458,6 +4388,98 @@ def _append_unique(items: list[str], value: str | None) -> None:
items.append(value)
+def _principal_records_from_sources(
+ *,
+ rbac_data: dict,
+ whoami_data: dict,
+ identity_data: dict,
+) -> tuple[list[dict], list[dict], dict[str, list[str]]]:
+ records: dict[str, dict] = {}
+ assignment_scope_ids_by_principal: dict[str, list[str]] = {}
+ issues = [
+ *rbac_data.get("issues", []),
+ *whoami_data.get("issues", []),
+ *identity_data.get("issues", []),
+ ]
+
+ def ensure_record(principal_id: str) -> dict:
+ if principal_id not in records:
+ records[principal_id] = {
+ "id": principal_id,
+ "principal_type": "unknown",
+ "display_name": None,
+ "tenant_id": None,
+ "sources": [],
+ "scope_ids": [],
+ "role_names": [],
+ "role_assignment_count": 0,
+ "identity_names": [],
+ "identity_types": [],
+ "attached_to": [],
+ "is_current_identity": False,
+ }
+ return records[principal_id]
+
+ for principal in rbac_data.get("principals", []):
+ principal_id = principal.get("id")
+ if not principal_id:
+ continue
+ record = ensure_record(principal_id)
+ _merge_principal_attributes(record, principal)
+ _append_unique(record["sources"], "rbac")
+
+ for assignment in rbac_data.get("role_assignments", []):
+ principal_id = assignment.get("principal_id")
+ if not principal_id:
+ continue
+ record = ensure_record(principal_id)
+ role_name = assignment.get("role_name")
+ scope_id = assignment.get("scope_id")
+ if role_name:
+ _append_unique(record["role_names"], role_name)
+ if scope_id:
+ _append_unique(record["scope_ids"], scope_id)
+ assignment_scope_ids_by_principal.setdefault(principal_id, [])
+ _append_unique(assignment_scope_ids_by_principal[principal_id], scope_id)
+ record["role_assignment_count"] += 1
+ principal_type = assignment.get("principal_type")
+ if principal_type:
+ record["principal_type"] = _normalize_principal_type(
+ record["principal_type"],
+ principal_type,
+ )
+ _append_unique(record["sources"], "rbac")
+
+ principal = whoami_data.get("principal")
+ if principal and principal.get("id"):
+ record = ensure_record(principal["id"])
+ _merge_principal_attributes(record, principal)
+ record["is_current_identity"] = True
+ for scope in whoami_data.get("effective_scopes", []):
+ scope_id = scope.get("id")
+ if scope_id:
+ _append_unique(record["scope_ids"], scope_id)
+ _append_unique(record["sources"], "whoami")
+
+ for identity in identity_data.get("identities", []):
+ principal_id = identity.get("principal_id")
+ if not principal_id:
+ continue
+ record = ensure_record(principal_id)
+ if record["principal_type"] == "unknown":
+ record["principal_type"] = "ServicePrincipal"
+ _append_unique(record["identity_names"], identity.get("name"))
+ _append_unique(record["identity_types"], identity.get("identity_type"))
+ for scope_id in identity.get("scope_ids", []):
+ _append_unique(record["scope_ids"], scope_id)
+ for attachment in identity.get("attached_to", []):
+ _append_unique(record["attached_to"], attachment)
+ _append_unique(record["sources"], "managed-identities")
+
+ principals = sorted(records.values(), key=_principal_sort_key)
+ return principals, issues, assignment_scope_ids_by_principal
+
+
def _normalize_principal_type(existing: str | None, candidate: str | None) -> str:
normalized_existing = existing or "unknown"
if not candidate:
@@ -4509,21 +4531,12 @@ def _graph_batch_list_with_fallback(
requests: list[GraphBatchRequest],
serial_fetch,
) -> tuple[dict[str, list[dict]], dict[str, Exception]]:
- if not requests:
- return {}, {}
-
- batch_fetch = getattr(graph, "batch_list_objects_by_key", None)
- if callable(batch_fetch):
- return batch_fetch(requests)
-
- results: dict[str, list[dict]] = {}
- errors: dict[str, Exception] = {}
- for request in requests:
- try:
- results[request.key] = serial_fetch(request)
- except Exception as exc:
- errors[request.key] = exc
- return results, errors
+ return _graph_batch_with_fallback(
+ graph=graph,
+ requests=requests,
+ batch_method_name="batch_list_objects_by_key",
+ serial_fetch=serial_fetch,
+ )
def _graph_batch_get_with_fallback(
@@ -4531,14 +4544,29 @@ def _graph_batch_get_with_fallback(
requests: list[GraphBatchRequest],
serial_fetch,
) -> tuple[dict[str, dict], dict[str, Exception]]:
+ return _graph_batch_with_fallback(
+ graph=graph,
+ requests=requests,
+ batch_method_name="batch_get_objects_by_key",
+ serial_fetch=serial_fetch,
+ )
+
+
+def _graph_batch_with_fallback(
+ *,
+ graph: object,
+ requests: list[GraphBatchRequest],
+ batch_method_name: str,
+ serial_fetch,
+) -> tuple[dict, dict[str, Exception]]:
if not requests:
return {}, {}
- batch_fetch = getattr(graph, "batch_get_objects_by_key", None)
+ batch_fetch = getattr(graph, batch_method_name, None)
if callable(batch_fetch):
return batch_fetch(requests)
- results: dict[str, dict] = {}
+ results: dict = {}
errors: dict[str, Exception] = {}
for request in requests:
try:
diff --git a/src/azurefox/config.py b/src/azurefox/config.py
index 484a37a..f8fb282 100644
--- a/src/azurefox/config.py
+++ b/src/azurefox/config.py
@@ -13,6 +13,7 @@ class GlobalOptions:
output: OutputMode
outdir: Path
debug: bool
+ live_only: bool = False
devops_organization: str | None = None
role_trusts_mode: RoleTrustsMode = RoleTrustsMode.FAST
diff --git a/src/azurefox/models/chains.py b/src/azurefox/models/chains.py
index c88eb46..5c9a5ce 100644
--- a/src/azurefox/models/chains.py
+++ b/src/azurefox/models/chains.py
@@ -91,6 +91,8 @@ class ChainsOutput(BaseModel):
current_gap: str | None = None
artifact_preference_order: list[str] = Field(default_factory=list)
backing_commands: list[str] = Field(default_factory=list)
+ reused_sources: list[str] = Field(default_factory=list)
+ live_sources: list[str] = Field(default_factory=list)
source_artifacts: list[ChainSourceArtifact] = Field(default_factory=list)
paths: list[ChainPathRecord] = Field(default_factory=list)
issues: list[CollectionIssue] = Field(default_factory=list)
diff --git a/tests/test_chain_scaffold.py b/tests/test_chain_scaffold.py
index 94b33bf..06c1c34 100644
--- a/tests/test_chain_scaffold.py
+++ b/tests/test_chain_scaffold.py
@@ -53,8 +53,8 @@
def test_chain_registry_uses_expected_grouped_command_shape() -> None:
assert GROUPED_COMMAND_NAME == "chains"
- assert GROUPED_COMMAND_INPUT_MODES == ("live", "artifacts")
- assert PREFERRED_ARTIFACT_ORDER == ("loot", "json")
+ assert GROUPED_COMMAND_INPUT_MODES == ("live", "artifacts", "mixed")
+ assert PREFERRED_ARTIFACT_ORDER == ("json",)
def test_chain_registry_keeps_first_family_order() -> None:
diff --git a/tests/test_cli_smoke.py b/tests/test_cli_smoke.py
index ecbef66..ea2ca6c 100644
--- a/tests/test_cli_smoke.py
+++ b/tests/test_cli_smoke.py
@@ -31,6 +31,43 @@ def _write_fixture_json(path: Path, payload: dict) -> None:
path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
+def _seed_command_json_artifacts(tmp_path: Path, fixture_dir: Path, commands: list[str]) -> None:
+ for command in commands:
+ argv = ["--outdir", str(tmp_path), "--output", "json", command]
+ if command == "role-trusts":
+ argv.append("--mode")
+ argv.append("fast")
+ result = runner.invoke(
+ app,
+ argv,
+ env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)},
+ )
+ assert result.exit_code == 0, result.stdout
+
+
+def _patch_failing_collectors(
+ monkeypatch: pytest.MonkeyPatch,
+ failing_commands: set[str],
+) -> None:
+ def _collector_for(command: str):
+ def _failing_collector(_provider, _options):
+ raise AzureFoxError(
+ kind=ErrorKind.PERMISSION_DENIED,
+ message="collector should not have run",
+ command=command,
+ )
+
+ return _failing_collector
+
+ patched_specs = tuple(
+ CommandSpec(spec.name, spec.section, _collector_for(spec.name))
+ if spec.name in failing_commands
+ else spec
+ for spec in get_command_specs()
+ )
+ monkeypatch.setattr("azurefox.chains.runner.get_command_specs", lambda: patched_specs)
+
+
def _contributor_app_permission_trust() -> dict:
return {
"trust_type": "app-to-service-principal",
@@ -350,7 +387,15 @@ def test_cli_smoke_chains_credential_path_json(tmp_path: Path) -> None:
assert payload["command_state"] == "extraction-only"
assert payload["claim_boundary"].startswith("Can claim that the visible evidence suggests")
assert payload["current_gap"].startswith("The live family now joins backing evidence")
- assert payload["artifact_preference_order"] == []
+ assert payload["artifact_preference_order"] == ["json"]
+ assert payload["reused_sources"] == []
+ assert payload["live_sources"] == [
+ "env-vars",
+ "tokens-credentials",
+ "databases",
+ "storage",
+ "keyvault",
+ ]
assert payload["source_artifacts"] == []
assert payload["backing_commands"] == [
"env-vars",
@@ -387,6 +432,147 @@ def test_cli_smoke_chains_credential_path_json(tmp_path: Path) -> None:
} == {"narrowed candidates"}
+def test_cli_smoke_chains_reuses_matching_source_artifacts_by_default(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant"
+ commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"]
+ _seed_command_json_artifacts(tmp_path, fixture_dir, commands)
+ _patch_failing_collectors(monkeypatch, set(commands))
+
+ result = runner.invoke(
+ app,
+ ["--outdir", str(tmp_path), "--output", "json", "chains", "credential-path"],
+ env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)},
+ )
+
+ assert result.exit_code == 0
+ payload = json.loads(result.stdout)
+ assert payload["family"] == "credential-path"
+ assert payload["input_mode"] == "artifacts"
+ assert payload["reused_sources"] == commands
+ assert payload["live_sources"] == []
+ assert [item["command"] for item in payload["source_artifacts"]] == commands
+ assert len(payload["paths"]) == 3
+ assert payload["issues"] == []
+
+
+def test_cli_smoke_chains_live_only_bypasses_matching_source_artifacts(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant"
+ commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"]
+ _seed_command_json_artifacts(tmp_path, fixture_dir, commands)
+ _patch_failing_collectors(monkeypatch, set(commands))
+
+ result = runner.invoke(
+ app,
+ [
+ "--outdir",
+ str(tmp_path),
+ "--output",
+ "json",
+ "chains",
+ "--live-only",
+ "credential-path",
+ ],
+ env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)},
+ )
+
+ assert result.exit_code == 0
+ payload = json.loads(result.stdout)
+ assert payload["family"] == "credential-path"
+ assert payload["input_mode"] == "live"
+ assert payload["reused_sources"] == []
+ assert payload["source_artifacts"] == []
+ assert payload["paths"] == []
+ assert {item["scope"] for item in payload["issues"]} == set(commands)
+
+
+def test_cli_smoke_chains_falls_back_live_when_source_artifact_mismatches(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant"
+ commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"]
+ _seed_command_json_artifacts(tmp_path, fixture_dir, commands)
+
+ storage_path = tmp_path / "json" / "storage.json"
+ storage_payload = _read_fixture_json(storage_path)
+ storage_payload["metadata"]["schema_version"] = "0.0.0"
+ _write_fixture_json(storage_path, storage_payload)
+
+ _patch_failing_collectors(monkeypatch, {"storage"})
+
+ result = runner.invoke(
+ app,
+ [
+ "--outdir",
+ str(tmp_path),
+ "--output",
+ "json",
+ "chains",
+ "credential-path",
+ ],
+ env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)},
+ )
+
+ assert result.exit_code == 0
+ payload = json.loads(result.stdout)
+ assert payload["family"] == "credential-path"
+ assert payload["input_mode"] == "mixed"
+ assert payload["reused_sources"] == [
+ "env-vars",
+ "tokens-credentials",
+ "databases",
+ "keyvault",
+ ]
+ assert payload["live_sources"] == ["storage"]
+ assert [item["command"] for item in payload["source_artifacts"]] == [
+ "env-vars",
+ "tokens-credentials",
+ "databases",
+ "keyvault",
+ ]
+ skipped_issue = next(
+ item
+ for item in payload["issues"]
+ if item["scope"] == "storage" and item["kind"] == "artifact_reuse_skipped"
+ )
+ assert "schema_version mismatch" in skipped_issue["message"]
+ assert any(issue["scope"] == "storage" for issue in payload["issues"])
+
+
+def test_cli_smoke_chains_reports_invalid_artifact_reuse_rejection_reason(tmp_path: Path) -> None:
+ fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant"
+ commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"]
+ _seed_command_json_artifacts(tmp_path, fixture_dir, commands)
+
+ storage_path = tmp_path / "json" / "storage.json"
+ storage_path.write_text("{not valid json", encoding="utf-8")
+
+ result = runner.invoke(
+ app,
+ ["--outdir", str(tmp_path), "--output", "json", "chains", "credential-path"],
+ env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)},
+ )
+
+ assert result.exit_code == 0
+ payload = json.loads(result.stdout)
+ assert payload["family"] == "credential-path"
+ assert payload["input_mode"] == "mixed"
+ skipped_issue = next(
+ item
+ for item in payload["issues"]
+ if item["scope"] == "storage" and item["kind"] == "artifact_reuse_skipped"
+ )
+ assert "invalid json artifact" in skipped_issue["message"].lower()
+ assert skipped_issue["context"]["artifact_type"] == "json"
+ assert skipped_issue["context"]["reason"] == "invalid_json"
+
+
def test_cli_smoke_chains_credential_path_table_output(tmp_path: Path) -> None:
fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant"
@@ -470,7 +656,20 @@ def test_cli_smoke_chains_deployment_path_json(tmp_path: Path) -> None:
assert payload["metadata"]["command"] == "chains"
assert payload["family"] == "deployment-path"
assert payload["command_state"] == "extraction-only"
- assert payload["artifact_preference_order"] == []
+ assert payload["artifact_preference_order"] == ["json"]
+ assert payload["reused_sources"] == []
+ assert payload["live_sources"] == [
+ "devops",
+ "automation",
+ "permissions",
+ "rbac",
+ "role-trusts",
+ "keyvault",
+ "arm-deployments",
+ "aks",
+ "functions",
+ "app-services",
+ ]
assert payload["source_artifacts"] == []
assert payload["backing_commands"] == [
"devops",
@@ -608,7 +807,15 @@ def test_cli_smoke_chains_compute_control_json(tmp_path: Path) -> None:
assert payload["metadata"]["command"] == "chains"
assert payload["family"] == "compute-control"
assert payload["command_state"] == "extraction-only"
- assert payload["artifact_preference_order"] == []
+ assert payload["artifact_preference_order"] == ["json"]
+ assert payload["reused_sources"] == []
+ assert payload["live_sources"] == [
+ "tokens-credentials",
+ "env-vars",
+ "workloads",
+ "managed-identities",
+ "permissions",
+ ]
assert payload["source_artifacts"] == []
assert payload["backing_commands"] == [
"tokens-credentials",
@@ -759,7 +966,9 @@ def test_cli_smoke_chains_escalation_path_json(tmp_path: Path) -> None:
assert payload["metadata"]["command"] == "chains"
assert payload["family"] == "escalation-path"
assert payload["command_state"] == "extraction-only"
- assert payload["artifact_preference_order"] == []
+ assert payload["artifact_preference_order"] == ["json"]
+ assert payload["reused_sources"] == []
+ assert payload["live_sources"] == ["permissions", "role-trusts"]
assert payload["source_artifacts"] == []
assert payload["backing_commands"] == [
"permissions",
diff --git a/tests/test_collectors.py b/tests/test_collectors.py
index 5d5eaf4..5c8e228 100644
--- a/tests/test_collectors.py
+++ b/tests/test_collectors.py
@@ -3881,9 +3881,11 @@ def test_permissions_current_identity_uses_assignment_scopes_not_whoami_scope()
}
provider.managed_identities = lambda: {"identities": [], "issues": []}
+ principals = AzureProvider.principals(provider)
permissions = AzureProvider.permissions(provider)
current_row = permissions["permissions"][0]
+ assert "assignment_scope_ids" not in principals["principals"][0]
assert current_row["is_current_identity"] is True
assert current_row["role_assignment_count"] == 1
assert current_row["scope_count"] == 1
diff --git a/tests/test_graph_client.py b/tests/test_graph_client.py
new file mode 100644
index 0000000..26a5efc
--- /dev/null
+++ b/tests/test_graph_client.py
@@ -0,0 +1,60 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+from azurefox.clients.graph import GraphBatchRequest, GraphClient
+from azurefox.errors import ErrorKind
+
+
+def test_graph_batch_execute_preserves_status_aware_error_kinds(monkeypatch) -> None:
+ client = GraphClient(SimpleNamespace())
+
+ def _fake_post(self, _url: str, _payload: dict) -> dict:
+ return {
+ "responses": [
+ {
+ "id": "0",
+ "status": 401,
+ "body": {
+ "error": {
+ "code": "InvalidAuthenticationToken",
+ "message": "Access token is invalid",
+ }
+ },
+ },
+ {
+ "id": "1",
+ "status": 403,
+ "body": {
+ "error": {
+ "code": "Authorization_RequestDenied",
+ "message": "Forbidden",
+ }
+ },
+ },
+ {
+ "id": "2",
+ "status": 429,
+ "body": {
+ "error": {
+ "code": "TooManyRequests",
+ "message": "Rate limit exceeded",
+ }
+ },
+ },
+ ]
+ }
+
+ monkeypatch.setattr(GraphClient, "_post", _fake_post)
+
+ _results, errors = client.batch_get_objects_by_key(
+ [
+ GraphBatchRequest(key="auth", path="/applications/app-a"),
+ GraphBatchRequest(key="forbidden", path="/servicePrincipals/sp-a"),
+ GraphBatchRequest(key="throttle", path="/servicePrincipals/sp-b"),
+ ]
+ )
+
+ assert errors["auth"].kind == ErrorKind.AUTH_FAILURE
+ assert errors["forbidden"].kind == ErrorKind.PERMISSION_DENIED
+ assert errors["throttle"].kind == ErrorKind.THROTTLING