Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion langfun/core/structured/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,13 +489,25 @@ def parse_result(self, lm_output: lf.Message) -> Any:
response_text = '\n'.join(
tc.text for tc in lm_output.metadata['tool_calls']
)

# Extract thought from metadata if present (for thought-aware autofix)
thought = None
if 'thought' in lm_output.metadata:
thought_message = lm_output.metadata['thought']
thought = (
thought_message.text
if hasattr(thought_message, 'text')
else str(thought_message)
)

return schema.parse(
response_text,
protocol=self.protocol,
additional_context=self.globals(),
autofix=self.autofix,
autofix_lm=self.autofix_lm or self.lm,
permission=self.permission,
thought=thought,
)

def postprocess_response(self, response: lf.Message) -> lf.Message:
Expand All @@ -513,4 +525,3 @@ def postprocess_result(self, result: Any) -> Any:
def globals(self) -> dict[str, Any]:
"""Gets additional symbol definitions besides schema as globals."""
return {'ModalityRef': lf.modality.ModalityRef}

55 changes: 55 additions & 0 deletions langfun/core/structured/mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,5 +226,60 @@ class Answer:
)


class ThoughtExtractionTest(unittest.TestCase):
"""Test that thought is correctly extracted from lm_output.metadata."""

def test_thought_extraction_from_aimessage(self):
"""Test that thought is extracted when it's an AIMessage."""
# Create a mock LM output with thought as AIMessage in metadata
lm_output = lf.AIMessage('## answer\n2')
lm_output.metadata['thought'] = lf.AIMessage(
'Let me think... 1 + 1 equals 2'
)

# Simulate the extraction logic from parse_result
thought = None
if 'thought' in lm_output.metadata:
thought_message = lm_output.metadata['thought']
thought = (
thought_message.text
if hasattr(thought_message, 'text')
else str(thought_message)
)

self.assertEqual(thought, 'Let me think... 1 + 1 equals 2')

def test_thought_extraction_from_string(self):
"""Test that thought is extracted when it's a plain string."""
lm_output = lf.AIMessage('## answer\n2')
lm_output.metadata['thought'] = 'Direct string thought'

thought = None
if 'thought' in lm_output.metadata:
thought_message = lm_output.metadata['thought']
thought = (
thought_message.text
if hasattr(thought_message, 'text')
else str(thought_message)
)

self.assertEqual(thought, 'Direct string thought')

def test_no_thought_in_metadata(self):
"""Test that thought is None when not present in metadata."""
lm_output = lf.AIMessage('## answer\n2')

thought = None
if 'thought' in lm_output.metadata:
thought_message = lm_output.metadata['thought']
thought = (
thought_message.text
if hasattr(thought_message, 'text')
else str(thought_message)
)

self.assertIsNone(thought)


if __name__ == '__main__':
unittest.main()
48 changes: 48 additions & 0 deletions langfun/core/structured/querying.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class _MyLfQuery(LFQuery):
_DEFAULT_PROTOCOL_VERSIONS: ClassVar[dict[str, str]] = {
'python': '2.0',
'json': '1.0',
'markdown': '1.0',
}

def __init_subclass__(cls) -> Any:
Expand Down Expand Up @@ -273,6 +274,53 @@ class Answer:
)


class _LfQueryMarkdownV1(LfQuery):
"""Query a structured value using Markdown as the protocol."""

preamble = """
Please respond to the last {{ input_title }} with {{ output_title }} in markdown format according to {{ schema_title }}.

{{ input_title }}:
Write a solution for 1 + 1

{{ schema_title }}:
Answer with fields: reasoning (str), result (int)

{{ output_title }}:
## reasoning
Adding 1 and 1 gives us 2

## result
2
"""
version = '1.0'
protocol = 'markdown'
input_title = 'REQUEST'
schema_title = 'OUTPUT SCHEMA'
output_title = 'OUTPUT'
mapping_template = lf.Template("""
{%- if example.context -%}
{{ context_title}}:
{{ example.context | indent(2, True)}}

{% endif -%}

{{ input_title }}:
{{ example.input_repr(protocol, compact=False) | indent(2, True) }}

{% if example.schema -%}
{{ schema_title }}:
{{ example.schema_repr(protocol) | indent(2, True) }}

{% endif -%}

{{ output_title }}:
{%- if example.has_output %}
{{ example.output_repr(protocol, compact=False) | indent(2, True) }}
{% endif -%}
""")


def query(
prompt: Union[str, lf.Template, lf.Message, Any],
schema: schema_lib.SchemaType | None = None,
Expand Down
75 changes: 75 additions & 0 deletions langfun/core/structured/querying_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,5 +1747,80 @@ def make_query(prompt):
self.assertEqual(len(queries), 2)


class LfQueryMarkdownV1Test(unittest.TestCase):

def test_from_protocol(self):
"""Test that markdown protocol is registered correctly."""
self.assertIs(
querying.LfQuery.from_protocol('markdown'), querying._LfQueryMarkdownV1
)
self.assertIs(
querying.LfQuery.from_protocol('markdown:1.0'),
querying._LfQueryMarkdownV1,
)

def test_render_no_examples(self):
"""Test markdown protocol rendering without examples."""

class Answer(pg.Object):
reasoning: str
result: int

l = querying.LfQuery.from_protocol('markdown:1.0')(
input=lf.AIMessage('Solve 1 + 1'), schema=Answer
)
rendered = l.render().text
# Check that the markdown format is used
self.assertIn('REQUEST:', rendered)
self.assertIn('OUTPUT SCHEMA:', rendered)
self.assertIn('OUTPUT:', rendered)
# Check preamble example
self.assertIn('## reasoning', rendered)
self.assertIn('## result', rendered)

def test_render_with_examples(self):
"""Test markdown protocol rendering with examples."""

class Answer(pg.Object):
reasoning: str
result: int

l = querying.LfQuery.from_protocol('markdown:1.0')(
input=lf.AIMessage('Solve 2 + 2'),
schema=Answer,
examples=[
mapping.MappingExample(
input='Solve 1 + 1',
output=Answer(reasoning='Adding 1 and 1', result=2),
),
],
)
rendered = l.render().text
# Check that examples are included
self.assertIn('Solve 1 + 1', rendered)
self.assertIn('Adding 1 and 1', rendered)
self.assertIn('Solve 2 + 2', rendered)

def test_query_with_markdown_protocol(self):
"""Test end-to-end query with markdown protocol."""

class Answer(pg.Object):
reasoning: str
result: int

lm = fake.StaticResponse("""
## reasoning
Adding 1 and 1 gives us 2

## result
2
""")

result = querying.query('Solve 1 + 1', Answer, lm=lm, protocol='markdown')

self.assertEqual(result.reasoning, 'Adding 1 and 1 gives us 2')
self.assertEqual(result.result, 2)


if __name__ == '__main__':
unittest.main()
Loading