diff --git a/python/packages/autogen-ext/src/autogen_ext/governance/README.md b/python/packages/autogen-ext/src/autogen_ext/governance/README.md new file mode 100644 index 000000000000..e5a59e5e4729 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/governance/README.md @@ -0,0 +1,123 @@ +# Agent-OS Governance Extension for AutoGen + +This extension provides kernel-level governance for AutoGen multi-agent conversations using [Agent-OS](https://github.com/imran-siddique/agent-os). + +## Features + +- **Policy Enforcement**: Define rules for agent behavior +- **Tool Filtering**: Control which tools agents can use +- **Content Filtering**: Block dangerous patterns (SQL injection, shell commands) +- **Rate Limiting**: Limit messages and tool calls +- **Audit Trail**: Full logging of all agent interactions + +## Installation + +```bash +pip install autogen-ext[governance] +# or +pip install agent-os-kernel +``` + +## Quick Start + +```python +from autogen_ext.governance import GovernedTeam, GovernancePolicy +from autogen_agentchat.agents import AssistantAgent +from autogen_ext.models.openai import OpenAIChatCompletionClient + +# Create policy +policy = GovernancePolicy( + max_tool_calls=10, + max_messages=50, + blocked_patterns=["DROP TABLE", "rm -rf", "DELETE FROM"], + blocked_tools=["shell_execute"], + require_human_approval=False, +) + +# Create agents +model = OpenAIChatCompletionClient(model="gpt-4o") +analyst = AssistantAgent("analyst", model_client=model) +reviewer = AssistantAgent("reviewer", model_client=model) + +# Create governed team +team = GovernedTeam( + agents=[analyst, reviewer], + policy=policy, +) + +# Run with governance +result = await team.run("Analyze Q4 sales data") + +# Get audit log +audit = team.get_audit_log() +print(f"Total events: {len(audit)}") +``` + +## Policy Options + +```python +GovernancePolicy( + # Limits + max_messages=100, # Max messages per session + max_tool_calls=50, # Max tool invocations + timeout_seconds=300, # Session timeout + + # Tool Control + allowed_tools=["code_executor", "web_search"], # Whitelist + blocked_tools=["shell_execute"], # Blacklist + + # Content Filtering + blocked_patterns=["DROP TABLE", "rm -rf"], + max_message_length=50000, + + # Approval + require_human_approval=False, + approval_tools=["database_write"], # Tools needing approval + + # Audit + log_all_messages=True, +) +``` + +## Handling Violations + +```python +def on_violation(error): + print(f"BLOCKED: {error.policy_name} - {error.description}") + # Send alert, log to SIEM, etc. + +team = GovernedTeam( + agents=[agent1, agent2], + policy=policy, + on_violation=on_violation, +) +``` + +## Integration with Agent-OS Kernel + +For full kernel-level governance with signals, checkpoints, and policy languages: + +```python +from agent_os import KernelSpace +from agent_os.policies import SQLPolicy, CostControlPolicy + +# Create kernel with policies +kernel = KernelSpace(policy=[ + SQLPolicy(allow=["SELECT"], deny=["DROP", "DELETE"]), + CostControlPolicy(max_cost_usd=100), +]) + +# Wrap AutoGen team in kernel +@kernel.register +async def run_team(task: str): + return await team.run(task) + +# Execute with full governance +result = await kernel.execute(run_team, "Analyze data") +``` + +## Links + +- [Agent-OS GitHub](https://github.com/imran-siddique/agent-os) +- [AutoGen Documentation](https://microsoft.github.io/autogen/) +- [Governance Best Practices](https://github.com/imran-siddique/agent-os/blob/main/docs/governance.md) diff --git a/python/packages/autogen-ext/src/autogen_ext/governance/__init__.py b/python/packages/autogen-ext/src/autogen_ext/governance/__init__.py new file mode 100644 index 000000000000..3964c1a0c63a --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/governance/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Agent-OS Governance Extension for AutoGen +========================================== + +Provides kernel-level governance for AutoGen multi-agent conversations. + +Features: +- Policy enforcement for agent messages +- Tool call filtering and limits +- Content pattern blocking +- Human approval workflows +- Full audit trail + +Example: + >>> from autogen_ext.governance import GovernedTeam, GovernancePolicy + >>> from autogen_agentchat.agents import AssistantAgent + >>> + >>> policy = GovernancePolicy( + ... max_tool_calls=10, + ... blocked_patterns=["DROP TABLE", "rm -rf"], + ... require_human_approval=False, + ... ) + >>> + >>> team = GovernedTeam( + ... agents=[agent1, agent2], + ... policy=policy, + ... ) + >>> result = await team.run("Analyze this data") +""" + +from ._governance import ( + GovernancePolicy, + GovernedAgent, + GovernedTeam, + PolicyViolationError, + ExecutionContext, +) + +__all__ = [ + "GovernancePolicy", + "GovernedAgent", + "GovernedTeam", + "PolicyViolationError", + "ExecutionContext", +] diff --git a/python/packages/autogen-ext/src/autogen_ext/governance/_governance.py b/python/packages/autogen-ext/src/autogen_ext/governance/_governance.py new file mode 100644 index 000000000000..8cbcce5cf968 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/governance/_governance.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Agent-OS Governance Implementation for AutoGen +=============================================== + +Kernel-level governance for AutoGen multi-agent conversations. +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence + +logger = logging.getLogger(__name__) + + +@dataclass +class GovernancePolicy: + """Policy configuration for governed agents.""" + + # Message limits + max_messages: int = 100 + max_tool_calls: int = 50 + timeout_seconds: int = 300 + + # Tool filtering + allowed_tools: List[str] = field(default_factory=list) + blocked_tools: List[str] = field(default_factory=list) + + # Content filtering + blocked_patterns: List[str] = field(default_factory=list) + max_message_length: int = 50000 + + # Approval flows + require_human_approval: bool = False + approval_tools: List[str] = field(default_factory=list) + + # Audit + log_all_messages: bool = True + + +@dataclass +class ExecutionContext: + """Runtime context for governed execution.""" + + session_id: str + policy: GovernancePolicy + started_at: datetime = field(default_factory=datetime.utcnow) + + # Counters + message_count: int = 0 + tool_calls: int = 0 + + # Audit trail + events: List[Dict[str, Any]] = field(default_factory=list) + + def record_event(self, event_type: str, data: Dict[str, Any]) -> None: + """Record an audit event.""" + self.events.append( + { + "type": event_type, + "timestamp": datetime.utcnow().isoformat(), + "data": data, + } + ) + + +class PolicyViolationError(Exception): + """Raised when a policy violation is detected.""" + + def __init__(self, policy_name: str, description: str, severity: str = "high"): + self.policy_name = policy_name + self.description = description + self.severity = severity + super().__init__(f"Policy violation ({policy_name}): {description}") + + +class GovernedAgent: + """ + Wrapper that adds governance to any AutoGen agent. + + Intercepts messages and tool calls to enforce policies. + """ + + def __init__( + self, + agent: Any, + policy: GovernancePolicy, + on_violation: Optional[Callable[[PolicyViolationError], None]] = None, + ): + self._agent = agent + self._policy = policy + self._on_violation = on_violation or self._default_violation_handler + self._context = ExecutionContext( + session_id=str(uuid.uuid4())[:8], + policy=policy, + ) + + def _default_violation_handler(self, error: PolicyViolationError) -> None: + """Default handler logs violations.""" + logger.error(f"Policy violation: {error}") + + @property + def name(self) -> str: + """Get agent name.""" + return getattr(self._agent, "name", "unknown") + + @property + def original(self) -> Any: + """Get original unwrapped agent.""" + return self._agent + + def _check_content(self, content: str) -> tuple[bool, str]: + """Check content against blocked patterns.""" + if len(content) > self._policy.max_message_length: + return False, f"Message exceeds max length ({len(content)} > {self._policy.max_message_length})" + + content_lower = content.lower() + for pattern in self._policy.blocked_patterns: + if pattern.lower() in content_lower: + return False, f"Content matches blocked pattern: {pattern}" + + return True, "" + + def _check_tool(self, tool_name: str) -> tuple[bool, str]: + """Check if tool is allowed.""" + if tool_name in self._policy.blocked_tools: + return False, f"Tool '{tool_name}' is blocked" + + if self._policy.allowed_tools and tool_name not in self._policy.allowed_tools: + return False, f"Tool '{tool_name}' not in allowed list" + + if self._context.tool_calls >= self._policy.max_tool_calls: + return False, f"Tool call limit ({self._policy.max_tool_calls}) exceeded" + + return True, "" + + async def on_messages( + self, + messages: Sequence[Any], + cancellation_token: Optional[Any] = None, + ) -> Any: + """Handle incoming messages with governance.""" + # Check message count + if self._context.message_count >= self._policy.max_messages: + error = PolicyViolationError( + "message_limit", + f"Message limit ({self._policy.max_messages}) exceeded", + ) + self._on_violation(error) + raise error + + # Check each message content + for msg in messages: + content = getattr(msg, "content", str(msg)) + if isinstance(content, str): + ok, reason = self._check_content(content) + if not ok: + error = PolicyViolationError("content_filter", reason) + self._on_violation(error) + raise error + + self._context.message_count += len(messages) + self._context.record_event( + "messages_received", + {"count": len(messages)}, + ) + + # Forward to original agent + if hasattr(self._agent, "on_messages"): + return await self._agent.on_messages(messages, cancellation_token) + + return None + + async def on_messages_stream( + self, + messages: Sequence[Any], + cancellation_token: Optional[Any] = None, + ) -> AsyncGenerator[Any, None]: + """Handle streaming messages with governance.""" + # Pre-check + for msg in messages: + content = getattr(msg, "content", str(msg)) + if isinstance(content, str): + ok, reason = self._check_content(content) + if not ok: + error = PolicyViolationError("content_filter", reason) + self._on_violation(error) + raise error + + self._context.message_count += len(messages) + + # Stream from original + if hasattr(self._agent, "on_messages_stream"): + async for chunk in self._agent.on_messages_stream(messages, cancellation_token): + yield chunk + + def __getattr__(self, name: str) -> Any: + """Forward unknown attributes to original agent.""" + return getattr(self._agent, name) + + +class GovernedTeam: + """ + Governed team of AutoGen agents. + + Wraps a team to enforce policies across all agent interactions. + """ + + def __init__( + self, + agents: List[Any], + policy: Optional[GovernancePolicy] = None, + termination_condition: Optional[Any] = None, + on_violation: Optional[Callable[[PolicyViolationError], None]] = None, + ): + self._policy = policy or GovernancePolicy() + self._on_violation = on_violation + + # Wrap all agents + self._governed_agents = [ + GovernedAgent(agent, self._policy, on_violation) for agent in agents + ] + + self._termination_condition = termination_condition + self._context = ExecutionContext( + session_id=str(uuid.uuid4())[:8], + policy=self._policy, + ) + + @property + def agents(self) -> List[GovernedAgent]: + """Get governed agents.""" + return self._governed_agents + + async def run( + self, + task: str, + cancellation_token: Optional[Any] = None, + ) -> Any: + """Run team with governance.""" + # Check task content + ok, reason = self._check_content(task) + if not ok: + error = PolicyViolationError("content_filter", reason) + if self._on_violation: + self._on_violation(error) + raise error + + self._context.record_event("team_run_start", {"task_length": len(task)}) + + try: + # Import RoundRobinGroupChat dynamically to avoid hard dependency + from autogen_agentchat.teams import RoundRobinGroupChat + + # Create team with governed agents + original_agents = [ga.original for ga in self._governed_agents] + team = RoundRobinGroupChat( + original_agents, + termination_condition=self._termination_condition, + ) + + result = await team.run(task=task, cancellation_token=cancellation_token) + + self._context.record_event("team_run_complete", {"success": True}) + return result + + except ImportError: + # Fallback: just run first agent + logger.warning("autogen_agentchat not available, running first agent only") + if self._governed_agents: + return await self._governed_agents[0].on_messages([task], cancellation_token) + return None + + async def run_stream( + self, + task: str, + cancellation_token: Optional[Any] = None, + ) -> AsyncGenerator[Any, None]: + """Run team with streaming and governance.""" + ok, reason = self._check_content(task) + if not ok: + error = PolicyViolationError("content_filter", reason) + if self._on_violation: + self._on_violation(error) + raise error + + try: + from autogen_agentchat.teams import RoundRobinGroupChat + + original_agents = [ga.original for ga in self._governed_agents] + team = RoundRobinGroupChat( + original_agents, + termination_condition=self._termination_condition, + ) + + async for chunk in team.run_stream(task=task, cancellation_token=cancellation_token): + yield chunk + + except ImportError: + logger.warning("autogen_agentchat not available") + + def _check_content(self, content: str) -> tuple[bool, str]: + """Check content against policy.""" + if len(content) > self._policy.max_message_length: + return False, f"Content exceeds max length" + + content_lower = content.lower() + for pattern in self._policy.blocked_patterns: + if pattern.lower() in content_lower: + return False, f"Content matches blocked pattern: {pattern}" + + return True, "" + + def get_audit_log(self) -> List[Dict[str, Any]]: + """Get combined audit log from team and all agents.""" + events = list(self._context.events) + for agent in self._governed_agents: + events.extend(agent._context.events) + return sorted(events, key=lambda e: e["timestamp"]) + + def get_stats(self) -> Dict[str, Any]: + """Get governance statistics.""" + total_messages = sum(a._context.message_count for a in self._governed_agents) + total_tool_calls = sum(a._context.tool_calls for a in self._governed_agents) + + return { + "session_id": self._context.session_id, + "agent_count": len(self._governed_agents), + "total_messages": total_messages, + "total_tool_calls": total_tool_calls, + "policy": { + "max_messages": self._policy.max_messages, + "max_tool_calls": self._policy.max_tool_calls, + "blocked_patterns_count": len(self._policy.blocked_patterns), + }, + } diff --git a/python/packages/autogen-ext/tests/governance/__init__.py b/python/packages/autogen-ext/tests/governance/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/packages/autogen-ext/tests/governance/test_governance.py b/python/packages/autogen-ext/tests/governance/test_governance.py new file mode 100644 index 000000000000..f530babf1071 --- /dev/null +++ b/python/packages/autogen-ext/tests/governance/test_governance.py @@ -0,0 +1,496 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Tests for Agent-OS Governance Extension +======================================== + +Covers: GovernancePolicy, ExecutionContext, PolicyViolationError, + GovernedAgent, GovernedTeam. +""" + +import asyncio +import logging +from dataclasses import fields +from types import SimpleNamespace +from typing import Any, Optional, Sequence +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from autogen_ext.governance import ( + ExecutionContext, + GovernancePolicy, + GovernedAgent, + GovernedTeam, + PolicyViolationError, +) + + +# ── Helpers ─────────────────────────────────────────────────────── + + +class FakeAgent: + """Minimal mock agent with name and on_messages.""" + + def __init__(self, name: str = "fake-agent"): + self.name = name + self.on_messages = AsyncMock(return_value="ok") + self.custom_attr = 42 + + +class FakeStreamAgent: + """Agent that supports on_messages_stream.""" + + def __init__(self, name: str = "stream-agent", chunks: list | None = None): + self.name = name + self._chunks = chunks or ["chunk1", "chunk2"] + + async def on_messages_stream( + self, messages: Sequence[Any], cancellation_token: Optional[Any] = None + ): + for c in self._chunks: + yield c + + +# ── GovernancePolicy ────────────────────────────────────────────── + + +class TestGovernancePolicy: + def test_defaults(self): + p = GovernancePolicy() + assert p.max_messages == 100 + assert p.max_tool_calls == 50 + assert p.timeout_seconds == 300 + assert p.blocked_patterns == [] + assert p.blocked_tools == [] + assert p.allowed_tools == [] + assert p.max_message_length == 50000 + assert p.require_human_approval is False + assert p.approval_tools == [] + assert p.log_all_messages is True + + def test_custom_values(self): + p = GovernancePolicy( + max_messages=10, + max_tool_calls=5, + blocked_patterns=["DROP TABLE"], + blocked_tools=["shell"], + allowed_tools=["search"], + max_message_length=1000, + ) + assert p.max_messages == 10 + assert p.max_tool_calls == 5 + assert p.blocked_patterns == ["DROP TABLE"] + assert p.blocked_tools == ["shell"] + assert p.allowed_tools == ["search"] + assert p.max_message_length == 1000 + + def test_is_dataclass(self): + names = {f.name for f in fields(GovernancePolicy)} + expected = { + "max_messages", + "max_tool_calls", + "timeout_seconds", + "allowed_tools", + "blocked_tools", + "blocked_patterns", + "max_message_length", + "require_human_approval", + "approval_tools", + "log_all_messages", + } + assert names == expected + + +# ── ExecutionContext ─────────────────────────────────────────────── + + +class TestExecutionContext: + def test_creation(self): + p = GovernancePolicy() + ctx = ExecutionContext(session_id="test-123", policy=p) + assert ctx.session_id == "test-123" + assert ctx.message_count == 0 + assert ctx.tool_calls == 0 + assert ctx.events == [] + + def test_record_event(self): + p = GovernancePolicy() + ctx = ExecutionContext(session_id="s1", policy=p) + ctx.record_event("test_event", {"key": "value"}) + assert len(ctx.events) == 1 + assert ctx.events[0]["type"] == "test_event" + assert ctx.events[0]["data"] == {"key": "value"} + assert "timestamp" in ctx.events[0] + + def test_multiple_events(self): + p = GovernancePolicy() + ctx = ExecutionContext(session_id="s1", policy=p) + for i in range(5): + ctx.record_event(f"event_{i}", {"i": i}) + assert len(ctx.events) == 5 + assert ctx.events[4]["type"] == "event_4" + + +# ── PolicyViolationError ────────────────────────────────────────── + + +class TestPolicyViolationError: + def test_basic(self): + e = PolicyViolationError("content_filter", "blocked pattern found") + assert e.policy_name == "content_filter" + assert e.description == "blocked pattern found" + assert e.severity == "high" + assert "content_filter" in str(e) + assert "blocked pattern found" in str(e) + + def test_custom_severity(self): + e = PolicyViolationError("rate_limit", "too many calls", severity="medium") + assert e.severity == "medium" + + def test_is_exception(self): + with pytest.raises(PolicyViolationError): + raise PolicyViolationError("test", "test error") + + +# ── GovernedAgent ───────────────────────────────────────────────── + + +class TestGovernedAgent: + def test_wraps_agent(self): + agent = FakeAgent("my-agent") + ga = GovernedAgent(agent, GovernancePolicy()) + assert ga.name == "my-agent" + assert ga.original is agent + + def test_name_fallback(self): + ga = GovernedAgent(object(), GovernancePolicy()) + assert ga.name == "unknown" + + def test_getattr_forwarding(self): + agent = FakeAgent() + ga = GovernedAgent(agent, GovernancePolicy()) + assert ga.custom_attr == 42 + + # ── Content checks ──────────────────────────────────────────── + + def test_check_content_passes(self): + ga = GovernedAgent(FakeAgent(), GovernancePolicy()) + ok, reason = ga._check_content("hello world") + assert ok is True + assert reason == "" + + def test_check_content_blocks_pattern(self): + policy = GovernancePolicy(blocked_patterns=["DROP TABLE", "rm -rf"]) + ga = GovernedAgent(FakeAgent(), policy) + ok, reason = ga._check_content("please DROP TABLE users") + assert ok is False + assert "DROP TABLE" in reason + + def test_check_content_case_insensitive(self): + policy = GovernancePolicy(blocked_patterns=["DROP TABLE"]) + ga = GovernedAgent(FakeAgent(), policy) + ok, reason = ga._check_content("drop table users") + assert ok is False + + def test_check_content_max_length(self): + policy = GovernancePolicy(max_message_length=10) + ga = GovernedAgent(FakeAgent(), policy) + ok, reason = ga._check_content("a" * 11) + assert ok is False + assert "max length" in reason + + # ── Tool checks ─────────────────────────────────────────────── + + def test_check_tool_allowed(self): + ga = GovernedAgent(FakeAgent(), GovernancePolicy()) + ok, reason = ga._check_tool("web_search") + assert ok is True + + def test_check_tool_blocked(self): + policy = GovernancePolicy(blocked_tools=["shell_execute"]) + ga = GovernedAgent(FakeAgent(), policy) + ok, reason = ga._check_tool("shell_execute") + assert ok is False + assert "blocked" in reason + + def test_check_tool_allowlist(self): + policy = GovernancePolicy(allowed_tools=["search", "read"]) + ga = GovernedAgent(FakeAgent(), policy) + ok, _ = ga._check_tool("search") + assert ok is True + ok, reason = ga._check_tool("shell") + assert ok is False + assert "not in allowed list" in reason + + def test_check_tool_limit(self): + policy = GovernancePolicy(max_tool_calls=2) + ga = GovernedAgent(FakeAgent(), policy) + ga._context.tool_calls = 2 + ok, reason = ga._check_tool("anything") + assert ok is False + assert "limit" in reason + + # ── on_messages ─────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_on_messages_forwards(self): + agent = FakeAgent() + ga = GovernedAgent(agent, GovernancePolicy()) + msgs = [SimpleNamespace(content="hello")] + result = await ga.on_messages(msgs) + assert result == "ok" + agent.on_messages.assert_awaited_once_with(msgs, None) + + @pytest.mark.asyncio + async def test_on_messages_records_audit(self): + ga = GovernedAgent(FakeAgent(), GovernancePolicy()) + await ga.on_messages([SimpleNamespace(content="hi")]) + assert ga._context.message_count == 1 + assert len(ga._context.events) == 1 + assert ga._context.events[0]["type"] == "messages_received" + + @pytest.mark.asyncio + async def test_on_messages_blocks_pattern(self): + policy = GovernancePolicy(blocked_patterns=["rm -rf"]) + violations = [] + ga = GovernedAgent(FakeAgent(), policy, on_violation=violations.append) + with pytest.raises(PolicyViolationError): + await ga.on_messages([SimpleNamespace(content="run rm -rf /")]) + assert len(violations) == 1 + assert violations[0].policy_name == "content_filter" + + @pytest.mark.asyncio + async def test_on_messages_limit(self): + policy = GovernancePolicy(max_messages=1) + ga = GovernedAgent(FakeAgent(), policy) + await ga.on_messages([SimpleNamespace(content="first")]) + with pytest.raises(PolicyViolationError, match="Message limit"): + await ga.on_messages([SimpleNamespace(content="second")]) + + @pytest.mark.asyncio + async def test_on_messages_no_on_messages_attr(self): + """Agent without on_messages returns None.""" + ga = GovernedAgent(object(), GovernancePolicy()) + result = await ga.on_messages([SimpleNamespace(content="hi")]) + assert result is None + + @pytest.mark.asyncio + async def test_on_messages_default_violation_handler(self, caplog): + """Default handler logs violations.""" + policy = GovernancePolicy(blocked_patterns=["bad"]) + ga = GovernedAgent(FakeAgent(), policy) + with pytest.raises(PolicyViolationError): + with caplog.at_level(logging.ERROR): + await ga.on_messages([SimpleNamespace(content="bad content")]) + + # ── on_messages_stream ──────────────────────────────────────── + + @pytest.mark.asyncio + async def test_on_messages_stream(self): + agent = FakeStreamAgent(chunks=["a", "b", "c"]) + ga = GovernedAgent(agent, GovernancePolicy()) + chunks = [] + async for c in ga.on_messages_stream([SimpleNamespace(content="hi")]): + chunks.append(c) + assert chunks == ["a", "b", "c"] + assert ga._context.message_count == 1 + + @pytest.mark.asyncio + async def test_on_messages_stream_blocks(self): + policy = GovernancePolicy(blocked_patterns=["evil"]) + ga = GovernedAgent(FakeStreamAgent(), policy) + with pytest.raises(PolicyViolationError): + async for _ in ga.on_messages_stream([SimpleNamespace(content="evil plan")]): + pass + + @pytest.mark.asyncio + async def test_on_messages_stream_no_stream_attr(self): + """Agent without on_messages_stream yields nothing.""" + ga = GovernedAgent(object(), GovernancePolicy()) + chunks = [] + async for c in ga.on_messages_stream([SimpleNamespace(content="hi")]): + chunks.append(c) + assert chunks == [] + + +# ── GovernedTeam ────────────────────────────────────────────────── + + +class TestGovernedTeam: + def test_wraps_agents(self): + agents = [FakeAgent("a1"), FakeAgent("a2")] + team = GovernedTeam(agents=agents) + assert len(team.agents) == 2 + assert team.agents[0].name == "a1" + assert team.agents[1].name == "a2" + + def test_default_policy(self): + team = GovernedTeam(agents=[FakeAgent()]) + assert team._policy.max_messages == 100 + + def test_custom_policy(self): + policy = GovernancePolicy(max_messages=5) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + assert team._policy.max_messages == 5 + + # ── Content check ───────────────────────────────────────────── + + def test_check_content_passes(self): + team = GovernedTeam(agents=[FakeAgent()]) + ok, reason = team._check_content("valid task") + assert ok is True + + def test_check_content_blocks_pattern(self): + policy = GovernancePolicy(blocked_patterns=["DELETE FROM"]) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + ok, reason = team._check_content("DELETE FROM users") + assert ok is False + + def test_check_content_max_length(self): + policy = GovernancePolicy(max_message_length=5) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + ok, reason = team._check_content("too long content") + assert ok is False + + # ── run ──────────────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_run_blocks_bad_task(self): + policy = GovernancePolicy(blocked_patterns=["rm -rf"]) + violations = [] + team = GovernedTeam( + agents=[FakeAgent()], + policy=policy, + on_violation=violations.append, + ) + with pytest.raises(PolicyViolationError): + await team.run("rm -rf /") + assert len(violations) == 1 + + @pytest.mark.asyncio + async def test_run_records_audit(self): + """run() records team_run_start even if autogen_agentchat unavailable.""" + team = GovernedTeam(agents=[FakeAgent()]) + # This will hit ImportError fallback for RoundRobinGroupChat + await team.run("simple task") + events = [e for e in team._context.events if e["type"] == "team_run_start"] + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_run_without_violation_handler(self): + """run() works when on_violation is None and content is blocked.""" + policy = GovernancePolicy(blocked_patterns=["bad"]) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + with pytest.raises(PolicyViolationError): + await team.run("bad task") + + # ── run_stream ──────────────────────────────────────────────── + + @pytest.mark.asyncio + async def test_run_stream_blocks_bad_task(self): + policy = GovernancePolicy(blocked_patterns=["DROP"]) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + with pytest.raises(PolicyViolationError): + async for _ in team.run_stream("DROP TABLE"): + pass + + @pytest.mark.asyncio + async def test_run_stream_without_violation_handler(self): + policy = GovernancePolicy(blocked_patterns=["bad"]) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + with pytest.raises(PolicyViolationError): + async for _ in team.run_stream("bad task"): + pass + + # ── Audit & Stats ───────────────────────────────────────────── + + def test_get_audit_log_empty(self): + team = GovernedTeam(agents=[FakeAgent()]) + assert team.get_audit_log() == [] + + @pytest.mark.asyncio + async def test_get_audit_log_combined(self): + team = GovernedTeam(agents=[FakeAgent(), FakeAgent()]) + # Trigger some events + await team.run("task") + log = team.get_audit_log() + assert len(log) >= 1 # At least team_run_start + + def test_get_stats(self): + agents = [FakeAgent("a1"), FakeAgent("a2")] + policy = GovernancePolicy( + max_messages=20, + max_tool_calls=10, + blocked_patterns=["X"], + ) + team = GovernedTeam(agents=agents, policy=policy) + stats = team.get_stats() + assert stats["agent_count"] == 2 + assert stats["total_messages"] == 0 + assert stats["total_tool_calls"] == 0 + assert stats["policy"]["max_messages"] == 20 + assert stats["policy"]["max_tool_calls"] == 10 + assert stats["policy"]["blocked_patterns_count"] == 1 + + @pytest.mark.asyncio + async def test_get_stats_after_messages(self): + team = GovernedTeam(agents=[FakeAgent()]) + await team.run("task") + stats = team.get_stats() + # ImportError fallback sends messages to first governed agent + assert stats["total_messages"] >= 0 + + +# ── Integration ─────────────────────────────────────────────────── + + +class TestIntegration: + @pytest.mark.asyncio + async def test_full_workflow(self): + """End-to-end: create policy → wrap agents → run → audit.""" + policy = GovernancePolicy( + max_messages=50, + max_tool_calls=10, + blocked_patterns=["DROP TABLE", "rm -rf", "DELETE FROM"], + blocked_tools=["shell_execute"], + ) + + a1 = FakeAgent("analyst") + a2 = FakeAgent("reviewer") + + violations = [] + team = GovernedTeam( + agents=[a1, a2], + policy=policy, + on_violation=violations.append, + ) + + # Good task succeeds + await team.run("Analyze Q4 sales data") + assert len(violations) == 0 + + # Bad task fails + with pytest.raises(PolicyViolationError): + await team.run("DROP TABLE users") + assert len(violations) == 1 + + # Audit has entries + log = team.get_audit_log() + assert len(log) >= 1 + + # Stats are correct + stats = team.get_stats() + assert stats["agent_count"] == 2 + + @pytest.mark.asyncio + async def test_multiple_patterns(self): + """All blocked patterns are enforced.""" + policy = GovernancePolicy( + blocked_patterns=["DROP TABLE", "rm -rf", "DELETE FROM", "EXEC xp_"] + ) + team = GovernedTeam(agents=[FakeAgent()], policy=policy) + + for bad in ["DROP TABLE x", "rm -rf /", "DELETE FROM y", "EXEC xp_cmdshell"]: + with pytest.raises(PolicyViolationError): + await team.run(bad)