From 53657d71ae89d901f8dbd2e5afdbefdc1764e215 Mon Sep 17 00:00:00 2001 From: Langfun Authors Date: Sun, 23 Feb 2025 15:28:48 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 730215585 --- langfun/core/agentic/action.py | 78 ++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/langfun/core/agentic/action.py b/langfun/core/agentic/action.py index 818f5b32..444200ec 100644 --- a/langfun/core/agentic/action.py +++ b/langfun/core/agentic/action.py @@ -795,6 +795,84 @@ def query( **kwargs ) + def query_prompt( + self, + prompt: Union[str, lf.Template, Any], + schema: Union[ + lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None + ] = None, + **kwargs, + ) -> Any: + """Calls `lf.query_prompt` and associates it with the current invocation. + + The following code are equivalent: + + Code 1: + ``` + session.query_prompt(...) + ``` + + Code 2: + ``` + with session.track_queries() as queries: + output = lf.query_prompt(...) + ``` + The former is preferred when `lf.query_prompt` is directly called by the + action. + If `lf.query_prompt` is called by a function that does not have access to + the + session, the latter should be used. + + Args: + prompt: The prompt to query. + schema: The schema to use for the query. + **kwargs: Additional keyword arguments to pass to `lf.query_prompt`. + + Returns: + The result of the query. + """ + with self.track_queries(): + return lf_structured.query_prompt(prompt, schema=schema, **kwargs) + + def query_output( + self, + response: Union[str, lf.Template, Any], + schema: Union[ + lf_structured.Schema, Type[Any], list[Type[Any]], dict[str, Any], None + ] = None, + **kwargs, + ) -> Any: + """Calls `lf.query_output` and associates it with the current invocation. + + The following code are equivalent: + + Code 1: + ``` + session.query_output(...) + ``` + + Code 2: + ``` + with session.track_queries() as queries: + output = lf.query_output(...) + ``` + The former is preferred when `lf.query_output` is directly called by the + action. + If `lf.query_output` is called by a function that does not have access to + the + session, the latter should be used. + + Args: + response: The response to query. + schema: The schema to use for the query. + **kwargs: Additional keyword arguments to pass to `lf.query_prompt`. + + Returns: + The result of the query. + """ + with self.track_queries(): + return lf_structured.query_output(response, schema=schema, **kwargs) + def _log(self, level: lf.logging.LogLevel, message: str, **kwargs): self._current_action.current_phase.append( lf.logging.LogEntry(