diff --git a/langfun/core/agentic/__init__.py b/langfun/core/agentic/__init__.py index d4a05487..1a590140 100644 --- a/langfun/core/agentic/__init__.py +++ b/langfun/core/agentic/__init__.py @@ -18,6 +18,7 @@ # pylint: disable=g-import-not-at-top from langfun.core.agentic.action import Action +from langfun.core.agentic.action import ActionSampling from langfun.core.agentic.action import ActionInvocation from langfun.core.agentic.action import Session diff --git a/langfun/core/agentic/action.py b/langfun/core/agentic/action.py index 818f5b32..02e1add2 100644 --- a/langfun/core/agentic/action.py +++ b/langfun/core/agentic/action.py @@ -19,7 +19,7 @@ import time import typing -from typing import Annotated, Any, ContextManager, Iterable, Iterator, Optional, Type, Union +from typing import Annotated, Any, Callable, ContextManager, Iterable, Iterator, Optional, Type, Union import langfun.core as lf from langfun.core import structured as lf_structured import pyglove as pg @@ -84,6 +84,55 @@ def call(self, session: 'Session', **kwargs) -> Any: The result of the action. """ + def sample( + self, + session: Optional['Session'] = None, + n: int = 1, + *, + show_progress: bool = True, + reduce_fn: Callable[[list[Any]], Any] | None = None, + **kwargs + ) -> Any: + """Samples the action multiple times.""" + return ActionSampling( + action=self, + n=n, + reduce_fn=reduce_fn, + )( + session=session, + show_progress=show_progress, + **kwargs + ) + + +class ActionSampling(Action): + """Sampling an action multiple times.""" + + action: Annotated[ + Action, + 'Action to sample.' + ] + + n: Annotated[ + int, + 'Number of samples to generate.' + ] = 1 + + reduce_fn: Annotated[ + Callable[[list[Any]], Any] | None, + 'Function to reduce the samples to a single result.' + ] = None + + def call(self, session: 'Session', **kwargs) -> Any: + """Calls the action.""" + results = [] + for _ in range(self.n): + result = self.action.clone(deep=True)(session=session, **kwargs) + results.append(result) + if self.reduce_fn is not None: + return self.reduce_fn(results) + return results + # Type definition for traced item during execution. TracedItem = Union[ diff --git a/langfun/core/agentic/action_test.py b/langfun/core/agentic/action_test.py index 97e3eede..82d14eb8 100644 --- a/langfun/core/agentic/action_test.py +++ b/langfun/core/agentic/action_test.py @@ -119,6 +119,29 @@ def make_additional_query(self, lm): json_str = session.to_json_str(save_ref_value=True) self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session) + def test_sample(self): + class Foo(action_lib.Action): + + def call(self, session, x=1): + return x + + session = action_lib.Session() + self.assertEqual(Foo().sample(session, n=2, x=1, reduce_fn=sum), 2) + self.assertIsNotNone(session.root) + self.assertEqual(len(session.root.execution), 1) + topmost_invocation = session.root.execution.items[0] + self.assertEqual(topmost_invocation.result, 2) + self.assertEqual( + topmost_invocation.action, + action_lib.ActionSampling(Foo(), n=2, reduce_fn=sum) + ) + self.assertEqual(len(topmost_invocation.execution.items), 2) + self.assertEqual(topmost_invocation.execution.items[0].action, Foo()) + self.assertEqual(topmost_invocation.execution.items[0].result, 1) + + # No reduce_fn. + self.assertEqual(Foo().sample(session, n=2, x=1), [1, 1]) + def test_log(self): session = action_lib.Session() session.debug('hi', x=1, y=2)