diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35f..b269f68f 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -14,5 +14,6 @@ from . import memory from . import sessions +from . import termination from . import version __version__ = version.__version__ diff --git a/src/google/adk_community/termination/__init__.py b/src/google/adk_community/termination/__init__.py new file mode 100644 index 00000000..589507d3 --- /dev/null +++ b/src/google/adk_community/termination/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Community termination conditions for ADK multi-agent workflows.""" + +from __future__ import annotations + +from .external_termination import ExternalTermination +from .function_call_termination import FunctionCallTermination +from .max_iterations_termination import MaxIterationsTermination +from .termination_condition import AndTerminationCondition +from .termination_condition import OrTerminationCondition +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult +from .text_mention_termination import TextMentionTermination +from .timeout_termination import TimeoutTermination +from .token_usage_termination import TokenUsageTermination + +__all__ = [ + 'AndTerminationCondition', + 'ExternalTermination', + 'FunctionCallTermination', + 'MaxIterationsTermination', + 'OrTerminationCondition', + 'TerminationCondition', + 'TerminationResult', + 'TextMentionTermination', + 'TimeoutTermination', + 'TokenUsageTermination', +] diff --git a/src/google/adk_community/termination/external_termination.py b/src/google/adk_community/termination/external_termination.py new file mode 100644 index 00000000..423841d8 --- /dev/null +++ b/src/google/adk_community/termination/external_termination.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A termination condition controlled programmatically via ``set()``.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class ExternalTermination(TerminationCondition): + """A termination condition that is controlled externally by calling ``set()``. + + Useful for integrating external stop signals such as a UI "Stop" button + or application-level logic. + + Example:: + + stop_button = ExternalTermination() + + agent = LoopAgent( + name='my_loop', + sub_agents=[...], + termination_condition=stop_button, + ) + + # Elsewhere (e.g. from a UI event handler): + stop_button.set() + """ + + def __init__(self) -> None: + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + def set(self) -> None: + """Signals that the conversation should terminate at the next check.""" + self._terminated = True + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return TerminationResult(reason='Externally terminated') + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/src/google/adk_community/termination/function_call_termination.py b/src/google/adk_community/termination/function_call_termination.py new file mode 100644 index 00000000..a253cd06 --- /dev/null +++ b/src/google/adk_community/termination/function_call_termination.py @@ -0,0 +1,60 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates when a specific function (tool) has been executed.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class FunctionCallTermination(TerminationCondition): + """Terminates when a tool with a specific name has been executed. + + The condition checks ``FunctionResponse`` parts in events. + + Example:: + + # Stop when the "approve" tool is called + condition = FunctionCallTermination('approve') + """ + + def __init__(self, function_name: str) -> None: + self._function_name = function_name + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + for event in events: + for response in event.get_function_responses(): + if response.name == self._function_name: + self._terminated = True + return TerminationResult( + reason=f"Function '{self._function_name}' was executed" + ) + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/src/google/adk_community/termination/max_iterations_termination.py b/src/google/adk_community/termination/max_iterations_termination.py new file mode 100644 index 00000000..c91f7cea --- /dev/null +++ b/src/google/adk_community/termination/max_iterations_termination.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates after a maximum number of events have been processed.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class MaxIterationsTermination(TerminationCondition): + """Terminates the conversation after a maximum number of events. + + Example:: + + # Stop after 10 events + condition = MaxIterationsTermination(10) + """ + + def __init__(self, max_iterations: int) -> None: + if max_iterations <= 0: + raise ValueError('max_iterations must be a positive integer.') + self._max_iterations = max_iterations + self._count = 0 + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + self._count += len(events) + + if self._count >= self._max_iterations: + self._terminated = True + return TerminationResult( + reason=( + f'Maximum iterations of {self._max_iterations} reached,' + f' current count: {self._count}' + ) + ) + return None + + async def reset(self) -> None: + self._terminated = False + self._count = 0 diff --git a/src/google/adk_community/termination/termination_condition.py b/src/google/adk_community/termination/termination_condition.py new file mode 100644 index 00000000..94c34caf --- /dev/null +++ b/src/google/adk_community/termination/termination_condition.py @@ -0,0 +1,171 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base termination condition and compound combinators.""" + +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event + + +@dataclass +class TerminationResult: + """The result returned by a termination condition when the conversation should stop.""" + + reason: str + """A human-readable description of why the conversation was terminated.""" + + +class TerminationCondition(abc.ABC): + """Abstract base class for all termination conditions. + + A termination condition is evaluated after each event in the agent loop. + When ``check()`` returns a ``TerminationResult``, the loop stops and the + ``reason`` is surfaced in the final event's ``actions.termination_reason``. + + Conditions are stateful but reset automatically at the start of each run. + They can be combined with ``.and_()`` and ``.or_()`` to create compound + logic. + + Example:: + + condition = MaxIterationsTermination(10).or_( + TextMentionTermination('TERMINATE') + ) + """ + + @property + @abc.abstractmethod + def terminated(self) -> bool: + """Whether this termination condition has been reached.""" + + @abc.abstractmethod + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + """Checks whether the termination condition is met. + + Called after each event emitted by the agent. Returns a + ``TerminationResult`` if the loop should stop, or ``None`` to continue. + + Args: + events: The delta sequence of events since the last check. + """ + + @abc.abstractmethod + async def reset(self) -> None: + """Resets this condition to its initial state. + + Called automatically at the start of each run so the same instance can + be reused across multiple runs. + """ + + def and_(self, other: TerminationCondition) -> TerminationCondition: + """Returns a new condition that terminates only when BOTH conditions are met. + + Args: + other: The other termination condition. + """ + return AndTerminationCondition(self, other) + + def or_(self, other: TerminationCondition) -> TerminationCondition: + """Returns a new condition that terminates when EITHER condition is met. + + Args: + other: The other termination condition. + """ + return OrTerminationCondition(self, other) + + def __and__(self, other: TerminationCondition) -> TerminationCondition: + """Supports ``condition_a & condition_b`` syntax.""" + return self.and_(other) + + def __or__(self, other: TerminationCondition) -> TerminationCondition: + """Supports ``condition_a | condition_b`` syntax.""" + return self.or_(other) + + +class AndTerminationCondition(TerminationCondition): + """A compound condition that terminates only when ALL children have fired.""" + + def __init__( + self, + left: TerminationCondition, + right: TerminationCondition, + ) -> None: + self._left = left + self._right = right + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + # Forward to both children so each accumulates its own state. + await self._left.check(events) + await self._right.check(events) + + if self._left.terminated and self._right.terminated: + self._terminated = True + return TerminationResult(reason='All termination conditions met') + return None + + async def reset(self) -> None: + self._terminated = False + await self._left.reset() + await self._right.reset() + + +class OrTerminationCondition(TerminationCondition): + """A compound condition that terminates when ANY child fires first.""" + + def __init__( + self, + left: TerminationCondition, + right: TerminationCondition, + ) -> None: + self._left = left + self._right = right + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + left_result = await self._left.check(events) + if left_result: + self._terminated = True + return left_result + + right_result = await self._right.check(events) + if right_result: + self._terminated = True + return right_result + + return None + + async def reset(self) -> None: + self._terminated = False + await self._left.reset() + await self._right.reset() diff --git a/src/google/adk_community/termination/text_mention_termination.py b/src/google/adk_community/termination/text_mention_termination.py new file mode 100644 index 00000000..7b04f1fa --- /dev/null +++ b/src/google/adk_community/termination/text_mention_termination.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates when a specific text string is found in event content.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +def _stringify_event_content(event: Event) -> str: + """Extracts a text representation from an event's content.""" + if not event.content or not event.content.parts: + return '' + texts = [] + for part in event.content.parts: + if part.text: + texts.append(part.text) + return ' '.join(texts) + + +class TextMentionTermination(TerminationCondition): + """Terminates the conversation when a specific text is found in event content. + + Example:: + + # Stop when any agent says "TERMINATE" + condition = TextMentionTermination('TERMINATE') + + # Stop only when the "critic" agent says "APPROVE" + condition = TextMentionTermination('APPROVE', sources=['critic']) + """ + + def __init__( + self, + text: str, + sources: Optional[Sequence[str]] = None, + ) -> None: + self._text = text + self._sources = list(sources) if sources else None + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + for event in events: + if self._sources and (event.author or '') not in self._sources: + continue + + if self._text in _stringify_event_content(event): + self._terminated = True + return TerminationResult(reason=f"Text '{self._text}' mentioned") + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/src/google/adk_community/termination/timeout_termination.py b/src/google/adk_community/termination/timeout_termination.py new file mode 100644 index 00000000..b50f6f3d --- /dev/null +++ b/src/google/adk_community/termination/timeout_termination.py @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates after a specified duration has elapsed.""" + +from __future__ import annotations + +import time +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class TimeoutTermination(TerminationCondition): + """Terminates the conversation after a specified duration has elapsed. + + The timer starts on the first ``check()`` call. + + Example:: + + # Stop after 30 seconds + condition = TimeoutTermination(30) + """ + + def __init__(self, timeout_seconds: float) -> None: + if timeout_seconds <= 0: + raise ValueError('timeout_seconds must be a positive number.') + self._timeout_seconds = timeout_seconds + self._start_time: Optional[float] = None + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + if self._start_time is None: + self._start_time = time.monotonic() + + elapsed = time.monotonic() - self._start_time + if elapsed >= self._timeout_seconds: + self._terminated = True + return TerminationResult( + reason=( + f'Timeout of {self._timeout_seconds}s reached' + f' (elapsed: {elapsed:.2f}s)' + ) + ) + return None + + async def reset(self) -> None: + self._terminated = False + self._start_time = None diff --git a/src/google/adk_community/termination/token_usage_termination.py b/src/google/adk_community/termination/token_usage_termination.py new file mode 100644 index 00000000..c941cdcc --- /dev/null +++ b/src/google/adk_community/termination/token_usage_termination.py @@ -0,0 +1,129 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates when cumulative token usage exceeds a limit.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from google.adk.events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class TokenUsageTermination(TerminationCondition): + """Terminates when cumulative token usage exceeds configured limits. + + At least one of the token limits must be provided. + + Example:: + + # Stop after 10000 total tokens + condition = TokenUsageTermination(max_total_tokens=10_000) + + # Stop after 5000 prompt tokens OR 2000 completion tokens + condition = TokenUsageTermination( + max_prompt_tokens=5_000, + max_completion_tokens=2_000, + ) + """ + + def __init__( + self, + *, + max_total_tokens: Optional[int] = None, + max_prompt_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, + ) -> None: + if ( + max_total_tokens is None + and max_prompt_tokens is None + and max_completion_tokens is None + ): + raise ValueError( + 'At least one of max_total_tokens, max_prompt_tokens, or' + ' max_completion_tokens must be provided.' + ) + self._max_total_tokens = max_total_tokens + self._max_prompt_tokens = max_prompt_tokens + self._max_completion_tokens = max_completion_tokens + self._total_tokens = 0 + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + for event in events: + if not event.usage_metadata: + continue + + self._total_tokens += event.usage_metadata.total_token_count or 0 + self._prompt_tokens += event.usage_metadata.prompt_token_count or 0 + self._completion_tokens += ( + event.usage_metadata.candidates_token_count or 0 + ) + + if ( + self._max_total_tokens is not None + and self._total_tokens >= self._max_total_tokens + ): + self._terminated = True + return TerminationResult( + reason=( + f'Token limit exceeded: total_tokens={self._total_tokens}' + f' >= max_total_tokens={self._max_total_tokens}' + ) + ) + + if ( + self._max_prompt_tokens is not None + and self._prompt_tokens >= self._max_prompt_tokens + ): + self._terminated = True + return TerminationResult( + reason=( + f'Token limit exceeded: prompt_tokens={self._prompt_tokens}' + f' >= max_prompt_tokens={self._max_prompt_tokens}' + ) + ) + + if ( + self._max_completion_tokens is not None + and self._completion_tokens >= self._max_completion_tokens + ): + self._terminated = True + return TerminationResult( + reason=( + 'Token limit exceeded:' + f' completion_tokens={self._completion_tokens}' + f' >= max_completion_tokens={self._max_completion_tokens}' + ) + ) + + return None + + async def reset(self) -> None: + self._terminated = False + self._total_tokens = 0 + self._prompt_tokens = 0 + self._completion_tokens = 0 diff --git a/tests/unittests/termination/__init__.py b/tests/unittests/termination/__init__.py new file mode 100644 index 00000000..0a2669d7 --- /dev/null +++ b/tests/unittests/termination/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/termination/test_termination_conditions.py b/tests/unittests/termination/test_termination_conditions.py new file mode 100644 index 00000000..5dd300d8 --- /dev/null +++ b/tests/unittests/termination/test_termination_conditions.py @@ -0,0 +1,472 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for termination conditions.""" + +from __future__ import annotations + +import asyncio +import time + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk_community.termination.external_termination import ExternalTermination +from google.adk_community.termination.function_call_termination import FunctionCallTermination +from google.adk_community.termination.max_iterations_termination import MaxIterationsTermination +from google.adk_community.termination.termination_condition import AndTerminationCondition +from google.adk_community.termination.termination_condition import OrTerminationCondition +from google.adk_community.termination.text_mention_termination import TextMentionTermination +from google.adk_community.termination.timeout_termination import TimeoutTermination +from google.adk_community.termination.token_usage_termination import TokenUsageTermination +from google.genai import types +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_text_event(text: str, author: str = 'agent') -> Event: + return Event( + invocation_id='inv-1', + author=author, + actions=EventActions(), + content=types.Content( + role='model', + parts=[types.Part(text=text)], + ), + ) + + +def _make_token_event( + total_tokens: int, + prompt_tokens: int, + completion_tokens: int, +) -> Event: + return Event( + invocation_id='inv-1', + author='agent', + actions=EventActions(), + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=total_tokens, + prompt_token_count=prompt_tokens, + candidates_token_count=completion_tokens, + ), + ) + + +def _make_function_response_event(function_name: str) -> Event: + return Event( + invocation_id='inv-1', + author='agent', + actions=EventActions(), + content=types.Content( + role='model', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=function_name, + response={'result': 'ok'}, + ) + ) + ], + ), + ) + + +# --------------------------------------------------------------------------- +# MaxIterationsTermination +# --------------------------------------------------------------------------- + + +class TestMaxIterationsTermination: + + def test_raises_if_not_positive(self): + with pytest.raises(ValueError): + MaxIterationsTermination(0) + with pytest.raises(ValueError): + MaxIterationsTermination(-1) + + @pytest.mark.asyncio + async def test_does_not_terminate_before_limit(self): + condition = MaxIterationsTermination(3) + result = await condition.check([_make_text_event('hello')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_at_limit(self): + condition = MaxIterationsTermination(3) + await condition.check([_make_text_event('a'), _make_text_event('b')]) + result = await condition.check([_make_text_event('c')]) + assert result is not None + assert '3' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_fire_again_after_termination(self): + condition = MaxIterationsTermination(1) + await condition.check([_make_text_event('first')]) + assert condition.terminated is True + second = await condition.check([_make_text_event('second')]) + assert second is None + + @pytest.mark.asyncio + async def test_reset(self): + condition = MaxIterationsTermination(1) + await condition.check([_make_text_event('first')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + result = await condition.check([_make_text_event('first again')]) + assert result is not None + + +# --------------------------------------------------------------------------- +# TextMentionTermination +# --------------------------------------------------------------------------- + + +class TestTextMentionTermination: + + @pytest.mark.asyncio + async def test_terminates_when_text_found(self): + condition = TextMentionTermination('TERMINATE') + result = await condition.check([_make_text_event('Please TERMINATE now.')]) + assert result is not None + assert 'TERMINATE' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_terminate_when_absent(self): + condition = TextMentionTermination('TERMINATE') + result = await condition.check([_make_text_event('Keep going!')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_respects_sources_filter(self): + condition = TextMentionTermination('APPROVE', sources=['critic']) + + # Wrong source — should NOT fire + no_fire = await condition.check([_make_text_event('APPROVE', 'primary')]) + assert no_fire is None + + # Correct source — should fire + fire = await condition.check([_make_text_event('APPROVE', 'critic')]) + assert fire is not None + + @pytest.mark.asyncio + async def test_reset(self): + condition = TextMentionTermination('DONE') + await condition.check([_make_text_event('DONE')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + result = await condition.check([_make_text_event('not done yet')]) + assert result is None + + +# --------------------------------------------------------------------------- +# TokenUsageTermination +# --------------------------------------------------------------------------- + + +class TestTokenUsageTermination: + + def test_raises_if_no_limit(self): + with pytest.raises(ValueError): + TokenUsageTermination() + + @pytest.mark.asyncio + async def test_terminates_on_total_tokens(self): + condition = TokenUsageTermination(max_total_tokens=100) + result = await condition.check([_make_token_event(101, 50, 51)]) + assert result is not None + assert 'total_tokens' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_terminates_on_prompt_tokens(self): + condition = TokenUsageTermination(max_prompt_tokens=50) + result = await condition.check([_make_token_event(60, 55, 5)]) + assert result is not None + assert 'prompt_tokens' in result.reason + + @pytest.mark.asyncio + async def test_terminates_on_completion_tokens(self): + condition = TokenUsageTermination(max_completion_tokens=30) + result = await condition.check([_make_token_event(40, 5, 35)]) + assert result is not None + assert 'completion_tokens' in result.reason + + @pytest.mark.asyncio + async def test_accumulates_across_events(self): + condition = TokenUsageTermination(max_total_tokens=100) + await condition.check([_make_token_event(60, 40, 20)]) + assert condition.terminated is False + + result = await condition.check([_make_token_event(50, 30, 20)]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_ignores_events_without_usage(self): + condition = TokenUsageTermination(max_total_tokens=10) + result = await condition.check([_make_text_event('no tokens here')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_reset(self): + condition = TokenUsageTermination(max_total_tokens=100) + await condition.check([_make_token_event(200, 100, 100)]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + result = await condition.check([_make_token_event(50, 30, 20)]) + assert result is None + + +# --------------------------------------------------------------------------- +# TimeoutTermination +# --------------------------------------------------------------------------- + + +class TestTimeoutTermination: + + def test_raises_if_not_positive(self): + with pytest.raises(ValueError): + TimeoutTermination(0) + with pytest.raises(ValueError): + TimeoutTermination(-5) + + @pytest.mark.asyncio + async def test_does_not_terminate_before_timeout(self): + condition = TimeoutTermination(60) + result = await condition.check([_make_text_event('hello')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_after_timeout(self): + condition = TimeoutTermination(0.01) # 10ms + # Warm up the start time. + await condition.check([_make_text_event('trigger start')]) + # Wait slightly longer than the timeout. + await asyncio.sleep(0.02) + + result = await condition.check([_make_text_event('after timeout')]) + assert result is not None + assert 'Timeout' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_reset(self): + condition = TimeoutTermination(0.01) + await condition.check([_make_text_event('start')]) + await asyncio.sleep(0.02) + await condition.check([_make_text_event('fires')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + # After reset a fresh check starts a new timer. + result = await condition.check([_make_text_event('fresh start')]) + assert result is None + + +# --------------------------------------------------------------------------- +# FunctionCallTermination +# --------------------------------------------------------------------------- + + +class TestFunctionCallTermination: + + @pytest.mark.asyncio + async def test_terminates_on_matching_function(self): + condition = FunctionCallTermination('approve') + result = await condition.check([_make_function_response_event('approve')]) + assert result is not None + assert 'approve' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_terminate_for_different_function(self): + condition = FunctionCallTermination('approve') + result = await condition.check([_make_function_response_event('search')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_does_not_terminate_on_text_only(self): + condition = FunctionCallTermination('approve') + result = await condition.check([_make_text_event('approve this')]) + assert result is None + + @pytest.mark.asyncio + async def test_reset(self): + condition = FunctionCallTermination('approve') + await condition.check([_make_function_response_event('approve')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + +# --------------------------------------------------------------------------- +# ExternalTermination +# --------------------------------------------------------------------------- + + +class TestExternalTermination: + + @pytest.mark.asyncio + async def test_does_not_terminate_before_set(self): + condition = ExternalTermination() + result = await condition.check([_make_text_event('anything')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_after_set(self): + condition = ExternalTermination() + condition.set() + result = await condition.check([_make_text_event('anything')]) + assert result is not None + assert 'Externally terminated' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_reset(self): + condition = ExternalTermination() + condition.set() + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + result = await condition.check([_make_text_event('should not fire')]) + assert result is None + + +# --------------------------------------------------------------------------- +# OrTerminationCondition (.or_()) +# --------------------------------------------------------------------------- + + +class TestOrTerminationCondition: + + @pytest.mark.asyncio + async def test_terminates_on_first(self): + condition = MaxIterationsTermination(1).or_(TextMentionTermination('DONE')) + result = await condition.check([_make_text_event('any')]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_terminates_on_second(self): + condition = MaxIterationsTermination(100).or_( + TextMentionTermination('DONE') + ) + result = await condition.check([_make_text_event('DONE')]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_terminate_when_neither_fires(self): + condition = MaxIterationsTermination(100).or_( + TextMentionTermination('DONE') + ) + result = await condition.check([_make_text_event('keep going')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_reset_both_children(self): + condition = MaxIterationsTermination(1).or_(TextMentionTermination('DONE')) + await condition.check([_make_text_event('fires')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + def test_is_or_instance(self): + condition = MaxIterationsTermination(1).or_(TextMentionTermination('X')) + assert isinstance(condition, OrTerminationCondition) + + @pytest.mark.asyncio + async def test_pipe_operator(self): + condition = MaxIterationsTermination(1) | TextMentionTermination('DONE') + result = await condition.check([_make_text_event('any')]) + assert result is not None + assert isinstance(condition, OrTerminationCondition) + + +# --------------------------------------------------------------------------- +# AndTerminationCondition (.and_()) +# --------------------------------------------------------------------------- + + +class TestAndTerminationCondition: + + @pytest.mark.asyncio + async def test_does_not_terminate_when_only_first_fires(self): + condition = MaxIterationsTermination(1).and_(TextMentionTermination('DONE')) + result = await condition.check([_make_text_event('no keyword here')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_does_not_terminate_when_only_second_fires(self): + condition = MaxIterationsTermination(100).and_( + TextMentionTermination('DONE') + ) + result = await condition.check([_make_text_event('DONE')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_when_both_fire(self): + left = MaxIterationsTermination(1) + right = TextMentionTermination('DONE') + condition = left.and_(right) + + result = await condition.check([_make_text_event('DONE')]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_reset_both_children(self): + condition = MaxIterationsTermination(1).and_(TextMentionTermination('DONE')) + await condition.check([_make_text_event('DONE')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + def test_is_and_instance(self): + condition = MaxIterationsTermination(1).and_(TextMentionTermination('X')) + assert isinstance(condition, AndTerminationCondition) + + @pytest.mark.asyncio + async def test_ampersand_operator(self): + condition = MaxIterationsTermination(1) & TextMentionTermination('DONE') + result = await condition.check([_make_text_event('DONE')]) + assert result is not None + assert isinstance(condition, AndTerminationCondition)