Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions langfun/core/agentic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=g-import-not-at-top

from langfun.core.agentic.action import Action
from langfun.core.agentic.action import ActionSampling
from langfun.core.agentic.action import ActionInvocation
from langfun.core.agentic.action import Session

Expand Down
51 changes: 50 additions & 1 deletion langfun/core/agentic/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time

import typing
from typing import Annotated, Any, ContextManager, Iterable, Iterator, Optional, Type, Union
from typing import Annotated, Any, Callable, ContextManager, Iterable, Iterator, Optional, Type, Union
import langfun.core as lf
from langfun.core import structured as lf_structured
import pyglove as pg
Expand Down Expand Up @@ -84,6 +84,55 @@ def call(self, session: 'Session', **kwargs) -> Any:
The result of the action.
"""

def sample(
self,
session: Optional['Session'] = None,
n: int = 1,
*,
show_progress: bool = True,
reduce_fn: Callable[[list[Any]], Any] | None = None,
**kwargs
) -> Any:
"""Samples the action multiple times."""
return ActionSampling(
action=self,
n=n,
reduce_fn=reduce_fn,
)(
session=session,
show_progress=show_progress,
**kwargs
)


class ActionSampling(Action):
"""Sampling an action multiple times."""

action: Annotated[
Action,
'Action to sample.'
]

n: Annotated[
int,
'Number of samples to generate.'
] = 1

reduce_fn: Annotated[
Callable[[list[Any]], Any] | None,
'Function to reduce the samples to a single result.'
] = None

def call(self, session: 'Session', **kwargs) -> Any:
"""Calls the action."""
results = []
for _ in range(self.n):
result = self.action.clone(deep=True)(session=session, **kwargs)
results.append(result)
if self.reduce_fn is not None:
return self.reduce_fn(results)
return results


# Type definition for traced item during execution.
TracedItem = Union[
Expand Down
23 changes: 23 additions & 0 deletions langfun/core/agentic/action_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ def make_additional_query(self, lm):
json_str = session.to_json_str(save_ref_value=True)
self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)

def test_sample(self):
class Foo(action_lib.Action):

def call(self, session, x=1):
return x

session = action_lib.Session()
self.assertEqual(Foo().sample(session, n=2, x=1, reduce_fn=sum), 2)
self.assertIsNotNone(session.root)
self.assertEqual(len(session.root.execution), 1)
topmost_invocation = session.root.execution.items[0]
self.assertEqual(topmost_invocation.result, 2)
self.assertEqual(
topmost_invocation.action,
action_lib.ActionSampling(Foo(), n=2, reduce_fn=sum)
)
self.assertEqual(len(topmost_invocation.execution.items), 2)
self.assertEqual(topmost_invocation.execution.items[0].action, Foo())
self.assertEqual(topmost_invocation.execution.items[0].result, 1)

# No reduce_fn.
self.assertEqual(Foo().sample(session, n=2, x=1), [1, 1])

def test_log(self):
session = action_lib.Session()
session.debug('hi', x=1, y=2)
Expand Down
Loading