Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions langfun/core/agentic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

from langfun.core.agentic.action import Action

from langfun.core.agentic.action import ActionCheckpoint
from langfun.core.agentic.action import Checkpointer
from langfun.core.agentic.action import FileCheckpointer

from langfun.core.agentic.action import ExecutionUnit
from langfun.core.agentic.action import ActionInvocation
from langfun.core.agentic.action import ParallelExecutions
Expand Down
104 changes: 104 additions & 0 deletions langfun/core/agentic/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dataclasses
import functools
import itertools
import os
import threading
import time
import typing
Expand Down Expand Up @@ -316,6 +317,109 @@ def call(self, session: 'Session', **kwargs) -> Any:
The result of the action.
"""

@property
def checkpoint_type_id(self) -> str:
"""Stable identifier for checkpoint type validation."""
return f'{self.__class__.__module__}.{self.__class__.__name__}'

def on_checkpoint(self) -> 'ActionCheckpoint':
"""Override to return checkpoint data. Default: empty state."""
return ActionCheckpoint(
action_type=self.checkpoint_type_id,
state={},
step=0,
)

def on_restore(self, checkpoint: 'ActionCheckpoint') -> None:
"""Override to restore from checkpoint. Validates action type."""
if checkpoint.action_type != self.checkpoint_type_id:
raise ValueError(
f'Checkpoint type mismatch: expected {self.checkpoint_type_id}, '
f'got {checkpoint.action_type}'
)

@classmethod
def restore_from_checkpoint(
cls,
checkpointer: 'Checkpointer',
**kwargs,
) -> tuple['Action', 'ActionCheckpoint | None']:
"""Convenience method: create action and restore from checkpoint if exists.

Args:
checkpointer: The checkpointer to use for loading.
**kwargs: Keyword arguments to pass to the action constructor.

Returns:
(action, checkpoint) tuple. checkpoint is None if no checkpoint found.
"""
action = cls(**kwargs)
checkpoint = checkpointer.load()
if checkpoint is not None:
action.on_restore(checkpoint)
return action, checkpoint


#
# Checkpointing.
#


class ActionCheckpoint(pg.Object):
"""Checkpoint data for an action's context."""

action_type: Annotated[str, 'Fully qualified class name for validation']
state: Annotated[dict[str, Any], 'Action-specific state to persist']
step: Annotated[int, 'Last completed step'] = 0


class Checkpointer(pg.Object):
"""Interface for action context checkpointing."""

def _on_bound(self):
super()._on_bound()
if type(self) is Checkpointer: # pylint: disable=unidiomatic-typecheck
raise TypeError(
'Checkpointer is abstract and cannot be instantiated directly. '
'Use a subclass like FileCheckpointer.'
)

def save(self, checkpoint: ActionCheckpoint) -> None:
raise NotImplementedError

def load(self) -> ActionCheckpoint | None:
raise NotImplementedError


class FileCheckpointer(Checkpointer):
"""Checkpoints to a file using pg.io (supports CNS/GCS/local)."""

path: Annotated[str, 'File path for checkpoint (CNS/GCS/local)']

def save(self, checkpoint: ActionCheckpoint) -> None:
content = checkpoint.to_json_str(json_indent=2)
dir_path = os.path.dirname(self.path)
if dir_path:
pg.io.mkdirs(dir_path, exist_ok=True)
tmp_path = os.path.join(dir_path, f'tmp.{os.path.basename(self.path)}')
pg.io.writefile(tmp_path, content)
pg.io.rename(tmp_path, self.path)

def load(self) -> ActionCheckpoint | None:
if not pg.io.path_exists(self.path):
return None
file_content = pg.io.readfile(self.path)
if file_content is None:
return None
if isinstance(file_content, bytes):
content = file_content.decode('utf-8')
else:
content = file_content
try:
return pg.from_json_str(content)
except Exception: # pylint: disable=broad-except
return None


#
# Execution tracking.
Expand Down
256 changes: 256 additions & 0 deletions langfun/core/agentic/action_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
"""Tests for base action."""

import asyncio
import os
import tempfile
import time
from typing import Any
import unittest

import langfun.core as lf
Expand Down Expand Up @@ -659,5 +662,258 @@ def test_query_with_track_if(self):
self.assertIsNone(session.root.queries[0].error)


class CheckpointTest(unittest.TestCase):
"""Tests for action checkpointing functionality."""

def test_action_checkpoint_serialization(self):
"""Round-trip via to_json_str/from_json_str preserves all fields."""
checkpoint = action_lib.ActionCheckpoint(
action_type='test.module.TestAction',
state={'key': 'value', 'count': 42},
step=5,
)
json_str = checkpoint.to_json_str()
restored = pg.from_json_str(json_str)
self.assertEqual(restored.action_type, 'test.module.TestAction')
self.assertEqual(restored.state, {'key': 'value', 'count': 42})
self.assertEqual(restored.step, 5)

def test_file_checkpointer_save_load(self):
"""Save then load returns equal checkpoint."""
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'checkpoint.json')
checkpointer = action_lib.FileCheckpointer(path=path)
checkpoint = action_lib.ActionCheckpoint(
action_type='test.TestAction',
state={'items': [1, 2, 3]},
step=10,
)
checkpointer.save(checkpoint)
loaded = checkpointer.load()
self.assertIsNotNone(loaded)
self.assertEqual(loaded.action_type, checkpoint.action_type)
self.assertEqual(loaded.state, checkpoint.state)
self.assertEqual(loaded.step, checkpoint.step)

def test_file_checkpointer_missing_file(self):
"""load() returns None for missing path."""
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'nonexistent.json')
checkpointer = action_lib.FileCheckpointer(path=path)
self.assertIsNone(checkpointer.load())

def test_file_checkpointer_creates_dirs(self):
"""Nested path dirs created via pg.io.mkdirs."""
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'nested', 'deep', 'checkpoint.json')
checkpointer = action_lib.FileCheckpointer(path=path)
checkpoint = action_lib.ActionCheckpoint(
action_type='test.TestAction',
state={},
step=0,
)
checkpointer.save(checkpoint)
self.assertTrue(os.path.exists(path))

def test_action_default_on_checkpoint(self):
"""Default on_checkpoint returns ActionCheckpoint with empty state."""
action = Bar()
checkpoint = action.on_checkpoint()
self.assertIsInstance(checkpoint, action_lib.ActionCheckpoint)
self.assertIn('Bar', checkpoint.action_type)
self.assertEqual(checkpoint.state, {})
self.assertEqual(checkpoint.step, 0)

def test_action_default_on_restore_validates_type(self):
"""Mismatched action_type raises ValueError."""
action = Bar()
wrong_checkpoint = action_lib.ActionCheckpoint(
action_type='wrong.module.WrongAction',
state={},
step=0,
)
with self.assertRaises(ValueError) as ctx:
action.on_restore(wrong_checkpoint)
self.assertIn('Checkpoint type mismatch', str(ctx.exception))

def test_custom_action_checkpoint_hooks(self):
"""Subclass state persists across checkpoint/restore."""

class StatefulAction(action_lib.Action):

def _on_bound(self):
super()._on_bound()
self._counter = 0
self._items = []

def on_checkpoint(self):
return action_lib.ActionCheckpoint(
action_type=self.checkpoint_type_id,
state={'counter': self._counter, 'items': self._items},
step=self._counter,
)

def on_restore(self, checkpoint):
super().on_restore(checkpoint)
self._counter = checkpoint.state.get('counter', 0)
self._items = checkpoint.state.get('items', [])

def call(self, session, **kwargs):
self._counter += 1
self._items.append(f'item_{self._counter}')
return self._counter

action1 = StatefulAction()
action1._counter = 5
action1._items = ['a', 'b', 'c']
checkpoint = action1.on_checkpoint()

action2 = StatefulAction()
self.assertEqual(action2._counter, 0)
self.assertEqual(action2._items, [])

action2.on_restore(checkpoint)
self.assertEqual(action2._counter, 5)
self.assertEqual(action2._items, ['a', 'b', 'c'])

def test_checkpoint_with_pg_object_state(self):
"""State containing pg.Object instances serializes correctly."""

class TestData(pg.Object):
name: str
value: int

checkpoint = action_lib.ActionCheckpoint(
action_type='test.TestAction',
state={'data': TestData(name='test', value=42)},
step=1,
)
json_str = checkpoint.to_json_str()
restored = pg.from_json_str(json_str)
self.assertIsInstance(restored.state['data'], TestData)
self.assertEqual(restored.state['data'].name, 'test')
self.assertEqual(restored.state['data'].value, 42)

def test_restore_from_checkpoint_convenience(self):
"""Creates action and restores state in one call."""

class TestAction(action_lib.Action):
x: int

def _on_bound(self):
super()._on_bound()
self._state_value = 0

def on_checkpoint(self):
return action_lib.ActionCheckpoint(
action_type=self.checkpoint_type_id,
state={'state_value': self._state_value},
step=self._state_value,
)

def on_restore(self, checkpoint):
super().on_restore(checkpoint)
self._state_value = checkpoint.state.get('state_value', 0)

def call(self, session, **kwargs):
return self.x

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'checkpoint.json')
checkpointer = action_lib.FileCheckpointer(path=path)

# Save a checkpoint with some state
original = TestAction(x=10)
original._state_value = 99
checkpointer.save(original.on_checkpoint())

# Use convenience method to restore
restored, checkpoint = TestAction.restore_from_checkpoint(
checkpointer, x=10
)
self.assertEqual(restored._state_value, 99)
self.assertIsNotNone(checkpoint)
self.assertEqual(checkpoint.step, 99)

# Test with no existing checkpoint
checkpointer2 = action_lib.FileCheckpointer(
path=os.path.join(tmpdir, 'new.json')
)
fresh, checkpoint = TestAction.restore_from_checkpoint(
checkpointer2, x=20
)
self.assertEqual(fresh._state_value, 0) # Default from _on_bound
self.assertIsNone(checkpoint)

def test_agent_checkpoint_round_trip_with_pg_objects(self):
"""Mimics LangfunAgent pattern: list of pg.Objects in state round-trips."""

class Step(pg.Object):
step: int
thoughts: str
results: Any

class AgentLikeAction(action_lib.Action):
_CKPT_KEY_STEPS = 'steps'
_CKPT_KEY_START_INDEX = 'start_index'
_CKPT_KEY_FILE_DIR = 'file_dir'

def _on_bound(self):
super()._on_bound()
self._steps = []
self._start_index = 0
self._restored = False

def on_checkpoint(self):
return action_lib.ActionCheckpoint(
action_type=self.checkpoint_type_id,
state={
self._CKPT_KEY_STEPS: self._steps,
self._CKPT_KEY_START_INDEX: self._start_index,
self._CKPT_KEY_FILE_DIR: '/tmp/files',
},
step=len(self._steps),
)

def on_restore(self, checkpoint):
super().on_restore(checkpoint)
self._steps = checkpoint.state.get(self._CKPT_KEY_STEPS, [])
self._start_index = checkpoint.state.get(self._CKPT_KEY_START_INDEX, 0)
self._restored = True

def call(self, session, **kwargs):
return None

original = AgentLikeAction()
original._steps = [
Step(step=0, thoughts='thinking', results='found it'),
Step(step=1, thoughts='analyzing', results={'key': 'value'}),
]
original._start_index = 1

with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'agent_checkpoint.json')
checkpointer = action_lib.FileCheckpointer(path=path)

checkpointer.save(original.on_checkpoint())
checkpoint = checkpointer.load()
self.assertIsNotNone(checkpoint)
self.assertEqual(checkpoint.step, 2)
self.assertEqual(
checkpoint.state[AgentLikeAction._CKPT_KEY_FILE_DIR], '/tmp/files'
)

restored = AgentLikeAction()
self.assertFalse(restored._restored)
restored.on_restore(checkpoint)
self.assertTrue(restored._restored)
self.assertEqual(len(restored._steps), 2)
self.assertIsInstance(restored._steps[0], Step)
self.assertEqual(restored._steps[0].step, 0)
self.assertEqual(restored._steps[0].thoughts, 'thinking')
self.assertEqual(restored._steps[1].results, {'key': 'value'})
self.assertEqual(restored._start_index, 1)


if __name__ == '__main__':
unittest.main()
Loading