diff --git a/langfun/core/structured/mapping.py b/langfun/core/structured/mapping.py index a21c0377..cdd96319 100644 --- a/langfun/core/structured/mapping.py +++ b/langfun/core/structured/mapping.py @@ -489,6 +489,17 @@ def parse_result(self, lm_output: lf.Message) -> Any: response_text = '\n'.join( tc.text for tc in lm_output.metadata['tool_calls'] ) + + # Extract thought from metadata if present (for thought-aware autofix) + thought = None + if 'thought' in lm_output.metadata: + thought_message = lm_output.metadata['thought'] + thought = ( + thought_message.text + if hasattr(thought_message, 'text') + else str(thought_message) + ) + return schema.parse( response_text, protocol=self.protocol, @@ -496,6 +507,7 @@ def parse_result(self, lm_output: lf.Message) -> Any: autofix=self.autofix, autofix_lm=self.autofix_lm or self.lm, permission=self.permission, + thought=thought, ) def postprocess_response(self, response: lf.Message) -> lf.Message: @@ -513,4 +525,3 @@ def postprocess_result(self, result: Any) -> Any: def globals(self) -> dict[str, Any]: """Gets additional symbol definitions besides schema as globals.""" return {'ModalityRef': lf.modality.ModalityRef} - diff --git a/langfun/core/structured/mapping_test.py b/langfun/core/structured/mapping_test.py index 192c6607..eff67ed6 100644 --- a/langfun/core/structured/mapping_test.py +++ b/langfun/core/structured/mapping_test.py @@ -226,5 +226,60 @@ class Answer: ) +class ThoughtExtractionTest(unittest.TestCase): + """Test that thought is correctly extracted from lm_output.metadata.""" + + def test_thought_extraction_from_aimessage(self): + """Test that thought is extracted when it's an AIMessage.""" + # Create a mock LM output with thought as AIMessage in metadata + lm_output = lf.AIMessage('## answer\n2') + lm_output.metadata['thought'] = lf.AIMessage( + 'Let me think... 1 + 1 equals 2' + ) + + # Simulate the extraction logic from parse_result + thought = None + if 'thought' in lm_output.metadata: + thought_message = lm_output.metadata['thought'] + thought = ( + thought_message.text + if hasattr(thought_message, 'text') + else str(thought_message) + ) + + self.assertEqual(thought, 'Let me think... 1 + 1 equals 2') + + def test_thought_extraction_from_string(self): + """Test that thought is extracted when it's a plain string.""" + lm_output = lf.AIMessage('## answer\n2') + lm_output.metadata['thought'] = 'Direct string thought' + + thought = None + if 'thought' in lm_output.metadata: + thought_message = lm_output.metadata['thought'] + thought = ( + thought_message.text + if hasattr(thought_message, 'text') + else str(thought_message) + ) + + self.assertEqual(thought, 'Direct string thought') + + def test_no_thought_in_metadata(self): + """Test that thought is None when not present in metadata.""" + lm_output = lf.AIMessage('## answer\n2') + + thought = None + if 'thought' in lm_output.metadata: + thought_message = lm_output.metadata['thought'] + thought = ( + thought_message.text + if hasattr(thought_message, 'text') + else str(thought_message) + ) + + self.assertIsNone(thought) + + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/structured/querying.py b/langfun/core/structured/querying.py index 646de51a..20aac613 100644 --- a/langfun/core/structured/querying.py +++ b/langfun/core/structured/querying.py @@ -85,6 +85,7 @@ class _MyLfQuery(LFQuery): _DEFAULT_PROTOCOL_VERSIONS: ClassVar[dict[str, str]] = { 'python': '2.0', 'json': '1.0', + 'markdown': '1.0', } def __init_subclass__(cls) -> Any: @@ -273,6 +274,53 @@ class Answer: ) +class _LfQueryMarkdownV1(LfQuery): + """Query a structured value using Markdown as the protocol.""" + + preamble = """ + Please respond to the last {{ input_title }} with {{ output_title }} in markdown format according to {{ schema_title }}. + + {{ input_title }}: + Write a solution for 1 + 1 + + {{ schema_title }}: + Answer with fields: reasoning (str), result (int) + + {{ output_title }}: + ## reasoning + Adding 1 and 1 gives us 2 + + ## result + 2 + """ + version = '1.0' + protocol = 'markdown' + input_title = 'REQUEST' + schema_title = 'OUTPUT SCHEMA' + output_title = 'OUTPUT' + mapping_template = lf.Template(""" + {%- if example.context -%} + {{ context_title}}: + {{ example.context | indent(2, True)}} + + {% endif -%} + + {{ input_title }}: + {{ example.input_repr(protocol, compact=False) | indent(2, True) }} + + {% if example.schema -%} + {{ schema_title }}: + {{ example.schema_repr(protocol) | indent(2, True) }} + + {% endif -%} + + {{ output_title }}: + {%- if example.has_output %} + {{ example.output_repr(protocol, compact=False) | indent(2, True) }} + {% endif -%} + """) + + def query( prompt: Union[str, lf.Template, lf.Message, Any], schema: schema_lib.SchemaType | None = None, diff --git a/langfun/core/structured/querying_test.py b/langfun/core/structured/querying_test.py index bed113e1..e3cec4c8 100644 --- a/langfun/core/structured/querying_test.py +++ b/langfun/core/structured/querying_test.py @@ -1747,5 +1747,80 @@ def make_query(prompt): self.assertEqual(len(queries), 2) +class LfQueryMarkdownV1Test(unittest.TestCase): + + def test_from_protocol(self): + """Test that markdown protocol is registered correctly.""" + self.assertIs( + querying.LfQuery.from_protocol('markdown'), querying._LfQueryMarkdownV1 + ) + self.assertIs( + querying.LfQuery.from_protocol('markdown:1.0'), + querying._LfQueryMarkdownV1, + ) + + def test_render_no_examples(self): + """Test markdown protocol rendering without examples.""" + + class Answer(pg.Object): + reasoning: str + result: int + + l = querying.LfQuery.from_protocol('markdown:1.0')( + input=lf.AIMessage('Solve 1 + 1'), schema=Answer + ) + rendered = l.render().text + # Check that the markdown format is used + self.assertIn('REQUEST:', rendered) + self.assertIn('OUTPUT SCHEMA:', rendered) + self.assertIn('OUTPUT:', rendered) + # Check preamble example + self.assertIn('## reasoning', rendered) + self.assertIn('## result', rendered) + + def test_render_with_examples(self): + """Test markdown protocol rendering with examples.""" + + class Answer(pg.Object): + reasoning: str + result: int + + l = querying.LfQuery.from_protocol('markdown:1.0')( + input=lf.AIMessage('Solve 2 + 2'), + schema=Answer, + examples=[ + mapping.MappingExample( + input='Solve 1 + 1', + output=Answer(reasoning='Adding 1 and 1', result=2), + ), + ], + ) + rendered = l.render().text + # Check that examples are included + self.assertIn('Solve 1 + 1', rendered) + self.assertIn('Adding 1 and 1', rendered) + self.assertIn('Solve 2 + 2', rendered) + + def test_query_with_markdown_protocol(self): + """Test end-to-end query with markdown protocol.""" + + class Answer(pg.Object): + reasoning: str + result: int + + lm = fake.StaticResponse(""" +## reasoning +Adding 1 and 1 gives us 2 + +## result +2 +""") + + result = querying.query('Solve 1 + 1', Answer, lm=lm, protocol='markdown') + + self.assertEqual(result.reasoning, 'Adding 1 and 1 gives us 2') + self.assertEqual(result.result, 2) + + if __name__ == '__main__': unittest.main() diff --git a/langfun/core/structured/schema.py b/langfun/core/structured/schema.py index 96c02934..673cda2c 100644 --- a/langfun/core/structured/schema.py +++ b/langfun/core/structured/schema.py @@ -73,7 +73,7 @@ def _parse_node(v) -> pg.typing.ValueSpec: return _parse_node(value) -SchemaProtocol = Literal['json', 'python'] +SchemaProtocol = Literal['json', 'python', 'markdown'] class SchemaError(Exception): # pylint: disable=g-bad-exception-name @@ -729,6 +729,57 @@ def _visit(node: Any) -> None: return out.getvalue() +class SchemaMarkdownRepr(SchemaRepr): + """Markdown-representation for a schema.""" + + def repr(self, schema: Schema, **kwargs) -> str: + """Generate markdown schema description.""" + del kwargs + out = io.StringIO() + out.write('Provide your response in the following markdown format:\n\n') + + def _visit_field( + field_name: str, field_spec: pg.typing.ValueSpec, level: int = 2 + ) -> None: + """Visit a field and generate markdown section.""" + header = '#' * level + out.write(f'{header} {field_name}\n') + + # Add field description/type hint + if isinstance(field_spec, pg.typing.Str): + # Check if this looks like a code field + if field_name.endswith('_code') or 'code' in field_name.lower(): + # Suggest code block format + if 'cpp' in field_name: + lang = 'cpp' + elif 'bash' in field_name or 'terminal' in field_name: + lang = 'bash' + else: + lang = 'python' + out.write(f'```{lang}\n\n```\n\n') + else: + out.write(f'<{field_spec.value_type.__name__}>\n\n') + elif isinstance(field_spec, pg.typing.Object): + # Nested object - recurse with deeper level + for k, f in field_spec.cls.__schema__.items(): + if isinstance(k, pg.typing.ConstStrKey): + _visit_field(str(k), f.value, level + 1) + elif isinstance(field_spec, pg.typing.List): + out.write( + f'\n\n' + ) + else: + out.write(f'<{field_spec.value_type.__name__}>\n\n') + + # Process schema fields + if isinstance(schema.spec, pg.typing.Object): + for key, field in schema.spec.cls.__schema__.items(): + if isinstance(key, pg.typing.ConstStrKey): + _visit_field(str(key), field.value) + + return out.getvalue() + + # # Value representations. # @@ -896,6 +947,170 @@ def parse(self, text: str, schema: Schema | None = None, **kwargs) -> Any: return v['result'] +class ValueMarkdownRepr(ValueRepr): + """Markdown-representation for value.""" + + def repr(self, value: Any, schema: Schema | None = None, **kwargs) -> str: + """Convert value to markdown format.""" + del schema, kwargs + out = io.StringIO() + + if isinstance(value, pg.Object): + for key, val in value.sym_items(): + out.write(f'## {key}\n') + if isinstance(val, str): + # Check if it looks like code + if '\n' in val and (key.endswith('_code') or 'code' in key.lower()): + # Detect language + if '#include' in val: + lang = 'cpp' + elif '#!/bin/bash' in val or 'cat >' in val: + lang = 'bash' + else: + lang = 'python' + out.write(f'```{lang}\n{val}\n```\n\n') + else: + out.write(f'{val}\n\n') + else: + out.write(f'{val}\n\n') + + return out.getvalue() + + def parse( + self, + text: str, + schema: Schema | None = None, + autofix: int = 0, + autofix_lm: lf.LanguageModel = lf.contextual(), + **kwargs, + ) -> Any: + """Parse markdown text into structured object.""" + del kwargs + if schema is None or not isinstance(schema.spec, pg.typing.Object): + raise ValueError('Markdown protocol requires a pg.Object schema') + + # Try to parse, with autofix if enabled + if autofix == 0: + return self._parse_markdown(text, schema) + + # With autofix: use correction mechanism + error = None + for attempt in range(autofix + 1): + try: + return self._parse_markdown(text, schema) + except Exception as e: # pylint: disable=broad-exception-caught + error = e + if attempt < autofix: + # Try to fix the markdown using LLM + text = self._fix_markdown(text, schema, error, autofix_lm) + else: + raise + + # Should not reach here, but just in case + raise error # type: ignore + + def _parse_markdown(self, text: str, schema: Schema) -> Any: + """Internal method to parse markdown text.""" + result = {} + + # Extract sections for each field + for key, field in schema.spec.cls.__schema__.items(): + if not isinstance(key, pg.typing.ConstStrKey): + continue + + field_name = str(key) + section_content = self._extract_section(text, field_name) + + if section_content is None: + # Field not found - check if it's required + if not field.value.is_noneable: + raise ValueError( + f'Required field "{field_name}" not found in markdown' + ) + result[field_name] = None + continue + + # Parse based on field type + if isinstance(field.value, pg.typing.Str): + # Try to extract code block first + code = self._extract_code_block(section_content) + result[field_name] = code if code else section_content.strip() + elif isinstance(field.value, (pg.typing.Int, pg.typing.Float)): + result[field_name] = field.value.value_type(section_content.strip()) + elif isinstance(field.value, pg.typing.Bool): + result[field_name] = section_content.strip().lower() in ( + 'true', + 'yes', + '1', + ) + else: + result[field_name] = section_content.strip() + + # Create object instance + return schema.spec.cls(**result) + + def _fix_markdown( + self, + text: str, + schema: Schema, + error: Exception, + lm: lf.LanguageModel, + ) -> str: + """Fix malformed markdown using LLM.""" + # Delay import at runtime to avoid circular dependency. + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured import querying + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + # Build schema description + schema_desc = schema.schema_str('markdown') + + # Build correction prompt + correction_prompt = f"""The following markdown output has an error: + +```markdown +{text} +``` + +Error: {error} + +Expected schema: +{schema_desc} + + +Please provide the corrected markdown output that matches the expected schema.""" + + # Query LLM for correction (disable autofix to avoid recursion) + corrected = querying.query( + correction_prompt, + str, + lm=lm, + autofix=0, + ) + + return corrected + + def _extract_section(self, text: str, section_name: str) -> str | None: + """Extract content from a markdown section.""" + # Match: ## section_name\n (until next ## or end) + pattern = rf'##\s+{re.escape(section_name)}\s*\n(.*?)(?=\n##|\Z)' + match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + return None + + def _extract_code_block(self, section_content: str) -> str | None: + """Extract code from markdown code block.""" + # Match: ```language\n\n``` + pattern = r'```[\w]*\s*\n(.*?)\n```' + match = re.search(pattern, section_content, re.DOTALL) + if match: + return match.group(1) + return None + + def cleanup_json(json_str: str) -> str: """Cleans up the LM responded JSON string.""" # Treatments: @@ -953,6 +1168,8 @@ def schema_repr(protocol: SchemaProtocol) -> SchemaRepr: return SchemaJsonRepr() elif protocol == 'python': return SchemaPythonRepr() + elif protocol == 'markdown': + return SchemaMarkdownRepr() raise ValueError(f'Unsupported protocol: {protocol}.') @@ -961,6 +1178,8 @@ def value_repr(protocol: SchemaProtocol) -> ValueRepr: return ValueJsonRepr() elif protocol == 'python': return ValuePythonRepr() + elif protocol == 'markdown': + return ValueMarkdownRepr() raise ValueError(f'Unsupported protocol: {protocol}.') diff --git a/langfun/core/structured/schema_test.py b/langfun/core/structured/schema_test.py index e6d1d7bd..2eab20c2 100644 --- a/langfun/core/structured/schema_test.py +++ b/langfun/core/structured/schema_test.py @@ -966,6 +966,260 @@ class A(pg.Object): ) +class SchemaMarkdownReprTest(unittest.TestCase): + + def test_repr_simple(self): + """Test markdown schema representation for simple types.""" + + class Solution(pg.Object): + reasoning: str + code: str + + schema = schema_lib.Schema(Solution) + markdown_repr = schema_lib.SchemaMarkdownRepr().repr(schema) + + self.assertIn('## reasoning', markdown_repr) + self.assertIn('## code', markdown_repr) + self.assertIn( + '```', markdown_repr + ) # Should suggest code block for 'code' field + + def test_repr_with_code_fields(self): + """Test automatic code block detection for *_code fields.""" + + class SolutionWithTests(pg.Object): + cpp_code: str + terminal_code: str + bash_code: str + + schema = schema_lib.Schema(SolutionWithTests) + markdown_repr = schema_lib.SchemaMarkdownRepr().repr(schema) + + # Should detect language from field name + self.assertIn('```cpp', markdown_repr) + self.assertIn('```bash', markdown_repr) + + def test_repr_nested_object(self): + """Test markdown schema with nested objects.""" + + class Inner(pg.Object): + value: int + + class Outer(pg.Object): + inner: Inner + name: str + + schema = schema_lib.Schema(Outer) + markdown_repr = schema_lib.SchemaMarkdownRepr().repr(schema) + + self.assertIn('## inner', markdown_repr) + self.assertIn('### value', markdown_repr) # Nested field uses ### + + +class ValueMarkdownReprTest(unittest.TestCase): + + def test_repr(self): + """Test markdown value representation.""" + + class Solution(pg.Object): + reasoning: str + cpp_code: str + + solution = Solution( + reasoning='Use dynamic programming', + cpp_code='#include \nint main() { return 0; }', + ) + + markdown_repr = schema_lib.ValueMarkdownRepr().repr(solution) + + self.assertIn('## reasoning', markdown_repr) + self.assertIn('Use dynamic programming', markdown_repr) + self.assertIn('## cpp_code', markdown_repr) + self.assertIn('```cpp', markdown_repr) # Should detect C++ code + self.assertIn('#include ', markdown_repr) + + def test_parse_simple(self): + """Test parsing markdown text into structured object.""" + + class Solution(pg.Object): + reasoning: str + code: str + + markdown_text = """ +## reasoning +This is my reasoning + +## code +```python +def foo(): + pass +``` +""" + + schema = schema_lib.Schema(Solution) + result = schema_lib.ValueMarkdownRepr().parse(markdown_text, schema) + + self.assertEqual(result.reasoning, 'This is my reasoning') + self.assertEqual(result.code, 'def foo():\n pass') + + def test_parse_with_code_blocks(self): + """Test parsing code blocks from markdown.""" + + class SolutionWithTests(pg.Object): + cpp_code: str + terminal_code: str + + markdown_text = """ +## cpp_code +```cpp +#include +int main() { return 0; } +``` + +## terminal_code +```bash +g++ -o solution solution.cpp +./solution +``` +""" + + schema = schema_lib.Schema(SolutionWithTests) + result = schema_lib.ValueMarkdownRepr().parse(markdown_text, schema) + + self.assertIn('#include ', result.cpp_code) + self.assertIn('g++ -o solution', result.terminal_code) + + def test_parse_missing_required_field(self): + """Test that missing required fields raise ValueError.""" + + class Solution(pg.Object): + reasoning: str + code: str + + markdown_text = """ +## reasoning +This is my reasoning +""" + + schema = schema_lib.Schema(Solution) + with self.assertRaisesRegex(ValueError, 'Required field "code" not found'): + schema_lib.ValueMarkdownRepr().parse(markdown_text, schema) + + def test_parse_optional_field(self): + """Test parsing with optional fields.""" + + class Solution(pg.Object): + reasoning: str + code: str | None + + markdown_text = """ +## reasoning +This is my reasoning +""" + + schema = schema_lib.Schema(Solution) + result = schema_lib.ValueMarkdownRepr().parse(markdown_text, schema) + + self.assertEqual(result.reasoning, 'This is my reasoning') + self.assertIsNone(result.code) + + def test_extract_section(self): + """Test section extraction helper method.""" + markdown_text = """ +## section1 +Content 1 + +## section2 +Content 2 +""" + + repr_obj = schema_lib.ValueMarkdownRepr() + section1 = repr_obj._extract_section(markdown_text, 'section1') + section2 = repr_obj._extract_section(markdown_text, 'section2') + + self.assertEqual(section1, 'Content 1') + self.assertEqual(section2, 'Content 2') + + def test_extract_code_block(self): + """Test code block extraction helper method.""" + section_content = """ +Some text +```python +def foo(): + pass +``` +More text +""" + + repr_obj = schema_lib.ValueMarkdownRepr() + code = repr_obj._extract_code_block(section_content) + + self.assertEqual(code, 'def foo():\n pass') + + def test_autofix_with_missing_field(self): + """Test that autofix is triggered when a required field is missing.""" + + class Solution(pg.Object): + reasoning: str + code: str + + # Markdown missing the 'code' field + markdown_text = """ +## reasoning +This is my reasoning +""" + + # Create a fake LLM that will provide the missing field + corrected_markdown = """ +## reasoning +This is my reasoning + +## code +```python +def solution(): + pass +``` +""" + + fix_lm = fake.StaticResponse(corrected_markdown) + + schema = schema_lib.Schema(Solution) + + # With autofix=1, should call fix_lm and succeed + result = schema_lib.ValueMarkdownRepr().parse( + markdown_text, + schema, + autofix=1, + autofix_lm=fix_lm, + ) + + self.assertEqual(result.reasoning, 'This is my reasoning') + self.assertEqual(result.code, 'def solution():\n pass') + + def test_autofix_not_triggered_when_disabled(self): + """Test that autofix is not triggered when autofix=0.""" + + class Solution(pg.Object): + reasoning: str + code: str + + # Markdown missing the 'code' field + markdown_text = """ +## reasoning +This is my reasoning +""" + + schema = schema_lib.Schema(Solution) + + # With autofix=0, should raise ValueError + with self.assertRaisesRegex(ValueError, 'Required field "code" not found'): + schema_lib.ValueMarkdownRepr().parse( + markdown_text, + schema, + autofix=0, + ) + + class UnknownTest(unittest.TestCase): def test_basics(self):