diff --git a/langfun/core/agentic/__init__.py b/langfun/core/agentic/__init__.py index 1c57f22d..18fc1683 100644 --- a/langfun/core/agentic/__init__.py +++ b/langfun/core/agentic/__init__.py @@ -22,6 +22,10 @@ from langfun.core.agentic.action import Action +from langfun.core.agentic.action import ActionCheckpoint +from langfun.core.agentic.action import Checkpointer +from langfun.core.agentic.action import FileCheckpointer + from langfun.core.agentic.action import ExecutionUnit from langfun.core.agentic.action import ActionInvocation from langfun.core.agentic.action import ParallelExecutions diff --git a/langfun/core/agentic/action.py b/langfun/core/agentic/action.py index 35e9ef1b..32065308 100644 --- a/langfun/core/agentic/action.py +++ b/langfun/core/agentic/action.py @@ -19,6 +19,7 @@ import dataclasses import functools import itertools +import os import threading import time import typing @@ -316,6 +317,109 @@ def call(self, session: 'Session', **kwargs) -> Any: The result of the action. """ + @property + def checkpoint_type_id(self) -> str: + """Stable identifier for checkpoint type validation.""" + return f'{self.__class__.__module__}.{self.__class__.__name__}' + + def on_checkpoint(self) -> 'ActionCheckpoint': + """Override to return checkpoint data. Default: empty state.""" + return ActionCheckpoint( + action_type=self.checkpoint_type_id, + state={}, + step=0, + ) + + def on_restore(self, checkpoint: 'ActionCheckpoint') -> None: + """Override to restore from checkpoint. Validates action type.""" + if checkpoint.action_type != self.checkpoint_type_id: + raise ValueError( + f'Checkpoint type mismatch: expected {self.checkpoint_type_id}, ' + f'got {checkpoint.action_type}' + ) + + @classmethod + def restore_from_checkpoint( + cls, + checkpointer: 'Checkpointer', + **kwargs, + ) -> tuple['Action', 'ActionCheckpoint | None']: + """Convenience method: create action and restore from checkpoint if exists. + + Args: + checkpointer: The checkpointer to use for loading. + **kwargs: Keyword arguments to pass to the action constructor. + + Returns: + (action, checkpoint) tuple. checkpoint is None if no checkpoint found. + """ + action = cls(**kwargs) + checkpoint = checkpointer.load() + if checkpoint is not None: + action.on_restore(checkpoint) + return action, checkpoint + + +# +# Checkpointing. +# + + +class ActionCheckpoint(pg.Object): + """Checkpoint data for an action's context.""" + + action_type: Annotated[str, 'Fully qualified class name for validation'] + state: Annotated[dict[str, Any], 'Action-specific state to persist'] + step: Annotated[int, 'Last completed step'] = 0 + + +class Checkpointer(pg.Object): + """Interface for action context checkpointing.""" + + def _on_bound(self): + super()._on_bound() + if type(self) is Checkpointer: # pylint: disable=unidiomatic-typecheck + raise TypeError( + 'Checkpointer is abstract and cannot be instantiated directly. ' + 'Use a subclass like FileCheckpointer.' + ) + + def save(self, checkpoint: ActionCheckpoint) -> None: + raise NotImplementedError + + def load(self) -> ActionCheckpoint | None: + raise NotImplementedError + + +class FileCheckpointer(Checkpointer): + """Checkpoints to a file using pg.io (supports CNS/GCS/local).""" + + path: Annotated[str, 'File path for checkpoint (CNS/GCS/local)'] + + def save(self, checkpoint: ActionCheckpoint) -> None: + content = checkpoint.to_json_str(json_indent=2) + dir_path = os.path.dirname(self.path) + if dir_path: + pg.io.mkdirs(dir_path, exist_ok=True) + tmp_path = os.path.join(dir_path, f'tmp.{os.path.basename(self.path)}') + pg.io.writefile(tmp_path, content) + pg.io.rename(tmp_path, self.path) + + def load(self) -> ActionCheckpoint | None: + if not pg.io.path_exists(self.path): + return None + file_content = pg.io.readfile(self.path) + if file_content is None: + return None + if isinstance(file_content, bytes): + content = file_content.decode('utf-8') + else: + content = file_content + try: + return pg.from_json_str(content) + except Exception: # pylint: disable=broad-except + return None + # # Execution tracking. diff --git a/langfun/core/agentic/action_test.py b/langfun/core/agentic/action_test.py index 3651088d..74490603 100644 --- a/langfun/core/agentic/action_test.py +++ b/langfun/core/agentic/action_test.py @@ -14,7 +14,10 @@ """Tests for base action.""" import asyncio +import os +import tempfile import time +from typing import Any import unittest import langfun.core as lf @@ -659,5 +662,258 @@ def test_query_with_track_if(self): self.assertIsNone(session.root.queries[0].error) +class CheckpointTest(unittest.TestCase): + """Tests for action checkpointing functionality.""" + + def test_action_checkpoint_serialization(self): + """Round-trip via to_json_str/from_json_str preserves all fields.""" + checkpoint = action_lib.ActionCheckpoint( + action_type='test.module.TestAction', + state={'key': 'value', 'count': 42}, + step=5, + ) + json_str = checkpoint.to_json_str() + restored = pg.from_json_str(json_str) + self.assertEqual(restored.action_type, 'test.module.TestAction') + self.assertEqual(restored.state, {'key': 'value', 'count': 42}) + self.assertEqual(restored.step, 5) + + def test_file_checkpointer_save_load(self): + """Save then load returns equal checkpoint.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'checkpoint.json') + checkpointer = action_lib.FileCheckpointer(path=path) + checkpoint = action_lib.ActionCheckpoint( + action_type='test.TestAction', + state={'items': [1, 2, 3]}, + step=10, + ) + checkpointer.save(checkpoint) + loaded = checkpointer.load() + self.assertIsNotNone(loaded) + self.assertEqual(loaded.action_type, checkpoint.action_type) + self.assertEqual(loaded.state, checkpoint.state) + self.assertEqual(loaded.step, checkpoint.step) + + def test_file_checkpointer_missing_file(self): + """load() returns None for missing path.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'nonexistent.json') + checkpointer = action_lib.FileCheckpointer(path=path) + self.assertIsNone(checkpointer.load()) + + def test_file_checkpointer_creates_dirs(self): + """Nested path dirs created via pg.io.mkdirs.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'nested', 'deep', 'checkpoint.json') + checkpointer = action_lib.FileCheckpointer(path=path) + checkpoint = action_lib.ActionCheckpoint( + action_type='test.TestAction', + state={}, + step=0, + ) + checkpointer.save(checkpoint) + self.assertTrue(os.path.exists(path)) + + def test_action_default_on_checkpoint(self): + """Default on_checkpoint returns ActionCheckpoint with empty state.""" + action = Bar() + checkpoint = action.on_checkpoint() + self.assertIsInstance(checkpoint, action_lib.ActionCheckpoint) + self.assertIn('Bar', checkpoint.action_type) + self.assertEqual(checkpoint.state, {}) + self.assertEqual(checkpoint.step, 0) + + def test_action_default_on_restore_validates_type(self): + """Mismatched action_type raises ValueError.""" + action = Bar() + wrong_checkpoint = action_lib.ActionCheckpoint( + action_type='wrong.module.WrongAction', + state={}, + step=0, + ) + with self.assertRaises(ValueError) as ctx: + action.on_restore(wrong_checkpoint) + self.assertIn('Checkpoint type mismatch', str(ctx.exception)) + + def test_custom_action_checkpoint_hooks(self): + """Subclass state persists across checkpoint/restore.""" + + class StatefulAction(action_lib.Action): + + def _on_bound(self): + super()._on_bound() + self._counter = 0 + self._items = [] + + def on_checkpoint(self): + return action_lib.ActionCheckpoint( + action_type=self.checkpoint_type_id, + state={'counter': self._counter, 'items': self._items}, + step=self._counter, + ) + + def on_restore(self, checkpoint): + super().on_restore(checkpoint) + self._counter = checkpoint.state.get('counter', 0) + self._items = checkpoint.state.get('items', []) + + def call(self, session, **kwargs): + self._counter += 1 + self._items.append(f'item_{self._counter}') + return self._counter + + action1 = StatefulAction() + action1._counter = 5 + action1._items = ['a', 'b', 'c'] + checkpoint = action1.on_checkpoint() + + action2 = StatefulAction() + self.assertEqual(action2._counter, 0) + self.assertEqual(action2._items, []) + + action2.on_restore(checkpoint) + self.assertEqual(action2._counter, 5) + self.assertEqual(action2._items, ['a', 'b', 'c']) + + def test_checkpoint_with_pg_object_state(self): + """State containing pg.Object instances serializes correctly.""" + + class TestData(pg.Object): + name: str + value: int + + checkpoint = action_lib.ActionCheckpoint( + action_type='test.TestAction', + state={'data': TestData(name='test', value=42)}, + step=1, + ) + json_str = checkpoint.to_json_str() + restored = pg.from_json_str(json_str) + self.assertIsInstance(restored.state['data'], TestData) + self.assertEqual(restored.state['data'].name, 'test') + self.assertEqual(restored.state['data'].value, 42) + + def test_restore_from_checkpoint_convenience(self): + """Creates action and restores state in one call.""" + + class TestAction(action_lib.Action): + x: int + + def _on_bound(self): + super()._on_bound() + self._state_value = 0 + + def on_checkpoint(self): + return action_lib.ActionCheckpoint( + action_type=self.checkpoint_type_id, + state={'state_value': self._state_value}, + step=self._state_value, + ) + + def on_restore(self, checkpoint): + super().on_restore(checkpoint) + self._state_value = checkpoint.state.get('state_value', 0) + + def call(self, session, **kwargs): + return self.x + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'checkpoint.json') + checkpointer = action_lib.FileCheckpointer(path=path) + + # Save a checkpoint with some state + original = TestAction(x=10) + original._state_value = 99 + checkpointer.save(original.on_checkpoint()) + + # Use convenience method to restore + restored, checkpoint = TestAction.restore_from_checkpoint( + checkpointer, x=10 + ) + self.assertEqual(restored._state_value, 99) + self.assertIsNotNone(checkpoint) + self.assertEqual(checkpoint.step, 99) + + # Test with no existing checkpoint + checkpointer2 = action_lib.FileCheckpointer( + path=os.path.join(tmpdir, 'new.json') + ) + fresh, checkpoint = TestAction.restore_from_checkpoint( + checkpointer2, x=20 + ) + self.assertEqual(fresh._state_value, 0) # Default from _on_bound + self.assertIsNone(checkpoint) + + def test_agent_checkpoint_round_trip_with_pg_objects(self): + """Mimics LangfunAgent pattern: list of pg.Objects in state round-trips.""" + + class Step(pg.Object): + step: int + thoughts: str + results: Any + + class AgentLikeAction(action_lib.Action): + _CKPT_KEY_STEPS = 'steps' + _CKPT_KEY_START_INDEX = 'start_index' + _CKPT_KEY_FILE_DIR = 'file_dir' + + def _on_bound(self): + super()._on_bound() + self._steps = [] + self._start_index = 0 + self._restored = False + + def on_checkpoint(self): + return action_lib.ActionCheckpoint( + action_type=self.checkpoint_type_id, + state={ + self._CKPT_KEY_STEPS: self._steps, + self._CKPT_KEY_START_INDEX: self._start_index, + self._CKPT_KEY_FILE_DIR: '/tmp/files', + }, + step=len(self._steps), + ) + + def on_restore(self, checkpoint): + super().on_restore(checkpoint) + self._steps = checkpoint.state.get(self._CKPT_KEY_STEPS, []) + self._start_index = checkpoint.state.get(self._CKPT_KEY_START_INDEX, 0) + self._restored = True + + def call(self, session, **kwargs): + return None + + original = AgentLikeAction() + original._steps = [ + Step(step=0, thoughts='thinking', results='found it'), + Step(step=1, thoughts='analyzing', results={'key': 'value'}), + ] + original._start_index = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'agent_checkpoint.json') + checkpointer = action_lib.FileCheckpointer(path=path) + + checkpointer.save(original.on_checkpoint()) + checkpoint = checkpointer.load() + self.assertIsNotNone(checkpoint) + self.assertEqual(checkpoint.step, 2) + self.assertEqual( + checkpoint.state[AgentLikeAction._CKPT_KEY_FILE_DIR], '/tmp/files' + ) + + restored = AgentLikeAction() + self.assertFalse(restored._restored) + restored.on_restore(checkpoint) + self.assertTrue(restored._restored) + self.assertEqual(len(restored._steps), 2) + self.assertIsInstance(restored._steps[0], Step) + self.assertEqual(restored._steps[0].step, 0) + self.assertEqual(restored._steps[0].thoughts, 'thinking') + self.assertEqual(restored._steps[1].results, {'key': 'value'}) + self.assertEqual(restored._start_index, 1) + + if __name__ == '__main__': unittest.main()