diff --git a/README.md b/README.md index 73d0731..025a3d7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # AzureFox

- AzureFox logo + AzureFox logo

Find attack paths, pivot opportunities, and movement across Azure before you drown in inventory. diff --git a/schemas/chains.schema.json b/schemas/chains.schema.json index 62d6d10..266d003 100644 --- a/schemas/chains.schema.json +++ b/schemas/chains.schema.json @@ -11,6 +11,8 @@ "current_gap", "artifact_preference_order", "backing_commands", + "reused_sources", + "live_sources", "source_artifacts", "paths", "issues" diff --git a/src/azurefox/chains/registry.py b/src/azurefox/chains/registry.py index d3582b3..6b10d23 100644 --- a/src/azurefox/chains/registry.py +++ b/src/azurefox/chains/registry.py @@ -2,9 +2,27 @@ from dataclasses import dataclass +from azurefox.config import GlobalOptions +from azurefox.models.commands import ( + AksOutput, + AppServicesOutput, + ArmDeploymentsOutput, + AutomationOutput, + DatabasesOutput, + DevopsOutput, + EnvVarsOutput, + FunctionsOutput, + KeyVaultOutput, + PermissionsOutput, + RbacOutput, + RoleTrustsOutput, + StorageOutput, + TokensCredentialsOutput, +) + GROUPED_COMMAND_NAME = "chains" -GROUPED_COMMAND_INPUT_MODES = ("live", "artifacts") -PREFERRED_ARTIFACT_ORDER = ("loot", "json") +GROUPED_COMMAND_INPUT_MODES = ("live", "artifacts", "mixed") +PREFERRED_ARTIFACT_ORDER = ("json",) SEMANTIC_LOOT_CHAIN_FAMILIES = ( "credential-path", "deployment-path", @@ -32,6 +50,24 @@ class ChainFamilySpec: source_commands: tuple[ChainSourceSpec, ...] +_CHAIN_SOURCE_MODELS = { + "devops": DevopsOutput, + "automation": AutomationOutput, + "permissions": PermissionsOutput, + "rbac": RbacOutput, + "role-trusts": RoleTrustsOutput, + "keyvault": KeyVaultOutput, + "arm-deployments": ArmDeploymentsOutput, + "app-services": AppServicesOutput, + "functions": FunctionsOutput, + "aks": AksOutput, + "env-vars": EnvVarsOutput, + "tokens-credentials": TokensCredentialsOutput, + "databases": DatabasesOutput, + "storage": StorageOutput, +} + + CHAIN_FAMILIES: tuple[ChainFamilySpec, ...] = ( ChainFamilySpec( name="credential-path", @@ -498,3 +534,37 @@ def get_chain_family_spec(name: str) -> ChainFamilySpec | None: if spec.name == name: return spec return None + + +def chain_source_model(command: str): + try: + return _CHAIN_SOURCE_MODELS[command] + except KeyError as exc: + raise ValueError(f"Missing empty grouped-source model for '{command}'") from exc + + +def empty_chain_source_fields( + command: str, + issue: dict[str, object], + options: GlobalOptions, +) -> dict[str, object]: + fields_by_command: dict[str, dict[str, object]] = { + "devops": {"pipelines": [], "issues": [issue]}, + "automation": {"automation_accounts": [], "issues": [issue]}, + "permissions": {"permissions": [], "issues": [issue]}, + "rbac": {"principals": [], "scopes": [], "role_assignments": [], "issues": [issue]}, + "role-trusts": {"mode": options.role_trusts_mode, "trusts": [], "issues": [issue]}, + "keyvault": {"key_vaults": [], "issues": [issue]}, + "arm-deployments": {"deployments": [], "issues": [issue]}, + "app-services": {"app_services": [], "issues": [issue]}, + "functions": {"function_apps": [], "issues": [issue]}, + "aks": {"aks_clusters": [], "issues": [issue]}, + "env-vars": {"env_vars": [], "issues": [issue]}, + "tokens-credentials": {"surfaces": [], "issues": [issue]}, + "databases": {"database_servers": [], "issues": [issue]}, + "storage": {"storage_assets": [], "issues": [issue]}, + } + try: + return fields_by_command[command] + except KeyError as exc: + raise ValueError(f"Missing empty grouped-source shape for '{command}'") from exc diff --git a/src/azurefox/chains/runner.py b/src/azurefox/chains/runner.py index 6a2db86..33c4bfb 100644 --- a/src/azurefox/chains/runner.py +++ b/src/azurefox/chains/runner.py @@ -1,8 +1,10 @@ from __future__ import annotations +import json import re from collections import defaultdict -from dataclasses import replace +from dataclasses import dataclass, replace +from pathlib import Path from azurefox.chains.compute_control import collect_compute_control_records from azurefox.chains.credential_path import collect_credential_path_records @@ -14,7 +16,10 @@ ) from azurefox.chains.registry import ( GROUPED_COMMAND_NAME, + PREFERRED_ARTIFACT_ORDER, ChainFamilySpec, + chain_source_model, + empty_chain_source_fields, get_chain_family_spec, implemented_chain_family_names, is_implemented_chain_family, @@ -34,26 +39,16 @@ ) from azurefox.models.chains import ( ChainPathRecord, + ChainSourceArtifact, ChainsOutput, ) -from azurefox.models.commands import ( - AksOutput, - AppServicesOutput, - ArmDeploymentsOutput, - AutomationOutput, - ChainsCommandOutput, - DatabasesOutput, - DevopsOutput, - EnvVarsOutput, - FunctionsOutput, - KeyVaultOutput, - PermissionsOutput, - RbacOutput, - RoleTrustsOutput, - StorageOutput, - TokensCredentialsOutput, +from azurefox.models.commands import ChainsCommandOutput +from azurefox.models.common import ( + SCHEMA_VERSION, + ArmDeploymentSummary, + CollectionIssue, + CommandMetadata, ) -from azurefox.models.common import ArmDeploymentSummary, CollectionIssue, CommandMetadata from azurefox.registry import get_command_specs from azurefox.scope_hints import permission_scope_description, permission_scope_phrase from azurefox.target_matching import ( @@ -140,6 +135,41 @@ } +@dataclass(slots=True) +class ChainSourceLoadState: + command: str + output: object + mode: str + reuse_issue: CollectionIssue | None = None + source_artifact: ChainSourceArtifact | None = None + + +@dataclass(slots=True) +class ChainFamilyLoadResult: + outputs: dict[str, object] + source_states: list[ChainSourceLoadState] + + @property + def reused_sources(self) -> list[str]: + return [state.command for state in self.source_states if state.mode == "artifacts"] + + @property + def live_sources(self) -> list[str]: + return [state.command for state in self.source_states if state.mode == "live"] + + @property + def source_artifacts(self) -> list[ChainSourceArtifact]: + return [ + state.source_artifact + for state in self.source_states + if state.source_artifact is not None + ] + + @property + def reuse_issues(self) -> list[CollectionIssue]: + return [state.reuse_issue for state in self.source_states if state.reuse_issue is not None] + + def implemented_chain_families() -> tuple[str, ...]: return implemented_chain_family_names() @@ -155,15 +185,15 @@ def run_chain_family( if not is_implemented_chain_family(family_name): raise ValueError(f"Chain family '{family_name}' is not implemented yet") - loaded = _collect_family_outputs(provider, options, family_name) + load_result = _collect_family_outputs(provider, options, family_name) if family_name == "credential-path": - return _build_credential_path_output(provider, options, family_name, loaded) + return _build_credential_path_output(provider, options, family_name, load_result) if family_name == "deployment-path": - return _build_deployment_path_output(options, family_name, loaded) + return _build_deployment_path_output(options, family_name, load_result) if family_name == "escalation-path": - return _build_escalation_path_output(options, family_name, loaded) + return _build_escalation_path_output(options, family_name, load_result) if family_name == "compute-control": - return _build_compute_control_output(options, family_name, loaded) + return _build_compute_control_output(options, family_name, load_result) raise ValueError(f"Unsupported chain family '{family_name}'") @@ -172,27 +202,293 @@ def _collect_family_outputs( provider: BaseProvider, options: GlobalOptions, family_name: str, -) -> dict[str, object]: +) -> ChainFamilyLoadResult: family = get_chain_family_spec(family_name) if family is None: raise ValueError(f"Unknown chain family '{family_name}'") collector_by_name = {spec.name: spec.collector for spec in get_command_specs()} - loaded: dict[str, object] = {} + outputs: dict[str, object] = {} + source_states: list[ChainSourceLoadState] = [] for source in family.source_commands: + reused_output, source_artifact, reuse_issue = _load_reusable_chain_source_output( + provider, + options, + source.command, + ) + if reused_output is not None and source_artifact is not None: + outputs[source.command] = reused_output + source_states.append( + ChainSourceLoadState( + command=source.command, + output=reused_output, + mode="artifacts", + reuse_issue=reuse_issue, + source_artifact=source_artifact, + ) + ) + continue + collector = collector_by_name[source.command] try: - loaded[source.command] = collector(provider, options) + outputs[source.command] = collector(provider, options) except Exception as exc: - loaded[source.command] = _empty_chain_source_output( + outputs[source.command] = _empty_chain_source_output( command=source.command, provider=provider, options=options, exc=exc, ) + source_states.append( + ChainSourceLoadState( + command=source.command, + output=outputs[source.command], + mode="live", + reuse_issue=reuse_issue, + ) + ) + + return ChainFamilyLoadResult( + outputs=outputs, + source_states=source_states, + ) + + +def _load_reusable_chain_source_output( + provider: BaseProvider, + options: GlobalOptions, + command: str, +) -> tuple[object | None, ChainSourceArtifact | None, CollectionIssue | None]: + if options.live_only: + return None, None, None - return loaded + for artifact_type in PREFERRED_ARTIFACT_ORDER: + artifact_path = _chain_source_artifact_path(options, command, artifact_type) + if artifact_path is None or not artifact_path.exists(): + continue + + try: + payload = json.loads(artifact_path.read_text(encoding="utf-8")) + except OSError as exc: + return None, None, _artifact_reuse_issue( + command=command, + artifact_type=artifact_type, + artifact_path=artifact_path, + reason="read_error", + detail=str(exc), + ) + except json.JSONDecodeError as exc: + return None, None, _artifact_reuse_issue( + command=command, + artifact_type=artifact_type, + artifact_path=artifact_path, + reason="invalid_json", + detail=str(exc), + ) + + mismatch_reason = _chain_source_artifact_mismatch_reason( + payload, + provider, + options, + command, + ) + if mismatch_reason is not None: + return None, None, _artifact_reuse_issue( + command=command, + artifact_type=artifact_type, + artifact_path=artifact_path, + reason="metadata_mismatch", + detail=mismatch_reason, + ) + + model_class = chain_source_model(command) + try: + output = model_class.model_validate(payload) + except Exception as exc: + return None, None, _artifact_reuse_issue( + command=command, + artifact_type=artifact_type, + artifact_path=artifact_path, + reason="model_validation_failed", + detail=str(exc), + ) + + return ( + output, + ChainSourceArtifact( + command=command, + artifact_type=artifact_type, + path=str(artifact_path), + ), + None, + ) + + return None, None, None + + +def _chain_source_artifact_path( + options: GlobalOptions, + command: str, + artifact_type: str, +) -> Path | None: + directories = { + "json": options.json_dir, + } + directory = directories.get(artifact_type) + if directory is None: + return None + extension = "json" if artifact_type == "json" else artifact_type + return directory / f"{command}.{extension}" + + +def _chain_source_artifact_mismatch_reason( + payload: dict[str, object], + provider: BaseProvider, + options: GlobalOptions, + command: str, +) -> str | None: + metadata = payload.get("metadata") + if not isinstance(metadata, dict): + return "metadata is missing or not an object" + if metadata.get("command") != command: + return f"command mismatch: expected {command}, got {metadata.get('command')}" + if metadata.get("schema_version") != SCHEMA_VERSION: + return ( + "schema_version mismatch: " + f"expected {SCHEMA_VERSION}, got {metadata.get('schema_version')}" + ) + + context = provider.metadata_context() + current_subscription = options.subscription or context.get("subscription_id") + current_tenant = options.tenant or context.get("tenant_id") + current_token_source = context.get("token_source") + current_auth_mode = context.get("auth_mode") + + mismatch = _metadata_value_mismatch( + field_name="subscription_id", + artifact_value=metadata.get("subscription_id"), + current_value=current_subscription, + ) + if mismatch is not None: + return mismatch + mismatch = _metadata_value_mismatch( + field_name="tenant_id", + artifact_value=metadata.get("tenant_id"), + current_value=current_tenant, + ) + if mismatch is not None: + return mismatch + mismatch = _metadata_value_mismatch( + field_name="token_source", + artifact_value=metadata.get("token_source"), + current_value=current_token_source, + ) + if mismatch is not None: + return mismatch + mismatch = _metadata_value_mismatch( + field_name="auth_mode", + artifact_value=metadata.get("auth_mode"), + current_value=current_auth_mode, + ) + if mismatch is not None: + return mismatch + + if command == "devops": + mismatch = _metadata_value_mismatch( + field_name="devops_organization", + artifact_value=metadata.get("devops_organization"), + current_value=options.devops_organization, + ) + if mismatch is not None: + return mismatch + + if command == "role-trusts" and payload.get("mode") != options.role_trusts_mode: + return ( + "role-trusts mode mismatch: " + f"expected {options.role_trusts_mode}, got {payload.get('mode')}" + ) + + if command == "permissions": + current_principal_id = _current_principal_id(provider) + if current_principal_id is None: + return "current principal id could not be determined" + artifact_current_ids = sorted( + { + str(item.get("principal_id")) + for item in payload.get("permissions", []) + if isinstance(item, dict) + and item.get("is_current_identity") + and item.get("principal_id") + } + ) + if artifact_current_ids != [current_principal_id]: + return ( + "current principal mismatch: " + f"expected [{current_principal_id}], got {artifact_current_ids}" + ) + + return None + + +def _metadata_value_mismatch( + *, + field_name: str, + artifact_value: object, + current_value: str | None, +) -> str | None: + if current_value is None: + return None + if artifact_value == current_value: + return None + return f"{field_name} mismatch: expected {current_value}, got {artifact_value}" + + +def _artifact_reuse_issue( + *, + command: str, + artifact_type: str, + artifact_path: str, + reason: str, + detail: str, +) -> CollectionIssue: + reason_labels = { + "read_error": "artifact read error", + "invalid_json": "invalid JSON artifact", + "metadata_mismatch": "artifact metadata mismatch", + "model_validation_failed": "artifact model validation failed", + } + reason_label = reason_labels.get(reason, reason.replace("_", " ")) + return CollectionIssue( + kind="artifact_reuse_skipped", + message=( + f"Skipped local {artifact_type} artifact reuse for {command}: " + f"{reason_label}: {detail}" + ), + scope=command, + context={ + "collector": command, + "artifact_type": artifact_type, + "artifact_path": str(artifact_path), + "reason": reason, + }, + ) + + +def _current_principal_id(provider: BaseProvider) -> str | None: + try: + whoami = provider.whoami() + except Exception: + return None + + principal = whoami.get("principal") + if not isinstance(principal, dict): + return None + + principal_id = principal.get("id") + if not principal_id: + return None + return str(principal_id) def _empty_chain_source_output( @@ -204,7 +500,7 @@ def _empty_chain_source_output( ): issue = _chain_source_issue(command, exc) context = provider.metadata_context() - model_class = _empty_chain_source_model(command) + model_class = chain_source_model(command) payload = { "metadata": CommandMetadata( command=command, @@ -214,61 +510,11 @@ def _empty_chain_source_output( token_source=context.get("token_source"), auth_mode=context.get("auth_mode"), ), - **_empty_chain_source_fields(command, issue, options), + **empty_chain_source_fields(command, issue, options), } return model_class.model_validate(payload) -def _empty_chain_source_fields( - command: str, - issue: dict[str, object], - options: GlobalOptions, -) -> dict[str, object]: - fields_by_command: dict[str, dict[str, object]] = { - "devops": {"pipelines": [], "issues": [issue]}, - "automation": {"automation_accounts": [], "issues": [issue]}, - "permissions": {"permissions": [], "issues": [issue]}, - "rbac": {"principals": [], "scopes": [], "role_assignments": [], "issues": [issue]}, - "role-trusts": {"mode": options.role_trusts_mode, "trusts": [], "issues": [issue]}, - "keyvault": {"key_vaults": [], "issues": [issue]}, - "arm-deployments": {"deployments": [], "issues": [issue]}, - "app-services": {"app_services": [], "issues": [issue]}, - "functions": {"function_apps": [], "issues": [issue]}, - "aks": {"aks_clusters": [], "issues": [issue]}, - "env-vars": {"env_vars": [], "issues": [issue]}, - "tokens-credentials": {"surfaces": [], "issues": [issue]}, - "databases": {"database_servers": [], "issues": [issue]}, - "storage": {"storage_assets": [], "issues": [issue]}, - } - try: - return fields_by_command[command] - except KeyError as exc: - raise ValueError(f"Missing empty grouped-source shape for '{command}'") from exc - - -def _empty_chain_source_model(command: str): - models_by_command = { - "devops": DevopsOutput, - "automation": AutomationOutput, - "permissions": PermissionsOutput, - "rbac": RbacOutput, - "role-trusts": RoleTrustsOutput, - "keyvault": KeyVaultOutput, - "arm-deployments": ArmDeploymentsOutput, - "app-services": AppServicesOutput, - "functions": FunctionsOutput, - "aks": AksOutput, - "env-vars": EnvVarsOutput, - "tokens-credentials": TokensCredentialsOutput, - "databases": DatabasesOutput, - "storage": StorageOutput, - } - try: - return models_by_command[command] - except KeyError as exc: - raise ValueError(f"Missing empty grouped-source model for '{command}'") from exc - - def _chain_source_issue(command: str, exc: Exception) -> dict[str, object]: return { "kind": str(getattr(exc, "kind", "unknown")), @@ -282,12 +528,12 @@ def _build_credential_path_output( provider: BaseProvider, options: GlobalOptions, family_name: str, - loaded: dict[str, object], + load_result: ChainFamilyLoadResult, ) -> ChainsCommandOutput: family = get_chain_family_spec(family_name) assert family is not None # pragma: no cover - guarded above - paths, issues = collect_credential_path_records(provider, family_name, loaded) + paths, issues = collect_credential_path_records(provider, family_name, load_result.outputs) paths.sort( key=lambda item: ( @@ -303,6 +549,7 @@ def _build_credential_path_output( options=options, family=family, family_name=family_name, + load_result=load_result, paths=paths, issues=issues, ) @@ -311,11 +558,13 @@ def _build_credential_path_output( def _build_deployment_path_output( options: GlobalOptions, family_name: str, - loaded: dict[str, object], + load_result: ChainFamilyLoadResult, ) -> ChainsCommandOutput: family = get_chain_family_spec(family_name) assert family is not None # pragma: no cover - guarded above + loaded = load_result.outputs + devops_output = loaded["devops"] automation_output = loaded["automation"] permissions_output = loaded["permissions"] @@ -523,6 +772,7 @@ def _build_deployment_path_output( options=options, family=family, family_name=family_name, + load_result=load_result, paths=paths, issues=issues, ) @@ -531,11 +781,23 @@ def _build_deployment_path_output( def _build_escalation_path_output( options: GlobalOptions, family_name: str, - loaded: dict[str, object], + load_result: ChainFamilyLoadResult | dict[str, object], ) -> ChainsCommandOutput: family = get_chain_family_spec(family_name) assert family is not None # pragma: no cover - guarded above + if isinstance(load_result, ChainFamilyLoadResult): + loaded = load_result.outputs + else: + loaded = load_result + load_result = ChainFamilyLoadResult( + outputs=loaded, + source_states=[ + ChainSourceLoadState(command=command, output=output, mode="live") + for command, output in loaded.items() + ], + ) + permissions_output = loaded["permissions"] role_trusts_output = loaded["role-trusts"] @@ -583,6 +845,7 @@ def _build_escalation_path_output( options=options, family=family, family_name=family_name, + load_result=load_result, paths=paths, issues=issues, ) @@ -628,17 +891,18 @@ def _current_foothold_contexts_from_permissions(permissions: list[object]) -> li def _build_compute_control_output( options: GlobalOptions, family_name: str, - loaded: dict[str, object], + load_result: ChainFamilyLoadResult, ) -> ChainsCommandOutput: family = get_chain_family_spec(family_name) assert family is not None # pragma: no cover - guarded above - paths, issues = collect_compute_control_records(family_name, loaded) + paths, issues = collect_compute_control_records(family_name, load_result.outputs) return _build_chains_command_output( options=options, family=family, family_name=family_name, + load_result=load_result, paths=paths, issues=issues, ) @@ -649,9 +913,11 @@ def _build_chains_command_output( options: GlobalOptions, family: ChainFamilySpec, family_name: str, + load_result: ChainFamilyLoadResult, paths: list[ChainPathRecord], issues: list[CollectionIssue], ) -> ChainsCommandOutput: + all_issues = [*load_result.reuse_issues, *issues] return ChainsCommandOutput( metadata=CommandMetadata( command=GROUPED_COMMAND_NAME, @@ -662,19 +928,29 @@ def _build_chains_command_output( ), grouped_command_name=GROUPED_COMMAND_NAME, family=family_name, - input_mode="live", + input_mode=_chain_input_mode(load_result), command_state="extraction-only", summary=family.summary, claim_boundary=family.allowed_claim, current_gap=family.current_gap, - artifact_preference_order=[], + artifact_preference_order=[] if options.live_only else list(PREFERRED_ARTIFACT_ORDER), backing_commands=[source.command for source in family.source_commands], - source_artifacts=[], + reused_sources=list(load_result.reused_sources), + live_sources=list(load_result.live_sources), + source_artifacts=list(load_result.source_artifacts), paths=paths, - issues=issues, + issues=all_issues, ) +def _chain_input_mode(load_result: ChainFamilyLoadResult) -> str: + if load_result.reused_sources and load_result.live_sources: + return "mixed" + if load_result.reused_sources: + return "artifacts" + return "live" + + def _build_deployment_source_record( family_name: str, *, diff --git a/src/azurefox/cli.py b/src/azurefox/cli.py index e47b5cc..59c30e0 100644 --- a/src/azurefox/cli.py +++ b/src/azurefox/cli.py @@ -277,11 +277,18 @@ def vmss(ctx: typer.Context) -> None: @app.command("chains") def chains( ctx: typer.Context, + live_only: bool = typer.Option( + False, + "--live-only", + help="Collect grouped chain sources live only and skip local artifact reuse.", + ), family: str | None = typer.Argument( None, help="Chain family name, or 'help' to list the available chain families." ), ) -> None: options: GlobalOptions = ctx.obj + if live_only: + options = replace(options, live_only=True) if family in {None, "help"}: try: if options.output != OutputMode.JSON: diff --git a/src/azurefox/clients/graph.py b/src/azurefox/clients/graph.py index 5cebf3d..7f70b09 100644 --- a/src/azurefox/clients/graph.py +++ b/src/azurefox/clients/graph.py @@ -142,54 +142,48 @@ def batch_list_objects_by_key( self, requests: list[GraphBatchRequest], ) -> tuple[dict[str, list[dict[str, Any]]], dict[str, AzureFoxError]]: - pending = list(requests) partial: dict[str, list[dict[str, Any]]] = {} - results: dict[str, list[dict[str, Any]]] = {} - errors: dict[str, AzureFoxError] = {} - - while pending: - chunk = pending[:GRAPH_BATCH_MAX_REQUESTS] - pending = pending[GRAPH_BATCH_MAX_REQUESTS:] - - bodies, body_errors = self._batch_execute(chunk) - for request in chunk: - if request.key in body_errors: - errors[request.key] = body_errors[request.key] - partial.pop(request.key, None) - continue - - body = bodies.get(request.key) - if body is None: - errors[request.key] = AzureFoxError( - ErrorKind.UNKNOWN, - f"Graph batch request missing response for {request.path}", - ) - partial.pop(request.key, None) - continue - values = body.get("value", []) - if isinstance(values, list): - partial.setdefault(request.key, []).extend( - item for item in values if isinstance(item, dict) - ) + def _consume_list_body( + request: GraphBatchRequest, + body: dict[str, Any], + ) -> tuple[bool, list[dict[str, Any]] | None]: + values = body.get("value", []) + if isinstance(values, list): + partial.setdefault(request.key, []).extend( + item for item in values if isinstance(item, dict) + ) - next_url = body.get("@odata.nextLink") - if isinstance(next_url, str) and next_url: - pending.append(GraphBatchRequest(key=request.key, path=next_url)) - continue + next_url = body.get("@odata.nextLink") + if isinstance(next_url, str) and next_url: + return False, None - results[request.key] = list(partial.get(request.key, [])) - partial.pop(request.key, None) + final_items = list(partial.get(request.key, [])) + partial.pop(request.key, None) + return True, final_items - return results, errors + return self._batch_collect_by_key(requests, _consume_list_body) def batch_get_objects_by_key( self, requests: list[GraphBatchRequest], ) -> tuple[dict[str, dict[str, Any]], dict[str, AzureFoxError]]: - results: dict[str, dict[str, Any]] = {} - errors: dict[str, AzureFoxError] = {} + def _consume_get_body( + _request: GraphBatchRequest, + body: dict[str, Any], + ) -> tuple[bool, dict[str, Any] | None]: + return True, body + + return self._batch_collect_by_key(requests, _consume_get_body) + + def _batch_collect_by_key( + self, + requests: list[GraphBatchRequest], + consume_body, + ) -> tuple[dict[str, Any], dict[str, AzureFoxError]]: pending = list(requests) + results: dict[str, Any] = {} + errors: dict[str, AzureFoxError] = {} while pending: chunk = pending[:GRAPH_BATCH_MAX_REQUESTS] @@ -209,7 +203,16 @@ def batch_get_objects_by_key( ) continue - results[request.key] = body + completed, final_value = consume_body(request, body) + if not completed: + pending.append( + GraphBatchRequest( + key=request.key, + path=str(body["@odata.nextLink"]), + ) + ) + continue + results[request.key] = final_value return results, errors @@ -369,13 +372,24 @@ def _graph_batch_request_error( if message: pieces.append(message) formatted = " ".join(pieces) + kind = _graph_batch_error_kind(status=status, code=code, message=message) return AzureFoxError( - classify_exception(Exception(formatted)), + kind, formatted, details={"body": json.dumps(body)[:500]}, ) +def _graph_batch_error_kind(*, status: int, code: str, message: str) -> ErrorKind: + if status == 401: + return ErrorKind.AUTH_FAILURE + if status == 403: + return ErrorKind.PERMISSION_DENIED + if status == 429: + return ErrorKind.THROTTLING + return classify_exception(Exception(" ".join(part for part in (code, message) if part))) + + def _graph_ssl_context() -> ssl.SSLContext: if certifi is not None: return ssl.create_default_context(cafile=certifi.where()) diff --git a/src/azurefox/collectors/provider.py b/src/azurefox/collectors/provider.py index bc7a1d9..277b6ca 100644 --- a/src/azurefox/collectors/provider.py +++ b/src/azurefox/collectors/provider.py @@ -1858,102 +1858,32 @@ def principals(self) -> dict: rbac_data = self.rbac() whoami_data = self.whoami() identity_data = self.managed_identities() - - records: dict[str, dict] = {} - issues = [ - *rbac_data.get("issues", []), - *whoami_data.get("issues", []), - *identity_data.get("issues", []), - ] - - def ensure_record(principal_id: str) -> dict: - if principal_id not in records: - records[principal_id] = { - "id": principal_id, - "principal_type": "unknown", - "display_name": None, - "tenant_id": None, - "sources": [], - "scope_ids": [], - "assignment_scope_ids": [], - "role_names": [], - "role_assignment_count": 0, - "identity_names": [], - "identity_types": [], - "attached_to": [], - "is_current_identity": False, - } - return records[principal_id] - - for principal in rbac_data.get("principals", []): - principal_id = principal.get("id") - if not principal_id: - continue - record = ensure_record(principal_id) - _merge_principal_attributes(record, principal) - _append_unique(record["sources"], "rbac") - - for assignment in rbac_data.get("role_assignments", []): - principal_id = assignment.get("principal_id") - if not principal_id: - continue - record = ensure_record(principal_id) - role_name = assignment.get("role_name") - scope_id = assignment.get("scope_id") - if role_name: - _append_unique(record["role_names"], role_name) - if scope_id: - _append_unique(record["scope_ids"], scope_id) - _append_unique(record["assignment_scope_ids"], scope_id) - record["role_assignment_count"] += 1 - principal_type = assignment.get("principal_type") - if principal_type: - record["principal_type"] = _normalize_principal_type( - record["principal_type"], - principal_type, - ) - _append_unique(record["sources"], "rbac") - - principal = whoami_data.get("principal") - if principal and principal.get("id"): - record = ensure_record(principal["id"]) - _merge_principal_attributes(record, principal) - record["is_current_identity"] = True - for scope in whoami_data.get("effective_scopes", []): - scope_id = scope.get("id") - if scope_id: - _append_unique(record["scope_ids"], scope_id) - _append_unique(record["sources"], "whoami") - - for identity in identity_data.get("identities", []): - principal_id = identity.get("principal_id") - if not principal_id: - continue - record = ensure_record(principal_id) - if record["principal_type"] == "unknown": - record["principal_type"] = "ServicePrincipal" - _append_unique(record["identity_names"], identity.get("name")) - _append_unique(record["identity_types"], identity.get("identity_type")) - for scope_id in identity.get("scope_ids", []): - _append_unique(record["scope_ids"], scope_id) - for attachment in identity.get("attached_to", []): - _append_unique(record["attached_to"], attachment) - _append_unique(record["sources"], "managed-identities") - - principals = sorted( - records.values(), - key=_principal_sort_key, + principals, issues, _assignment_scope_ids_by_principal = _principal_records_from_sources( + rbac_data=rbac_data, + whoami_data=whoami_data, + identity_data=identity_data, ) return {"principals": principals, "issues": issues} def permissions(self) -> dict: - principal_data = self.principals() + rbac_data = self.rbac() + whoami_data = self.whoami() + identity_data = self.managed_identities() + principals, issues, assignment_scope_ids_by_principal = _principal_records_from_sources( + rbac_data=rbac_data, + whoami_data=whoami_data, + identity_data=identity_data, + ) permission_rows: list[dict] = [] - for principal in principal_data.get("principals", []): + for principal in principals: role_names = sorted(set(principal.get("role_names", []))) + principal_id = str(principal.get("id") or "") scope_ids = sorted( - set(principal.get("assignment_scope_ids") or principal.get("scope_ids", [])) + set( + assignment_scope_ids_by_principal.get(principal_id) + or principal.get("scope_ids", []) + ) ) high_impact_roles = sorted( { @@ -1986,7 +1916,7 @@ def permissions(self) -> dict: item["principal_id"] or "", ) ) - return {"permissions": permission_rows, "issues": principal_data.get("issues", [])} + return {"permissions": permission_rows, "issues": issues} def privesc(self) -> dict: permissions_data = self.permissions() @@ -4458,6 +4388,98 @@ def _append_unique(items: list[str], value: str | None) -> None: items.append(value) +def _principal_records_from_sources( + *, + rbac_data: dict, + whoami_data: dict, + identity_data: dict, +) -> tuple[list[dict], list[dict], dict[str, list[str]]]: + records: dict[str, dict] = {} + assignment_scope_ids_by_principal: dict[str, list[str]] = {} + issues = [ + *rbac_data.get("issues", []), + *whoami_data.get("issues", []), + *identity_data.get("issues", []), + ] + + def ensure_record(principal_id: str) -> dict: + if principal_id not in records: + records[principal_id] = { + "id": principal_id, + "principal_type": "unknown", + "display_name": None, + "tenant_id": None, + "sources": [], + "scope_ids": [], + "role_names": [], + "role_assignment_count": 0, + "identity_names": [], + "identity_types": [], + "attached_to": [], + "is_current_identity": False, + } + return records[principal_id] + + for principal in rbac_data.get("principals", []): + principal_id = principal.get("id") + if not principal_id: + continue + record = ensure_record(principal_id) + _merge_principal_attributes(record, principal) + _append_unique(record["sources"], "rbac") + + for assignment in rbac_data.get("role_assignments", []): + principal_id = assignment.get("principal_id") + if not principal_id: + continue + record = ensure_record(principal_id) + role_name = assignment.get("role_name") + scope_id = assignment.get("scope_id") + if role_name: + _append_unique(record["role_names"], role_name) + if scope_id: + _append_unique(record["scope_ids"], scope_id) + assignment_scope_ids_by_principal.setdefault(principal_id, []) + _append_unique(assignment_scope_ids_by_principal[principal_id], scope_id) + record["role_assignment_count"] += 1 + principal_type = assignment.get("principal_type") + if principal_type: + record["principal_type"] = _normalize_principal_type( + record["principal_type"], + principal_type, + ) + _append_unique(record["sources"], "rbac") + + principal = whoami_data.get("principal") + if principal and principal.get("id"): + record = ensure_record(principal["id"]) + _merge_principal_attributes(record, principal) + record["is_current_identity"] = True + for scope in whoami_data.get("effective_scopes", []): + scope_id = scope.get("id") + if scope_id: + _append_unique(record["scope_ids"], scope_id) + _append_unique(record["sources"], "whoami") + + for identity in identity_data.get("identities", []): + principal_id = identity.get("principal_id") + if not principal_id: + continue + record = ensure_record(principal_id) + if record["principal_type"] == "unknown": + record["principal_type"] = "ServicePrincipal" + _append_unique(record["identity_names"], identity.get("name")) + _append_unique(record["identity_types"], identity.get("identity_type")) + for scope_id in identity.get("scope_ids", []): + _append_unique(record["scope_ids"], scope_id) + for attachment in identity.get("attached_to", []): + _append_unique(record["attached_to"], attachment) + _append_unique(record["sources"], "managed-identities") + + principals = sorted(records.values(), key=_principal_sort_key) + return principals, issues, assignment_scope_ids_by_principal + + def _normalize_principal_type(existing: str | None, candidate: str | None) -> str: normalized_existing = existing or "unknown" if not candidate: @@ -4509,21 +4531,12 @@ def _graph_batch_list_with_fallback( requests: list[GraphBatchRequest], serial_fetch, ) -> tuple[dict[str, list[dict]], dict[str, Exception]]: - if not requests: - return {}, {} - - batch_fetch = getattr(graph, "batch_list_objects_by_key", None) - if callable(batch_fetch): - return batch_fetch(requests) - - results: dict[str, list[dict]] = {} - errors: dict[str, Exception] = {} - for request in requests: - try: - results[request.key] = serial_fetch(request) - except Exception as exc: - errors[request.key] = exc - return results, errors + return _graph_batch_with_fallback( + graph=graph, + requests=requests, + batch_method_name="batch_list_objects_by_key", + serial_fetch=serial_fetch, + ) def _graph_batch_get_with_fallback( @@ -4531,14 +4544,29 @@ def _graph_batch_get_with_fallback( requests: list[GraphBatchRequest], serial_fetch, ) -> tuple[dict[str, dict], dict[str, Exception]]: + return _graph_batch_with_fallback( + graph=graph, + requests=requests, + batch_method_name="batch_get_objects_by_key", + serial_fetch=serial_fetch, + ) + + +def _graph_batch_with_fallback( + *, + graph: object, + requests: list[GraphBatchRequest], + batch_method_name: str, + serial_fetch, +) -> tuple[dict, dict[str, Exception]]: if not requests: return {}, {} - batch_fetch = getattr(graph, "batch_get_objects_by_key", None) + batch_fetch = getattr(graph, batch_method_name, None) if callable(batch_fetch): return batch_fetch(requests) - results: dict[str, dict] = {} + results: dict = {} errors: dict[str, Exception] = {} for request in requests: try: diff --git a/src/azurefox/config.py b/src/azurefox/config.py index 484a37a..f8fb282 100644 --- a/src/azurefox/config.py +++ b/src/azurefox/config.py @@ -13,6 +13,7 @@ class GlobalOptions: output: OutputMode outdir: Path debug: bool + live_only: bool = False devops_organization: str | None = None role_trusts_mode: RoleTrustsMode = RoleTrustsMode.FAST diff --git a/src/azurefox/models/chains.py b/src/azurefox/models/chains.py index c88eb46..5c9a5ce 100644 --- a/src/azurefox/models/chains.py +++ b/src/azurefox/models/chains.py @@ -91,6 +91,8 @@ class ChainsOutput(BaseModel): current_gap: str | None = None artifact_preference_order: list[str] = Field(default_factory=list) backing_commands: list[str] = Field(default_factory=list) + reused_sources: list[str] = Field(default_factory=list) + live_sources: list[str] = Field(default_factory=list) source_artifacts: list[ChainSourceArtifact] = Field(default_factory=list) paths: list[ChainPathRecord] = Field(default_factory=list) issues: list[CollectionIssue] = Field(default_factory=list) diff --git a/tests/test_chain_scaffold.py b/tests/test_chain_scaffold.py index 94b33bf..06c1c34 100644 --- a/tests/test_chain_scaffold.py +++ b/tests/test_chain_scaffold.py @@ -53,8 +53,8 @@ def test_chain_registry_uses_expected_grouped_command_shape() -> None: assert GROUPED_COMMAND_NAME == "chains" - assert GROUPED_COMMAND_INPUT_MODES == ("live", "artifacts") - assert PREFERRED_ARTIFACT_ORDER == ("loot", "json") + assert GROUPED_COMMAND_INPUT_MODES == ("live", "artifacts", "mixed") + assert PREFERRED_ARTIFACT_ORDER == ("json",) def test_chain_registry_keeps_first_family_order() -> None: diff --git a/tests/test_cli_smoke.py b/tests/test_cli_smoke.py index ecbef66..ea2ca6c 100644 --- a/tests/test_cli_smoke.py +++ b/tests/test_cli_smoke.py @@ -31,6 +31,43 @@ def _write_fixture_json(path: Path, payload: dict) -> None: path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") +def _seed_command_json_artifacts(tmp_path: Path, fixture_dir: Path, commands: list[str]) -> None: + for command in commands: + argv = ["--outdir", str(tmp_path), "--output", "json", command] + if command == "role-trusts": + argv.append("--mode") + argv.append("fast") + result = runner.invoke( + app, + argv, + env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)}, + ) + assert result.exit_code == 0, result.stdout + + +def _patch_failing_collectors( + monkeypatch: pytest.MonkeyPatch, + failing_commands: set[str], +) -> None: + def _collector_for(command: str): + def _failing_collector(_provider, _options): + raise AzureFoxError( + kind=ErrorKind.PERMISSION_DENIED, + message="collector should not have run", + command=command, + ) + + return _failing_collector + + patched_specs = tuple( + CommandSpec(spec.name, spec.section, _collector_for(spec.name)) + if spec.name in failing_commands + else spec + for spec in get_command_specs() + ) + monkeypatch.setattr("azurefox.chains.runner.get_command_specs", lambda: patched_specs) + + def _contributor_app_permission_trust() -> dict: return { "trust_type": "app-to-service-principal", @@ -350,7 +387,15 @@ def test_cli_smoke_chains_credential_path_json(tmp_path: Path) -> None: assert payload["command_state"] == "extraction-only" assert payload["claim_boundary"].startswith("Can claim that the visible evidence suggests") assert payload["current_gap"].startswith("The live family now joins backing evidence") - assert payload["artifact_preference_order"] == [] + assert payload["artifact_preference_order"] == ["json"] + assert payload["reused_sources"] == [] + assert payload["live_sources"] == [ + "env-vars", + "tokens-credentials", + "databases", + "storage", + "keyvault", + ] assert payload["source_artifacts"] == [] assert payload["backing_commands"] == [ "env-vars", @@ -387,6 +432,147 @@ def test_cli_smoke_chains_credential_path_json(tmp_path: Path) -> None: } == {"narrowed candidates"} +def test_cli_smoke_chains_reuses_matching_source_artifacts_by_default( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant" + commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"] + _seed_command_json_artifacts(tmp_path, fixture_dir, commands) + _patch_failing_collectors(monkeypatch, set(commands)) + + result = runner.invoke( + app, + ["--outdir", str(tmp_path), "--output", "json", "chains", "credential-path"], + env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)}, + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["family"] == "credential-path" + assert payload["input_mode"] == "artifacts" + assert payload["reused_sources"] == commands + assert payload["live_sources"] == [] + assert [item["command"] for item in payload["source_artifacts"]] == commands + assert len(payload["paths"]) == 3 + assert payload["issues"] == [] + + +def test_cli_smoke_chains_live_only_bypasses_matching_source_artifacts( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant" + commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"] + _seed_command_json_artifacts(tmp_path, fixture_dir, commands) + _patch_failing_collectors(monkeypatch, set(commands)) + + result = runner.invoke( + app, + [ + "--outdir", + str(tmp_path), + "--output", + "json", + "chains", + "--live-only", + "credential-path", + ], + env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)}, + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["family"] == "credential-path" + assert payload["input_mode"] == "live" + assert payload["reused_sources"] == [] + assert payload["source_artifacts"] == [] + assert payload["paths"] == [] + assert {item["scope"] for item in payload["issues"]} == set(commands) + + +def test_cli_smoke_chains_falls_back_live_when_source_artifact_mismatches( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant" + commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"] + _seed_command_json_artifacts(tmp_path, fixture_dir, commands) + + storage_path = tmp_path / "json" / "storage.json" + storage_payload = _read_fixture_json(storage_path) + storage_payload["metadata"]["schema_version"] = "0.0.0" + _write_fixture_json(storage_path, storage_payload) + + _patch_failing_collectors(monkeypatch, {"storage"}) + + result = runner.invoke( + app, + [ + "--outdir", + str(tmp_path), + "--output", + "json", + "chains", + "credential-path", + ], + env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)}, + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["family"] == "credential-path" + assert payload["input_mode"] == "mixed" + assert payload["reused_sources"] == [ + "env-vars", + "tokens-credentials", + "databases", + "keyvault", + ] + assert payload["live_sources"] == ["storage"] + assert [item["command"] for item in payload["source_artifacts"]] == [ + "env-vars", + "tokens-credentials", + "databases", + "keyvault", + ] + skipped_issue = next( + item + for item in payload["issues"] + if item["scope"] == "storage" and item["kind"] == "artifact_reuse_skipped" + ) + assert "schema_version mismatch" in skipped_issue["message"] + assert any(issue["scope"] == "storage" for issue in payload["issues"]) + + +def test_cli_smoke_chains_reports_invalid_artifact_reuse_rejection_reason(tmp_path: Path) -> None: + fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant" + commands = ["env-vars", "tokens-credentials", "databases", "storage", "keyvault"] + _seed_command_json_artifacts(tmp_path, fixture_dir, commands) + + storage_path = tmp_path / "json" / "storage.json" + storage_path.write_text("{not valid json", encoding="utf-8") + + result = runner.invoke( + app, + ["--outdir", str(tmp_path), "--output", "json", "chains", "credential-path"], + env={"AZUREFOX_FIXTURE_DIR": str(fixture_dir)}, + ) + + assert result.exit_code == 0 + payload = json.loads(result.stdout) + assert payload["family"] == "credential-path" + assert payload["input_mode"] == "mixed" + skipped_issue = next( + item + for item in payload["issues"] + if item["scope"] == "storage" and item["kind"] == "artifact_reuse_skipped" + ) + assert "invalid json artifact" in skipped_issue["message"].lower() + assert skipped_issue["context"]["artifact_type"] == "json" + assert skipped_issue["context"]["reason"] == "invalid_json" + + def test_cli_smoke_chains_credential_path_table_output(tmp_path: Path) -> None: fixture_dir = Path(__file__).resolve().parent / "fixtures" / "lab_tenant" @@ -470,7 +656,20 @@ def test_cli_smoke_chains_deployment_path_json(tmp_path: Path) -> None: assert payload["metadata"]["command"] == "chains" assert payload["family"] == "deployment-path" assert payload["command_state"] == "extraction-only" - assert payload["artifact_preference_order"] == [] + assert payload["artifact_preference_order"] == ["json"] + assert payload["reused_sources"] == [] + assert payload["live_sources"] == [ + "devops", + "automation", + "permissions", + "rbac", + "role-trusts", + "keyvault", + "arm-deployments", + "aks", + "functions", + "app-services", + ] assert payload["source_artifacts"] == [] assert payload["backing_commands"] == [ "devops", @@ -608,7 +807,15 @@ def test_cli_smoke_chains_compute_control_json(tmp_path: Path) -> None: assert payload["metadata"]["command"] == "chains" assert payload["family"] == "compute-control" assert payload["command_state"] == "extraction-only" - assert payload["artifact_preference_order"] == [] + assert payload["artifact_preference_order"] == ["json"] + assert payload["reused_sources"] == [] + assert payload["live_sources"] == [ + "tokens-credentials", + "env-vars", + "workloads", + "managed-identities", + "permissions", + ] assert payload["source_artifacts"] == [] assert payload["backing_commands"] == [ "tokens-credentials", @@ -759,7 +966,9 @@ def test_cli_smoke_chains_escalation_path_json(tmp_path: Path) -> None: assert payload["metadata"]["command"] == "chains" assert payload["family"] == "escalation-path" assert payload["command_state"] == "extraction-only" - assert payload["artifact_preference_order"] == [] + assert payload["artifact_preference_order"] == ["json"] + assert payload["reused_sources"] == [] + assert payload["live_sources"] == ["permissions", "role-trusts"] assert payload["source_artifacts"] == [] assert payload["backing_commands"] == [ "permissions", diff --git a/tests/test_collectors.py b/tests/test_collectors.py index 5d5eaf4..5c8e228 100644 --- a/tests/test_collectors.py +++ b/tests/test_collectors.py @@ -3881,9 +3881,11 @@ def test_permissions_current_identity_uses_assignment_scopes_not_whoami_scope() } provider.managed_identities = lambda: {"identities": [], "issues": []} + principals = AzureProvider.principals(provider) permissions = AzureProvider.permissions(provider) current_row = permissions["permissions"][0] + assert "assignment_scope_ids" not in principals["principals"][0] assert current_row["is_current_identity"] is True assert current_row["role_assignment_count"] == 1 assert current_row["scope_count"] == 1 diff --git a/tests/test_graph_client.py b/tests/test_graph_client.py new file mode 100644 index 0000000..26a5efc --- /dev/null +++ b/tests/test_graph_client.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from azurefox.clients.graph import GraphBatchRequest, GraphClient +from azurefox.errors import ErrorKind + + +def test_graph_batch_execute_preserves_status_aware_error_kinds(monkeypatch) -> None: + client = GraphClient(SimpleNamespace()) + + def _fake_post(self, _url: str, _payload: dict) -> dict: + return { + "responses": [ + { + "id": "0", + "status": 401, + "body": { + "error": { + "code": "InvalidAuthenticationToken", + "message": "Access token is invalid", + } + }, + }, + { + "id": "1", + "status": 403, + "body": { + "error": { + "code": "Authorization_RequestDenied", + "message": "Forbidden", + } + }, + }, + { + "id": "2", + "status": 429, + "body": { + "error": { + "code": "TooManyRequests", + "message": "Rate limit exceeded", + } + }, + }, + ] + } + + monkeypatch.setattr(GraphClient, "_post", _fake_post) + + _results, errors = client.batch_get_objects_by_key( + [ + GraphBatchRequest(key="auth", path="/applications/app-a"), + GraphBatchRequest(key="forbidden", path="/servicePrincipals/sp-a"), + GraphBatchRequest(key="throttle", path="/servicePrincipals/sp-b"), + ] + ) + + assert errors["auth"].kind == ErrorKind.AUTH_FAILURE + assert errors["forbidden"].kind == ErrorKind.PERMISSION_DENIED + assert errors["throttle"].kind == ErrorKind.THROTTLING