diff --git a/hackagent/attacks/orchestrator.py b/hackagent/attacks/orchestrator.py index 1c87d16..6495064 100644 --- a/hackagent/attacks/orchestrator.py +++ b/hackagent/attacks/orchestrator.py @@ -24,6 +24,7 @@ import json import logging +import copy import re import shutil import subprocess @@ -55,6 +56,14 @@ logger.addHandler(logging.NullHandler()) logger.propagate = False +_REMOTE_ROLE_ENDPOINT = "https://api.hackagent.dev/v1" +_REMOTE_ATTACKER_IDENTIFIER = "hackagent-attacker" +_REMOTE_JUDGE_IDENTIFIER = "hackagent-judge" + +_LOCAL_ROLE_ENDPOINT = "http://localhost:11434" +_LOCAL_ROLE_IDENTIFIER = "gemma3:4b" +_LOCAL_ROLE_AGENT_TYPE = "OLLAMA" + class _BatchContextFilter(logging.Filter): """ @@ -104,62 +113,63 @@ class AdvPrefix(AttackOrchestrator): attack_impl_class: type = None # Must be overridden by subclass # Model-role extraction map used by pre-run availability preflight. - # Tuple format: (role_name, path_tuple, is_list) + # Tuple format: (role_name, path_tuple, is_list, role_family_for_defaults) + # role_family_for_defaults drives remote/local auto-default injection. _ATTACK_MODEL_ROLE_PATHS: Dict[ - str, Tuple[Tuple[str, Tuple[str, ...], bool], ...] + str, Tuple[Tuple[str, Tuple[str, ...], bool, Optional[str]], ...] ] = { "advprefix": ( - ("generator", ("generator",), False), - ("judge", ("judge",), False), - ("judge", ("judges",), True), + ("generator", ("generator",), False, "attacker"), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), ), "baseline": ( - ("judge", ("judge",), False), - ("judge", ("judges",), True), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), ), "flipattack": ( - ("judge", ("judge",), False), - ("judge", ("judges",), True), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), ), "tap": ( - ("attacker", ("attacker",), False), - ("judge", ("judge",), False), - ("judge", ("judges",), True), - ("on_topic_judge", ("on_topic_judge",), False), + ("attacker", ("attacker",), False, "attacker"), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), + ("on_topic_judge", ("on_topic_judge",), False, None), ), "pair": ( - ("attacker", ("attacker",), False), - ("scorer", ("scorer",), False), + ("attacker", ("attacker",), False, "attacker"), + ("scorer", ("scorer",), False, "judge"), ), "autodan_turbo": ( - ("attacker", ("attacker",), False), - ("scorer", ("scorer",), False), - ("summarizer", ("summarizer",), False), - ("embedder", ("embedder",), False), + ("attacker", ("attacker",), False, "attacker"), + ("scorer", ("scorer",), False, "judge"), + ("summarizer", ("summarizer",), False, "attacker"), + ("embedder", ("embedder",), False, None), ), "bon": ( - ("judge", ("judge",), False), - ("judge", ("judges",), True), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), ), "cipherchat": ( - ("judge", ("judge",), False), - ("judge", ("judges",), True), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), ), "h4rm3l": ( - ("judge", ("judge",), False), - ("judge", ("judges",), True), - ("decorator_llm", ("decorator_llm",), False), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), + ("decorator_llm", ("decorator_llm",), False, "attacker"), ), "pap": ( - ("attacker", ("attacker",), False), - ("judge", ("judge",), False), - ("judge", ("judges",), True), + ("attacker", ("attacker",), False, "attacker"), + ("judge", ("judge",), False, "judge"), + ("judge", ("judges",), True, "judge"), ), "indirect_prompt_injection": ( - ("attacker", ("attacker",), False), - ("judge", ("judge",), False), - ("judge", ("judges",), True), - ("embedder", ("rag_injection_params", "embedder"), False), + ("attacker", ("attacker",), False, None), + ("judge", ("judge",), False, None), + ("judge", ("judges",), True, None), + ("embedder", ("rag_injection_params", "embedder"), False, None), ), } @@ -213,6 +223,128 @@ def _create_server_attack_record( ) raise HackAgentError(f"Failed to create Attack record: {e}") from e + def _backend_api_key_for_role_defaults(self) -> Optional[str]: + """Return backend API key only when it is a non-empty string.""" + backend = getattr(self.hackagent_agent, "backend", None) + if backend is None: + return None + + getter = getattr(backend, "get_api_key", None) + if not callable(getter): + return None + + try: + api_key = getter() + except Exception: + return None + + if isinstance(api_key, str) and api_key.strip(): + return api_key + return None + + @staticmethod + def _remote_role_defaults(api_key: str) -> Dict[str, Dict[str, Any]]: + """Build remote role defaults with backend-key fallback semantics.""" + return { + "attacker": { + "identifier": _REMOTE_ATTACKER_IDENTIFIER, + "endpoint": _REMOTE_ROLE_ENDPOINT, + "agent_type": "OPENAI_SDK", + "api_key": api_key, + }, + "judge": { + "identifier": _REMOTE_JUDGE_IDENTIFIER, + "endpoint": _REMOTE_ROLE_ENDPOINT, + "agent_type": "OPENAI_SDK", + "type": "harmbench_variant", + "api_key": api_key, + }, + } + + @staticmethod + def _local_role_defaults() -> Dict[str, Dict[str, Any]]: + """Build local role defaults from the attack_roles profile.""" + return { + "attacker": { + "identifier": _LOCAL_ROLE_IDENTIFIER, + "endpoint": _LOCAL_ROLE_ENDPOINT, + "agent_type": _LOCAL_ROLE_AGENT_TYPE, + "api_key": None, + }, + "judge": { + "identifier": _LOCAL_ROLE_IDENTIFIER, + "endpoint": _LOCAL_ROLE_ENDPOINT, + "agent_type": _LOCAL_ROLE_AGENT_TYPE, + # Keep evaluator compatibility when defaults are auto-injected. + "type": "harmbench", + "api_key": None, + }, + } + + @staticmethod + def _merge_missing_keys(target: Dict[str, Any], defaults: Dict[str, Any]) -> None: + """Fill only missing keys from defaults, preserving explicit overrides.""" + for key, value in defaults.items(): + if key not in target: + target[key] = value + + @classmethod + def _role_defaults_mapping_for_attack(cls, attack_type: str) -> Dict[str, str]: + """Build role->family map from _ATTACK_MODEL_ROLE_PATHS metadata.""" + role_specs = cls._ATTACK_MODEL_ROLE_PATHS.get(attack_type) or () + mapping: Dict[str, str] = {} + for role_name, _, _, role_family in role_specs: + if role_family in {"attacker", "judge"} and role_name not in mapping: + mapping[role_name] = role_family + return mapping + + def _apply_mode_based_role_defaults( + self, attack_config: Dict[str, Any] + ) -> Dict[str, Any]: + """Apply local/remote role defaults before preflight and execution.""" + api_key = self._backend_api_key_for_role_defaults() + is_remote_mode = bool(api_key) + + selected_attack_type = ( + attack_config.get("attack_type") + if isinstance(attack_config, dict) + else None + ) or self.attack_type + normalized_attack_type = self._normalize_attack_type_for_preflight( + selected_attack_type + ) + + role_mapping = self._role_defaults_mapping_for_attack(normalized_attack_type) + if not role_mapping: + return attack_config + + resolved = copy.deepcopy(attack_config) + defaults_by_family = ( + self._remote_role_defaults(api_key) + if is_remote_mode + else self._local_role_defaults() + ) + + for role_name, role_family in role_mapping.items(): + role_defaults = dict(defaults_by_family[role_family]) + role_cfg = resolved.get(role_name) + if isinstance(role_cfg, dict): + self._merge_missing_keys(role_cfg, role_defaults) + else: + resolved[role_name] = role_defaults + + # Judge-based attacks usually consume list-style judge configs. + if role_name == "judge": + judges_cfg = resolved.get("judges") + if isinstance(judges_cfg, list) and judges_cfg: + for item in judges_cfg: + if isinstance(item, dict): + self._merge_missing_keys(item, role_defaults) + else: + resolved["judges"] = [dict(role_defaults)] + + return resolved + def _create_server_run_record( self, attack_id: str, @@ -729,7 +861,7 @@ def _register_target( role_specs = self._ATTACK_MODEL_ROLE_PATHS.get(attack_type) if role_specs: - for role, path, is_list in role_specs: + for role, path, is_list, _ in role_specs: value = self._get_nested_config_value(attack_config, path) if value is None: continue @@ -1246,6 +1378,8 @@ def execute( ValueError: If configuration is invalid HackAgentError: If server record creation fails """ + attack_config = self._apply_mode_based_role_defaults(attack_config) + # 1. Validate parameters attack_params = self._prepare_attack_params(attack_config) goal_labels_by_index = attack_params.pop("_goal_labels_by_index", None) diff --git a/tests/unit/attacks/test_orchestrator_extended.py b/tests/unit/attacks/test_orchestrator_extended.py index 1ac4ed2..8068255 100644 --- a/tests/unit/attacks/test_orchestrator_extended.py +++ b/tests/unit/attacks/test_orchestrator_extended.py @@ -196,6 +196,95 @@ def test_execute_continues_when_status_update_fails( self.assertIsNotNone(results) +class TestModeBasedRoleDefaults(unittest.TestCase): + """Test remote/local role defaults injected before attack execution.""" + + def test_remote_mode_injects_baseline_judge_defaults(self): + """Baseline judge defaults should switch to remote profile in remote mode.""" + orch, hack_agent, _ = _make_orchestrator() + orch.attack_type = "baseline" + hack_agent.backend.get_api_key.return_value = "hk_test_remote_key" + + resolved = orch._apply_mode_based_role_defaults( + {"attack_type": "baseline", "goals": ["test"]} + ) + + self.assertEqual(resolved["judge"]["identifier"], "hackagent-judge") + self.assertEqual(resolved["judge"]["endpoint"], "https://api.hackagent.dev/v1") + self.assertEqual(resolved["judge"]["agent_type"], "OPENAI_SDK") + self.assertEqual(resolved["judge"]["type"], "harmbench_variant") + self.assertEqual(resolved["judge"]["api_key"], "hk_test_remote_key") + self.assertEqual(resolved["judges"][0]["identifier"], "hackagent-judge") + self.assertEqual(resolved["judges"][0]["type"], "harmbench_variant") + + def test_remote_mode_preserves_explicit_judge_overrides(self): + """Explicit judge fields must not be overwritten by remote defaults.""" + orch, hack_agent, _ = _make_orchestrator() + orch.attack_type = "baseline" + hack_agent.backend.get_api_key.return_value = "hk_test_remote_key" + + resolved = orch._apply_mode_based_role_defaults( + { + "attack_type": "baseline", + "goals": ["test"], + "judges": [ + { + "identifier": "custom-judge", + "endpoint": "https://custom.endpoint/v1", + "agent_type": "OPENAI_SDK", + "api_key": "custom-key", + } + ], + } + ) + + self.assertEqual(resolved["judges"][0]["identifier"], "custom-judge") + self.assertEqual( + resolved["judges"][0]["endpoint"], "https://custom.endpoint/v1" + ) + self.assertEqual(resolved["judges"][0]["api_key"], "custom-key") + + def test_pair_remote_mode_fills_missing_role_fields(self): + """Partial attacker config should receive remote defaults, scorer should be added.""" + orch, hack_agent, _ = _make_orchestrator() + orch.attack_type = "pair" + hack_agent.backend.get_api_key.return_value = "hk_test_remote_key" + + resolved = orch._apply_mode_based_role_defaults( + { + "attack_type": "pair", + "goals": ["test"], + "attacker": {"identifier": "my-attacker"}, + } + ) + + self.assertEqual(resolved["attacker"]["identifier"], "my-attacker") + self.assertEqual( + resolved["attacker"]["endpoint"], "https://api.hackagent.dev/v1" + ) + self.assertEqual(resolved["attacker"]["agent_type"], "OPENAI_SDK") + self.assertEqual(resolved["attacker"]["api_key"], "hk_test_remote_key") + + self.assertEqual(resolved["scorer"]["identifier"], "hackagent-judge") + self.assertEqual(resolved["scorer"]["api_key"], "hk_test_remote_key") + + def test_local_mode_injects_baseline_judge_defaults(self): + """Without backend API key, baseline judge defaults should use local profile.""" + orch, hack_agent, _ = _make_orchestrator() + orch.attack_type = "baseline" + hack_agent.backend.get_api_key.return_value = None + + attack_config = {"attack_type": "baseline", "goals": ["test"]} + resolved = orch._apply_mode_based_role_defaults(attack_config) + + self.assertEqual(resolved["judge"]["identifier"], "gemma3:4b") + self.assertEqual(resolved["judge"]["endpoint"], "http://localhost:11434") + self.assertEqual(resolved["judge"]["agent_type"], "OLLAMA") + self.assertEqual(resolved["judge"]["type"], "harmbench") + self.assertIsNone(resolved["judge"]["api_key"]) + self.assertEqual(resolved["judges"][0]["identifier"], "gemma3:4b") + + class TestDefaultCategoryClassifierPreflight(unittest.TestCase): """Test abort behavior when default category classifier dependencies are missing."""