diff --git a/langfun/core/structured/querying.py b/langfun/core/structured/querying.py index c62e5600..233d4c75 100644 --- a/langfun/core/structured/querying.py +++ b/langfun/core/structured/querying.py @@ -85,6 +85,7 @@ class _MyLfQuery(LFQuery): _DEFAULT_PROTOCOL_VERSIONS: ClassVar[dict[str, str]] = { 'python': '2.0', 'json': '1.0', + 'markdown': '1.0', } def __init_subclass__(cls) -> Any: @@ -273,6 +274,53 @@ class Answer: ) +class _LfQueryMarkdownV1(LfQuery): + """Query a structured value using Markdown as the protocol.""" + + preamble = """ + Please respond to the last {{ input_title }} with {{ output_title }} only according to {{ schema_title }}. + + {{ input_title }}: + 1 + 1 = + + {{ schema_title }}: + Answer + + ## final_answer + ... + + {{ output_title }}: + ## final_answer + 2 + """ + version = '1.0' + protocol = 'markdown' + input_title = 'REQUEST' + schema_title = 'OUTPUT MARKDOWN SCHEMA' + output_title = 'OUTPUT MARKDOWN' + mapping_template = lf.Template(""" + {%- if example.context -%} + {{ context_title}}: + {{ example.context | indent(2, True)}} + + {% endif -%} + + {{ input_title }}: + {{ example.input_repr(protocol, compact=False) | indent(2, True) }} + + {% if example.schema -%} + {{ schema_title }}: + {{ example.schema_repr(protocol) | indent(2, True) }} + + {% endif -%} + + {{ output_title }}: + {%- if example.has_output %} + {{ example.output_repr(protocol, compact=False) | indent(2, True) }} + {% endif -%} + """) + + def query( prompt: Union[str, lf.Template, lf.Message, Any], schema: schema_lib.SchemaType | None = None, diff --git a/langfun/core/structured/schema/__init__.py b/langfun/core/structured/schema/__init__.py index d56e1b81..f4a23e27 100644 --- a/langfun/core/structured/schema/__init__.py +++ b/langfun/core/structured/schema/__init__.py @@ -47,3 +47,6 @@ from langfun.core.structured.schema.python import class_definition from langfun.core.structured.schema.python import class_definitions from langfun.core.structured.schema.python import include_method_in_prompt + +# Markdown protocol. +from langfun.core.structured.schema.markdown import MarkdownPromptingProtocol diff --git a/langfun/core/structured/schema/markdown.py b/langfun/core/structured/schema/markdown.py new file mode 100644 index 00000000..5b1c1e69 --- /dev/null +++ b/langfun/core/structured/schema/markdown.py @@ -0,0 +1,831 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Markdown-based prompting protocol.""" + +import re +from typing import Any +import langfun.core as lf +from langfun.core.structured.schema import base +import pyglove as pg + + +class MarkdownPromptingProtocol(base.PromptingProtocol): + """Markdown-based prompting protocol.""" + + NAME = 'markdown' + + def schema_repr(self, schema: base.Schema, **kwargs) -> str: + """Returns markdown representation of the schema.""" + del kwargs + lines = [] + + # Only process Object schemas + if not isinstance(schema.spec, pg.typing.Object): + raise ValueError( + f'Markdown protocol only supports Object schemas, got: {schema.spec}' + ) + + cls = schema.spec.cls + for key, field in cls.__schema__.items(): + if not isinstance(key, pg.typing.ConstStrKey): + continue + + field_name = str(key) + lines.append(f'## {field_name}') + + # Handle Union type + if isinstance(field.value, pg.typing.Union): + # Separate code classes from other candidates + code_classes = [] # [(cls, field_name, language)] + other_candidates = [] + + for candidate in field.value.candidates: + if isinstance(candidate, pg.typing.Object): + # Check if this is a code class (single *_code field) + cls_fields = list(candidate.cls.__schema__.items()) + if len(cls_fields) == 1: + cls_key, cls_field = cls_fields[0] + if isinstance(cls_key, pg.typing.ConstStrKey): + cls_field_name = str(cls_key) + if cls_field_name.endswith('_code') and isinstance( + cls_field.value, pg.typing.Str + ): + # This is a code class + language = self._detect_code_language(cls_field_name) + code_classes.append((candidate.cls, cls_field_name, language)) + continue + # Not a code class + other_candidates.append(candidate) + else: + other_candidates.append(candidate) + + if code_classes: + # Special format: code blocks OR python objects + lines.append('Choose ONE of:') + lines.append('') + + # Add code block options + for i, (cls, _, language) in enumerate(code_classes): + lines.append(f'```{language}') + lines.append('...') + lines.append('```') + if i < len(code_classes) - 1 or other_candidates: + lines.append('') + lines.append('OR') + lines.append('') + + # Add other object options + if other_candidates: + lines.append('```pyobject') + for i, candidate in enumerate(other_candidates): + if isinstance(candidate, pg.typing.Object): + lines.append(f'{candidate.cls.__name__}(...)') + elif candidate == pg.typing.MISSING_VALUE: + lines.append('None') + if i < len(other_candidates) - 1: + lines.append('# OR') + lines.append('```') + else: + # Normal Union format (no code classes) + union_annotation = base.annotation(field.value) + lines.append('```pyobject') + lines.append(union_annotation) + lines.append('```') + # Handle List type + elif isinstance(field.value, pg.typing.List): + element_spec = field.value.element.value + if isinstance(element_spec, pg.typing.Object): + # List of Objects - use pyobject code block for type + list_annotation = base.annotation(field.value) + lines.append('```pyobject') + lines.append(list_annotation) + lines.append('```') + lines.append('') + # Show nested structure + lines.append(f'### {element_spec.cls.__name__} 1') + lines.append('') + for obj_key, _ in element_spec.cls.__schema__.items(): + if isinstance(obj_key, pg.typing.ConstStrKey): + obj_field_name = str(obj_key) + lines.append(f'#### {obj_field_name}') + lines.append('...') + lines.append('') + else: + # List of primitives - use angle brackets + list_annotation = base.annotation(field.value) + lines.append(f'<{list_annotation}>') + lines.append('') + lines.append('- item 1') + lines.append('- item 2') + lines.append('- ...') + # Handle string type + elif isinstance(field.value, pg.typing.Str): + lines.append('') + lines.append('') + # Detect if this is a code field and suggest code block + if field_name.endswith('_code') or field_name == 'code': + language = self._detect_code_language(field_name) + lines.append(f'```{language}') + lines.append('...') + lines.append('```') + else: + lines.append('...') + # Handle other primitive types + elif isinstance(field.value, pg.typing.Int): + lines.append('') + lines.append('') + lines.append('...') + elif isinstance(field.value, pg.typing.Float): + lines.append('') + lines.append('') + lines.append('...') + elif isinstance(field.value, pg.typing.Bool): + lines.append('bool') + lines.append('') + lines.append('...') + else: + # Unknown type, just show placeholder + lines.append('...') + lines.append('') + + # Add Python class definitions for all dependent types + # This helps LLM understand the structure of Object types used in + # Union fields + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured.schema import python + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + py_protocol = python.PythonPromptingProtocol() + class_defs = py_protocol.class_definitions(schema, markdown=True) + + if class_defs: + lines.append('---') + lines.append('') + lines.append('**Type Definitions:**') + lines.append('') + lines.append(class_defs) + + return '\n'.join(lines) + + def _detect_code_language(self, field_name: str) -> str: + """Detects programming language from field name.""" + # Check for language-specific prefixes + if field_name.startswith('cpp_') or field_name.startswith('c++_'): + return 'cpp' + elif field_name.startswith('bash_') or field_name.startswith('shell_'): + return 'bash' + elif field_name.startswith('terminal_'): + return 'bash' + elif field_name.startswith('python_'): + return 'python' + elif field_name.startswith('java_'): + return 'java' + elif field_name.startswith('javascript_') or field_name.startswith('js_'): + return 'javascript' + # Default to python for generic *_code fields + return 'python' + + def value_repr( + self, value: Any, schema: base.Schema | None = None, **kwargs + ) -> str: + """Returns markdown representation of a value.""" + del schema, kwargs + if not isinstance(value, pg.Object): + return str(value) + + lines = [] + for key, val in value.sym_items(): + field_name = str(key) + lines.append(f'## {field_name}') + + # Handle List type + if isinstance(val, list): + for idx, item in enumerate(val, 1): + if isinstance(item, pg.Object): + # Nested Object - use ### for item header + item_type = item.__class__.__name__ + lines.append(f'### {item_type} {idx}') + lines.append('') + # Recursively render object fields with #### + for item_key, item_val in item.sym_items(): + item_field_name = str(item_key) + lines.append(f'#### {item_field_name}') + # Handle code in nested objects + if isinstance(item_val, str): + language = self._detect_code_language_from_content( + item_field_name, item_val + ) + if language: + lines.append(f'```{language}') + lines.append(item_val) + lines.append('```') + else: + lines.append(item_val) + else: + lines.append(str(item_val)) + lines.append('') + else: + # Simple type - use list item + lines.append(f'- {item}') + # Check if value looks like code + elif isinstance(val, str): + language = self._detect_code_language_from_content(field_name, val) + if language: + lines.append(f'```{language}') + lines.append(val) + lines.append('```') + else: + lines.append(val) + # Handle pg.Object values - use pyobject code block + elif isinstance(val, pg.Object): + lines.append('```pyobject') + # Use pg.format to get proper Python representation + lines.append( + pg.format(val, compact=True, verbose=False, python_format=True) + ) + lines.append('```') + else: + lines.append(str(val)) + lines.append('') + + return '\n'.join(lines) + + def _detect_code_language_from_content( + self, field_name: str, content: str + ) -> str | None: + """Detects if content is code and returns language.""" + # First check field name + if field_name.endswith('_code') or field_name == 'code': + # Check content for language hints + if content.strip().startswith('#include'): + return 'cpp' + elif content.strip().startswith('#!/bin/bash') or 'cat >' in content: + return 'bash' + elif 'def ' in content or 'class ' in content: + return 'python' + # Use field name detection as fallback + return self._detect_code_language(field_name) + return None + + def parse_value( + self, + text: str, + schema: base.Schema | None = None, + *, + autofix=0, + autofix_lm: lf.LanguageModel = lf.contextual(), + **kwargs, + ) -> Any: + """Parses markdown text into a structured object.""" + del kwargs + if schema is None: + raise ValueError('Schema is required for markdown parsing') + + # Without autofix: parse directly + if autofix == 0: + return self._parse_markdown(text, schema) + + # With autofix: use correction mechanism + error = None + for attempt in range(autofix + 1): + try: + return self._parse_markdown(text, schema) + except Exception as e: # pylint: disable=broad-exception-caught + error = e + if attempt < autofix: + # Try to fix the markdown using LLM + text = self._fix_markdown(text, schema, error, autofix_lm) + else: + raise + + # Should not reach here, but just in case + raise error # type: ignore + + def _parse_markdown(self, text: str, schema: base.Schema) -> Any: + """Internal method to parse markdown text.""" + if not isinstance(schema.spec, pg.typing.Object): + raise ValueError( + f'Markdown protocol only supports Object schemas, got: {schema.spec}' + ) + + cls = schema.spec.cls + result = {} + + # Get all class dependencies from schema (like Python protocol does) + dependencies = schema.class_dependencies( + include_base_classes=False, include_subclasses=False + ) + all_dependencies = {d.__name__: d for d in dependencies} + + # Extract sections for each field + for key, field in cls.__schema__.items(): + if not isinstance(key, pg.typing.ConstStrKey): + continue + + field_name = str(key) + section_content = self._extract_section(text, field_name) + + if section_content is None: + # Field not found - check if it's required + if not field.value.is_noneable: + raise ValueError( + f'Required field "{field_name}" not found in markdown' + ) + result[field_name] = None + continue + + # Parse based on field type + if isinstance(field.value, pg.typing.Union): + # Handle Union type - try each candidate in order + result[field_name] = self._parse_union_field( + section_content, field.value, field_name, all_dependencies + ) + elif isinstance(field.value, pg.typing.List): + # Parse List type + element_spec = field.value.element.value + + # Check if this is a pyobject block with Python list literal + code_info = self._extract_code_block(section_content) + if code_info and code_info[1] == 'pyobject': + # Parse as Python list literal + # Delay import at runtime to avoid circular dependency. + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured.schema import python + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + # Build global_vars with all classes + global_vars = all_dependencies.copy() + + try: + parsed_list = python.structure_from_python( + code_info[0], + global_vars=global_vars, + permission=pg.coding.CodePermission.CALL, + ) + # Verify it's a list + if isinstance(parsed_list, list): + result[field_name] = parsed_list + else: + raise TypeError( + f'Expected list, got {type(parsed_list).__name__}' + ) + except Exception as e: + raise ValueError( + f'Failed to parse list for field "{field_name}": {e}' + ) from e + elif isinstance(element_spec, pg.typing.Object): + # List of Objects - extract items with ### headers + items = self._extract_list_objects(section_content, element_spec.cls) + result[field_name] = items + else: + # List of primitives - extract list items + items = self._extract_list_items(section_content) + # Convert to appropriate type + if isinstance(element_spec, pg.typing.Int): + result[field_name] = [int(item) for item in items] + elif isinstance(element_spec, pg.typing.Float): + result[field_name] = [float(item) for item in items] + elif isinstance(element_spec, pg.typing.Bool): + result[field_name] = [ + item.lower() in ('true', 'yes', '1') for item in items + ] + else: + result[field_name] = items + elif isinstance(field.value, pg.typing.Str): + # Try to extract code block first + code_info = self._extract_code_block(section_content) + result[field_name] = ( + code_info[0] if code_info else section_content.strip() + ) + elif isinstance(field.value, (pg.typing.Int, pg.typing.Float)): + result[field_name] = field.value.value_type(section_content.strip()) + elif isinstance(field.value, pg.typing.Bool): + result[field_name] = section_content.strip().lower() in ( + 'true', + 'yes', + '1', + ) + elif isinstance(field.value, pg.typing.Object): + # Handle Object type - check if it's a pyobject block + code_info = self._extract_code_block(section_content) + if code_info and code_info[1] == 'pyobject': + # Delegate to Python protocol for parsing + # Delay import at runtime to avoid circular dependency. + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured.schema import python + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + code = code_info[0] + # Strip any backticks that LLMs might add + code = code.strip('`') + + # Build global_vars with all class dependencies + global_vars = all_dependencies.copy() + + try: + result[field_name] = python.structure_from_python( + code, + global_vars=global_vars, + permission=pg.coding.CodePermission.CALL, + ) + except Exception as e: + raise ValueError( + f'Failed to parse pyobject for field "{field_name}": {e}' + ) from e + else: + # Not a pyobject block - treat as error + raise ValueError( + f'Object field "{field_name}" must use pyobject code block' + ) + else: + result[field_name] = section_content.strip() + + # Create object instance + return cls(**result) + + def _fix_markdown( + self, + text: str, + schema: base.Schema, + error: Exception, + lm: lf.LanguageModel, + ) -> str: + """Fix malformed markdown using LLM.""" + # Delay import at runtime to avoid circular dependency. + # This follows the same pattern as python/correction.py + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured import querying + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + # Build schema description + schema_desc = self.schema_repr(schema) + + # Build correction prompt + correction_prompt = f"""The following markdown output has an error: + +```markdown +{text} +``` + +Error: {error} + +Expected schema: +{schema_desc} + + +Please provide the corrected markdown output that matches the expected schema.""" + + # Query LLM for correction (disable autofix to avoid recursion) + corrected = querying.query( + correction_prompt, + str, + lm=lm, + autofix=0, + ) + + return corrected + + def _extract_section(self, text: str, section_name: str) -> str | None: + """Extract content from a markdown section.""" + # Match: ## section_name\n + # (until next ## that's not ### or ####, or end) + # Use negative lookahead to ensure ## is not followed by another # + pattern = rf'##\s+{re.escape(section_name)}\s*\n(.*?)(?=\n##(?!#)|\Z)' + match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) + if match: + return match.group(1).strip() + return None + + def _extract_code_block(self, section_content: str) -> tuple[str, str] | None: + """Extract code and language from markdown code block. + + Args: + section_content: The markdown section content to extract from. + + Returns: + Tuple of (code_content, language) if found, None otherwise. + Language defaults to 'python' if not specified. + """ + # Match: ```language\n\n``` + pattern = r'```([\w]*)\s*\n(.*?)\n```' + match = re.search(pattern, section_content, re.DOTALL) + if match: + language = match.group(1) or 'python' + code = match.group(2) + return (code, language) + return None + + def _extract_list_items(self, section_content: str) -> list[str]: + """Extract items from a markdown list.""" + if not section_content: + return [] + + items = [] + for line in section_content.split('\n'): + line = line.strip() + if line.startswith('- ') or line.startswith('* '): + items.append(line[2:].strip()) + + return items + + def _extract_list_objects( + self, section_content: str, obj_cls: type[pg.Object] + ) -> list[pg.Object]: + """Extract list of objects from markdown with ### headers.""" + if not section_content: + return [] + + # Split by ### headers to get individual items + # Pattern: ### ClassName N or ### Item N + pattern = r'###\s+(?:\w+\s+)?\d+\s*\n' + parts = re.split(pattern, section_content) + + # First part before any ### is usually empty or description + items = [] + for part in parts[1:]: # Skip first empty part + if not part.strip(): + continue + + # Parse object fields from this part + obj_data = {} + for key, field in obj_cls.__schema__.items(): + if not isinstance(key, pg.typing.ConstStrKey): + continue + + field_name = str(key) + # Extract #### field_name content + field_pattern = ( + rf'####\s+{re.escape(field_name)}\s*\n(.*?)(?=\n####|\Z)' + ) + field_match = re.search(field_pattern, part, re.DOTALL | re.IGNORECASE) + + if field_match: + field_content = field_match.group(1).strip() + + # Parse based on field type + if isinstance(field.value, pg.typing.Str): + code_info = self._extract_code_block(field_content) + obj_data[field_name] = code_info[0] if code_info else field_content + elif isinstance(field.value, pg.typing.Int): + obj_data[field_name] = int(field_content) + elif isinstance(field.value, pg.typing.Float): + obj_data[field_name] = float(field_content) + elif isinstance(field.value, pg.typing.Bool): + obj_data[field_name] = field_content.lower() in ( + 'true', + 'yes', + '1', + ) + else: + obj_data[field_name] = field_content + elif not field.value.is_noneable: + raise ValueError( + f'Required field "{field_name}" not found in list item' + ) + + items.append(obj_cls(**obj_data)) + + return items + + def _parse_union_field( + self, + section_content: str, + union_spec: pg.typing.Union, + field_name: str, + all_dependencies: dict[str, type[pg.Object]], + ) -> Any: + """Parse a Union type field by trying each candidate in order.""" + if not section_content: + # Empty content - check if None is allowed + if union_spec.is_noneable: + return None + raise ValueError(f'Field "{field_name}" is empty but not optional') + + # Check if this is a pyobject code block + code_info = self._extract_code_block(section_content) + if code_info and code_info[1] == 'pyobject': + # Delegate to Python protocol for parsing + # This handles Union matching automatically through PyGlove's type system + # Delay import at runtime to avoid circular dependency. + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured.schema import python + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + code = code_info[0] + # Strip any backticks that LLMs might add + code = code.strip('`') + + # Build global_vars with all class dependencies + global_vars = all_dependencies.copy() + + try: + result = python.structure_from_python( + code, + global_vars=global_vars, + permission=pg.coding.CodePermission.CALL, + ) + return result + except Exception as e: + raise ValueError( + f'Failed to parse pyobject for field "{field_name}": {e}' + ) from e + + # Use the dependencies passed from _parse_markdown (like Python protocol) + all_classes = all_dependencies.copy() + code_classes = {} + + for candidate in union_spec.candidates: + if isinstance(candidate, pg.typing.Object): + # Add the candidate class itself + all_classes[candidate.cls.__name__] = candidate.cls + + # Check if this is a code class + cls_fields = list(candidate.cls.__schema__.items()) + if len(cls_fields) == 1: + cls_key, cls_field = cls_fields[0] + if isinstance(cls_key, pg.typing.ConstStrKey): + cls_field_name = str(cls_key) + if cls_field_name.endswith('_code') and isinstance( + cls_field.value, pg.typing.Str + ): + code_classes[candidate.cls.__name__] = ( + candidate.cls, + cls_field_name, + ) + + # Sort candidates: Objects first, then primitives, then List last + # This prevents List from matching too eagerly + def candidate_priority(candidate): + if isinstance(candidate, pg.typing.Object): + return 0 # Highest priority + elif isinstance(candidate, pg.typing.List): + return 2 # Lowest priority + else: + return 1 # Medium priority + + sorted_candidates = sorted(union_spec.candidates, key=candidate_priority) + + # Try each candidate type in priority order + errors = [] + for candidate in sorted_candidates: + try: + # Try to parse as this candidate type + if isinstance(candidate, pg.typing.Object): + if candidate.cls.__name__ in code_classes: + # This is a code class - check language marker + code_info = self._extract_code_block(section_content) + if code_info: + code_content, actual_lang = code_info + # Check if this code class matches the language + current_expected_lang = self._detect_code_language( + code_classes[candidate.cls.__name__][1] + ) + + # If language matches this code class, use it + if actual_lang == current_expected_lang: + code_cls, code_field_name = code_classes[candidate.cls.__name__] + return code_cls(**{code_field_name: code_content}) + + # If language doesn't match, skip this candidate + raise ValueError( + f'Code block language "{actual_lang}" does not match expected' + f' language "{current_expected_lang}" for' + f' {candidate.cls.__name__}' + ) + else: + # This is a complex object - should have been handled by + # pyobject above + raise ValueError( + f'Complex object {candidate.cls.__name__} must use' + ' pyobject marker' + ) + elif isinstance(candidate, pg.typing.List): + # Check if this is a pyobject block with Python list literal + code_info = self._extract_code_block(section_content) + if code_info: + _, language = code_info + if language == 'pyobject': + # Should have been handled above + raise ValueError('pyobject blocks handled earlier') + + # Try markdown-style list parsing + # (has ### headers or - list items) + if '###' in section_content or section_content.strip().startswith( + '-' + ): + element_spec = candidate.element.value + if isinstance(element_spec, pg.typing.Object): + items = self._extract_list_objects( + section_content, element_spec.cls + ) + if items: # Only return if we actually found items + return items + else: + items = self._extract_list_items(section_content) + if items: # Only return if we actually found items + if isinstance(element_spec, pg.typing.Int): + return [int(item) for item in items] + elif isinstance(element_spec, pg.typing.Float): + return [float(item) for item in items] + else: + return items + # If doesn't look like a list, skip this candidate + raise ValueError('Content does not look like a list') + elif isinstance(candidate, pg.typing.Str): + code_info = self._extract_code_block(section_content) + if code_info: + # If this is a pyobject block, skip Str candidate + # Let Object candidates handle it + if code_info[1] == 'pyobject': + raise ValueError( + 'pyobject code blocks should be parsed as Objects, not Str' + ) + return code_info[0] + return section_content.strip() + elif isinstance(candidate, pg.typing.Int): + return int(section_content.strip()) + elif isinstance(candidate, pg.typing.Float): + return float(section_content.strip()) + elif isinstance(candidate, pg.typing.Bool): + return section_content.strip().lower() in ('true', 'yes', '1') + else: + # Unknown type, skip + continue + except Exception as e: # pylint: disable=broad-exception-caught + errors.append((candidate, e)) + continue + + # If we get here, all candidates failed + error_msg = ( + f'Failed to parse field "{field_name}" as any Union candidate:\\n' + ) + for candidate, error in errors: + error_msg += f' - {candidate}: {error}\\n' + raise ValueError(error_msg) + + def _parse_as_object( + self, + section_content: str, + obj_cls: type[pg.Object], + all_classes: dict[str, type[pg.Object]] | None = None, + ) -> pg.Object: + """Parse content as a PyGlove Object using Python eval.""" + # Delay import at runtime to avoid circular dependency. + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from langfun.core.structured.schema import python + # pytype: enable=import-error + # pylint: enable=g-import-not-at-top + + # Extract code if wrapped in triple backticks (```...```) + code_info = self._extract_code_block(section_content) + if code_info: + code = code_info[0] + else: + code = section_content.strip() + + # Strip any leading/trailing backticks that LLMs might add + # (e.g., `BrowseWeb(...)` instead of BrowseWeb(...)) + code = code.strip('`') + + # Build global_vars with all classes + global_vars = all_classes.copy() if all_classes else {} + # Ensure the target class is included + global_vars[obj_cls.__name__] = obj_cls + + # Use Python protocol to parse the object + try: + result = python.structure_from_python( + code, + global_vars=global_vars, + permission=pg.coding.CodePermission.CALL, + ) + # Verify it's the right type + if isinstance(result, obj_cls): + return result + raise TypeError( + f'Expected {obj_cls.__name__}, got {type(result).__name__}' + ) + except Exception as e: + raise ValueError(f'Failed to parse as {obj_cls.__name__}: {e}') from e diff --git a/langfun/core/structured/schema/markdown_test.py b/langfun/core/structured/schema/markdown_test.py new file mode 100644 index 00000000..de107716 --- /dev/null +++ b/langfun/core/structured/schema/markdown_test.py @@ -0,0 +1,1219 @@ +# Copyright 2025 The Langfun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for markdown prompting protocol.""" + +import unittest +from langfun.core.llms import fake +from langfun.core.structured.schema import base +from langfun.core.structured.schema import markdown +import pyglove as pg + + +class MarkdownPromptingProtocolSchemaReprTest(unittest.TestCase): + """Tests for schema representation in markdown.""" + + def test_repr_simple(self): + """Test markdown schema with simple fields.""" + + class Solution(pg.Object): + reasoning: str + answer: int + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + expected = """## reasoning + + +... + +## answer + + +... + +--- + +**Type Definitions:** + +```python +class Solution: + reasoning: str + answer: int +```""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_with_code_field(self): + """Test automatic code block suggestion for 'code' field.""" + + class Solution(pg.Object): + reasoning: str + code: str + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + expected = """## reasoning + + +... + +## code + + +```python +... +``` + +--- + +**Type Definitions:** + +```python +class Solution: + reasoning: str + code: str +```""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_with_code_fields(self): + """Test automatic code block detection for *_code fields.""" + + class SolutionWithTests(pg.Object): + cpp_code: str + terminal_code: str + bash_code: str + python_code: str # Added to test the default python branch + + schema = base.Schema(SolutionWithTests) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + expected = """## cpp_code + + +```cpp +... +``` + +## terminal_code + + +```bash +... +``` + +## bash_code + + +```bash +... +``` + +## python_code + + +```python +... +``` + +--- + +**Type Definitions:** + +```python +class SolutionWithTests: + cpp_code: str + terminal_code: str + bash_code: str + python_code: str +```""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_nested_object(self): + """Test markdown schema with nested objects.""" + + class Inner(pg.Object): + value: int + + class Outer(pg.Object): + inner: Inner + name: str + + schema = base.Schema(Outer) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + expected = """## inner +... + +## name + + +... + +--- + +**Type Definitions:** + +```python +class Inner: + value: int + +class Outer: + inner: Inner + name: str +```""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_union_with_pyobject(self): + """Test that Union types use 'pyobject' language marker in schema.""" + + class FileRead(pg.Object): + file_path: str + mode: str + + class FileWrite(pg.Object): + file_path: str + content: str + + class FinalizeAnswer(pg.Object): + pass + + class NextStep(pg.Object): + next_step: FileRead | FileWrite | FinalizeAnswer | None + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + expected_parts = [ + '## next_step', + '```pyobject', + 'Union[FileRead, FileWrite, FinalizeAnswer, None]', + '```', + '**Type Definitions:**', + 'class FileRead:', + 'class FileWrite:', + 'class FinalizeAnswer:', + 'class NextStep:', + ] + + # Verify all expected parts are present + for part in expected_parts: + self.assertIn(part, markdown_repr) + + # Verify python Union is NOT used + self.assertNotIn('```python\nUnion[', markdown_repr) + + def test_repr_union_with_code_classes_and_objects(self): + """Test Union with both code classes and other objects uses correct markers.""" + + class BashCode(pg.Object): + bash_code: str + + class PythonCode(pg.Object): + python_code: str + + class FileRead(pg.Object): + file_path: str + + class NextStep(pg.Object): + next_step: BashCode | PythonCode | FileRead | None + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + # Should have bash and python code block examples + self.assertIn('```bash', markdown_repr) + self.assertIn('```python', markdown_repr) + # Should use pyobject for non-code objects + self.assertIn('```pyobject', markdown_repr) + self.assertIn('FileRead(...)', markdown_repr) + + def test_repr_list_of_objects_with_pyobject(self): + """Test that List[Object] uses 'pyobject' language marker in schema.""" + + class Item(pg.Object): + name: str + value: int + + class Container(pg.Object): + items: list[Item] + + schema = base.Schema(Container) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + # Should use pyobject for List[Object] type annotation + self.assertIn('```pyobject', markdown_repr) + self.assertIn('list[Item]', markdown_repr) + # Should NOT use python for type annotation + self.assertNotIn('```python\nlist[', markdown_repr) + + def test_print_schema_example(self): + """Print example schema output to demonstrate the pyobject format.""" + + class BashCode(pg.Object): + """Execute bash commands.""" + + bash_code: str + + class FileRead(pg.Object): + """Read a file from the filesystem.""" + + file_path: str + mode: str + + class FinalizeAnswer(pg.Object): + """Finalize the answer to the question.""" + + pass + + class NextStep(pg.Object): + """Next step in the research process.""" + + think_step_by_step: str + next_step: BashCode | FileRead | FinalizeAnswer | None + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + # Print the actual output for demonstration + print('\n' + '=' * 80) + print('EXAMPLE SCHEMA OUTPUT WITH PYOBJECT FORMAT:') + print('=' * 80) + print(markdown_repr) + print('=' * 80 + '\n') + + # Verify it has the expected markers + self.assertIn('```bash', markdown_repr) + self.assertIn('```pyobject', markdown_repr) + self.assertIn('FileRead(...)', markdown_repr) + self.assertIn('FinalizeAnswer(...)', markdown_repr) + + +class MarkdownPromptingProtocolValueReprTest(unittest.TestCase): + """Tests for value representation in markdown.""" + + def test_repr(self): + """Test markdown value representation.""" + + class Solution(pg.Object): + reasoning: str + cpp_code: str + + solution = Solution( + reasoning='Use dynamic programming', + cpp_code='#include \nint main() { return 0; }', + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(solution) + + expected = """## reasoning +Use dynamic programming + +## cpp_code +```cpp +#include +int main() { return 0; } +``` +""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_bash_detection(self): + """Test bash code detection in markdown repr.""" + + class ScriptSolution(pg.Object): + bash_script_code: str + + solution = ScriptSolution( + bash_script_code='#!/bin/bash\necho "Hello"', + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(solution) + + expected = """## bash_script_code +```bash +#!/bin/bash +echo "Hello" +``` +""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_bash_detection_cat(self): + """Test bash code detection with cat > pattern.""" + + class TestCode(pg.Object): + test_code: str + + solution = TestCode( + test_code='cat > test.txt\nSome content', + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(solution) + + expected = """## test_code +```bash +cat > test.txt +Some content +``` +""" + + self.assertEqual(markdown_repr, expected) + + def test_repr_python_detection(self): + """Test python code detection (default fallback).""" + + class PythonSolution(pg.Object): + python_code: str + + solution = PythonSolution( + python_code='def foo():\n return 42', + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(solution) + + expected = """## python_code +```python +def foo(): + return 42 +``` +""" + + self.assertEqual(markdown_repr, expected) + + +class MarkdownPromptingProtocolParseValueTest(unittest.TestCase): + """Tests for parsing markdown into structured values.""" + + def test_parse_simple(self): + """Test parsing markdown text into structured object.""" + + class Solution(pg.Object): + reasoning: str + code: str + + markdown_text = """ +## reasoning +This is my reasoning + +## code +```python +def foo(): + pass +``` +""" + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertEqual(result.reasoning, 'This is my reasoning') + self.assertEqual(result.code, 'def foo():\n pass') + + def test_parse_with_code_blocks(self): + """Test parsing code blocks from markdown.""" + + class SolutionWithTests(pg.Object): + cpp_code: str + terminal_code: str + + markdown_text = """ +## cpp_code +```cpp +#include +int main() { return 0; } +``` + +## terminal_code +```bash +g++ -o solution solution.cpp +./solution +``` +""" + + schema = base.Schema(SolutionWithTests) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertIn('#include ', result.cpp_code) + self.assertIn('g++ -o solution', result.terminal_code) + + def test_parse_missing_required_field(self): + """Test that missing required fields raise ValueError.""" + + class Solution(pg.Object): + reasoning: str + code: str + + markdown_text = """ +## reasoning +This is my reasoning +""" + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + with self.assertRaisesRegex(ValueError, 'Required field "code" not found'): + protocol.parse_value(markdown_text, schema) + + def test_parse_optional_field(self): + """Test parsing with optional fields.""" + + class Solution(pg.Object): + reasoning: str + code: str | None + + markdown_text = """ +## reasoning +This is my reasoning +""" + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertEqual(result.reasoning, 'This is my reasoning') + self.assertIsNone(result.code) + + def test_extract_section(self): + """Test section extraction helper method.""" + markdown_text = """ +## section1 +Content 1 + +## section2 +Content 2 +""" + + protocol = markdown.MarkdownPromptingProtocol() + section1 = protocol._extract_section(markdown_text, 'section1') + section2 = protocol._extract_section(markdown_text, 'section2') + + self.assertEqual(section1, 'Content 1') + self.assertEqual(section2, 'Content 2') + + def test_extract_code_block(self): + """Test code block extraction helper method.""" + section_content = """ +Some text +```python +def foo(): + pass +``` +More text +""" + + protocol = markdown.MarkdownPromptingProtocol() + code_info = protocol._extract_code_block(section_content) + + self.assertIsNotNone(code_info) + code, language = code_info + self.assertEqual(code, 'def foo():\n pass') + self.assertEqual(language, 'python') + + def test_autofix_with_missing_field(self): + """Test that autofix is triggered when a required field is missing.""" + + class Solution(pg.Object): + reasoning: str + code: str + + # Markdown missing the 'code' field + markdown_text = """ +## reasoning +This is my reasoning +""" + + # Create a fake LLM that will provide the missing field + corrected_markdown = """ +## reasoning +This is my reasoning + +## code +```python +def solution(): + pass +``` +""" + + fix_lm = fake.StaticResponse(corrected_markdown) + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + + # With autofix=1, should call fix_lm and succeed + result = protocol.parse_value( + markdown_text, + schema, + autofix=1, + autofix_lm=fix_lm, + ) + + self.assertEqual(result.reasoning, 'This is my reasoning') + self.assertEqual(result.code, 'def solution():\n pass') + + def test_autofix_not_triggered_when_disabled(self): + """Test that autofix is not triggered when autofix=0.""" + + class Solution(pg.Object): + reasoning: str + code: str + + # Markdown missing the 'code' field + markdown_text = """ +## reasoning +This is my reasoning +""" + + schema = base.Schema(Solution) + protocol = markdown.MarkdownPromptingProtocol() + + # With autofix=0, should raise ValueError + with self.assertRaisesRegex(ValueError, 'Required field "code" not found'): + protocol.parse_value( + markdown_text, + schema, + autofix=0, + ) + + +class MarkdownPromptingProtocolListTest(unittest.TestCase): + """Tests for List type support in markdown protocol.""" + + def test_simple_list_schema_repr(self): + """Test schema representation for simple list.""" + + class SimpleList(pg.Object): + title: str + items: list[str] + + schema = base.Schema(SimpleList) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + self.assertIn('## title', markdown_repr) + self.assertIn('## items', markdown_repr) + self.assertIn('- item 1', markdown_repr) + + def test_object_list_schema_repr(self): + """Test schema representation for list of objects.""" + + class TestCase(pg.Object): + description: str + code: str + + class TestSuite(pg.Object): + name: str + test_cases: list[TestCase] + + schema = base.Schema(TestSuite) + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.schema_repr(schema) + + self.assertIn('## test_cases', markdown_repr) + self.assertIn('### TestCase 1', markdown_repr) + self.assertIn('#### description', markdown_repr) + self.assertIn('#### code', markdown_repr) + + def test_simple_list_value_repr(self): + """Test value representation for simple list.""" + + class SimpleList(pg.Object): + title: str + items: list[str] + + value = SimpleList(title='Shopping List', items=['Milk', 'Eggs', 'Bread']) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(value) + + self.assertIn('## title', markdown_repr) + self.assertIn('Shopping List', markdown_repr) + self.assertIn('## items', markdown_repr) + self.assertIn('- Milk', markdown_repr) + self.assertIn('- Eggs', markdown_repr) + self.assertIn('- Bread', markdown_repr) + + def test_object_list_value_repr(self): + """Test value representation for list of objects.""" + + class TestCase(pg.Object): + description: str + input_code: str + + class TestSuite(pg.Object): + name: str + test_cases: list[TestCase] + + value = TestSuite( + name='Math Tests', + test_cases=[ + TestCase( + description='Test addition', + input_code='assert add(1, 2) == 3', + ), + TestCase( + description='Test subtraction', + input_code='assert subtract(5, 3) == 2', + ), + ], + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(value) + + self.assertIn('## name', markdown_repr) + self.assertIn('Math Tests', markdown_repr) + self.assertIn('## test_cases', markdown_repr) + self.assertIn('### TestCase 1', markdown_repr) + self.assertIn('#### description', markdown_repr) + self.assertIn('Test addition', markdown_repr) + self.assertIn('#### input_code', markdown_repr) + self.assertIn('assert add(1, 2) == 3', markdown_repr) + self.assertIn('### TestCase 2', markdown_repr) + self.assertIn('Test subtraction', markdown_repr) + + def test_parse_simple_list(self): + """Test parsing simple list from markdown.""" + + class SimpleList(pg.Object): + title: str + items: list[str] + + markdown_text = """ +## title +Shopping List + +## items +- Milk +- Eggs +- Bread +""" + + schema = base.Schema(SimpleList) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertEqual(result.title, 'Shopping List') + self.assertEqual(result.items, ['Milk', 'Eggs', 'Bread']) + + def test_parse_object_list(self): + """Test parsing list of objects from markdown.""" + + class TestCase(pg.Object): + description: str + input_code: str + expected_output: str + + class TestSuite(pg.Object): + name: str + test_cases: list[TestCase] + + markdown_text = """ +## name +Math Tests + +## test_cases + +### TestCase 1 + +#### description +Test addition + +#### input_code +```python +assert add(1, 2) == 3 +``` + +#### expected_output +Pass + +### TestCase 2 + +#### description +Test subtraction + +#### input_code +```python +assert subtract(5, 3) == 2 +``` + +#### expected_output +Pass +""" + + schema = base.Schema(TestSuite) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertEqual(result.name, 'Math Tests') + self.assertEqual(len(result.test_cases), 2) + self.assertEqual(result.test_cases[0].description, 'Test addition') + self.assertEqual(result.test_cases[0].input_code, 'assert add(1, 2) == 3') + self.assertEqual(result.test_cases[0].expected_output, 'Pass') + self.assertEqual(result.test_cases[1].description, 'Test subtraction') + self.assertEqual( + result.test_cases[1].input_code, 'assert subtract(5, 3) == 2' + ) + self.assertEqual(result.test_cases[1].expected_output, 'Pass') + + def test_parse_int_list(self): + """Test parsing list of integers.""" + + class NumberList(pg.Object): + numbers: list[int] + + markdown_text = """ +## numbers +- 1 +- 2 +- 3 +""" + + schema = base.Schema(NumberList) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertEqual(result.numbers, [1, 2, 3]) + + def test_extract_list_items(self): + """Test list item extraction helper method.""" + section_content = """ +- Item 1 +- Item 2 +- Item 3 +""" + + protocol = markdown.MarkdownPromptingProtocol() + items = protocol._extract_list_items(section_content) + + self.assertEqual(items, ['Item 1', 'Item 2', 'Item 3']) + + +class MarkdownPromptingProtocolUnionTest(unittest.TestCase): + """Tests for Union type support in markdown protocol.""" + + def test_parse_union_object(self): + """Test parsing Union of objects.""" + + class Action1(pg.Object): + name: str + value: int + + class Action2(pg.Object): + title: str + + class TestUnion(pg.Object): + action: Action1 | Action2 + + # Test first candidate + markdown_text = """## action +```pyobject +Action1(name='test', value=42) +``` +""" + + schema = base.Schema(TestUnion) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertIsInstance(result.action, Action1) + self.assertEqual(result.action.name, 'test') + self.assertEqual(result.action.value, 42) + + def test_parse_union_with_nested_objects(self): + """Test parsing Union with nested objects (e.g., BrowseWeb with Question).""" + + class Question(pg.Object): + question: str + context: dict[str, str] + + class BrowseWeb(pg.Object): + question: Question + + class FileRead(pg.Object): + file_path: str + + class NextStep(pg.Object): + next_step: BrowseWeb | FileRead | None + + # Test BrowseWeb with nested Question object + markdown_text = """## next_step +```pyobject +BrowseWeb(question=Question(question='What is the answer?', context={})) +``` +""" + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertIsInstance(result.next_step, BrowseWeb) + self.assertIsInstance(result.next_step.question, Question) + self.assertEqual(result.next_step.question.question, 'What is the answer?') + self.assertEqual(result.next_step.question.context, {}) + + def test_parse_union_list_vs_single(self): + """Test parsing Union of list vs single object.""" + + class Item(pg.Object): + name: str + + class TestUnion(pg.Object): + items: list[Item] | Item | None + + # Test single object + markdown_text1 = """## items +```pyobject +Item(name='single') +``` +""" + + schema = base.Schema(TestUnion) + protocol = markdown.MarkdownPromptingProtocol() + result1 = protocol.parse_value(markdown_text1, schema) + + self.assertIsInstance(result1.items, Item) + self.assertEqual(result1.items.name, 'single') + + def test_parse_union_code_class(self): + """Test parsing Union with code class (BashCode).""" + + class BashCode(pg.Object): + bash_code: str + + class FileRead(pg.Object): + file_path: str + + class TerminalNextStep(pg.Object): + think_step_by_step: str + next_step: BashCode | FileRead | None + + # Test BashCode with comment (like LLM generates from schema example) + markdown_text = """## think_step_by_step +Calculate 1002 * 0.04 and round up. + +## next_step +```bash +# BashCode +python -c "import math; print(math.ceil(1002 * 0.04))" +``` +""" + + schema = base.Schema(TerminalNextStep) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + self.assertIsInstance(result.next_step, BashCode) + # The bash_code should include the comment (it's valid bash) + self.assertIn('# BashCode', result.next_step.bash_code) + self.assertIn('python -c', result.next_step.bash_code) + + def test_parse_union_python_vs_bash(self): + """Test that Python code blocks are not misidentified as BashCode.""" + + class BashCode(pg.Object): + bash_code: str + + class PythonCode(pg.Object): + python_code: str + + class NextStep(pg.Object): + reasoning: str + action: BashCode | PythonCode | None + + # Test Python code block - should be parsed as PythonCode + markdown_text = """## reasoning +Use Python to finalize the answer. + +## action +```python +FinalizeAnswer() +``` +""" + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + # Should be PythonCode, not BashCode + self.assertIsInstance(result.action, PythonCode) + self.assertEqual(result.action.python_code, 'FinalizeAnswer()') + + # Test Bash code block - should be parsed as BashCode + markdown_text2 = """## reasoning +Use bash to execute a command. + +## action +```bash +echo "Hello World" +``` +""" + + result2 = protocol.parse_value(markdown_text2, schema) + + # Should be BashCode + self.assertIsInstance(result2.action, BashCode) + self.assertEqual(result2.action.bash_code, 'echo "Hello World"') + + def test_finalize_answer_not_parsed_as_bash(self): + """Test the original user scenario: FinalizeAnswer() should not be BashCode. + + This reproduces the exact issue reported by the user where a Python code + block containing FinalizeAnswer() was being incorrectly parsed as BashCode. + """ + + class BashCode(pg.Object): + bash_code: str + + class FinalizeAnswer(pg.Object): + """Dummy FinalizeAnswer for testing.""" + + pass + + class TerminalNextStep(pg.Object): + think_step_by_step: str + next_step: BashCode | FinalizeAnswer | None + + # This is the user's exact scenario + markdown_text = """## think_step_by_step +The problem has been successfully solved in the previous steps. + +1. **Initial Attempt (Step 1):** The first script failed with a `KeyError`. +2. **Investigation (Step 2):** A second script was used to inspect the structure. +3. **Successful Execution (Step 3):** A third script was created and succeeded. +4. **Result:** The script executed successfully and printed the final answer. + +The goal has been achieved. The correct action is to finalize the answer. + +## next_step +```pyobject +FinalizeAnswer() +``` +""" + + schema = base.Schema(TerminalNextStep) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + # Critical assertion: Python code should be parsed as FinalizeAnswer + self.assertIsInstance( + result.next_step, + FinalizeAnswer, + f'Expected FinalizeAnswer but got {type(result.next_step).__name__}. ' + 'Python code blocks should not be misidentified as BashCode!', + ) + + def test_parse_list_of_union_with_nested_objects(self): + """Test parsing list of Union types with nested objects in pyobject format. + + This tests the scenario where an LLM generates a Python list literal + containing multiple objects with nested dependencies (e.g., Terminal and + BrowseWeb, each containing a Question object). + """ + + class Question(pg.Object): + question: str + context: dict[str, str] | None + + class Terminal(pg.Object): + question: Question + + class BrowseWeb(pg.Object): + question: Question + + class FileRead(pg.Object): + file_path: str + + class NextStep(pg.Object): + next_step: list[Terminal | BrowseWeb | FileRead] | None + + # This is the exact format an LLM might generate + markdown_text = """## next_step +```pyobject +[ + Terminal( + question=Question( + question="Please check the transcript of the audio file to identify the recommended reading page numbers for the Calculus mid-term. You can use available command-line tools or install Python libraries like SpeechRecognition to process the file. If the file is MP3, you might need to convert it to WAV first.", + context={ + 'file_path': './question/attachments/1f975693-876d-457b-a649-393859e79bf3.mp3' + } + ) + ), + BrowseWeb( + question=Question( + question="What are the recommended reading page numbers for Professor Willowbrook's Calculus mid-term?", + context=None + ) + ) +] +``` +""" + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + # Verify it's a list + self.assertIsInstance(result.next_step, list) + self.assertEqual(len(result.next_step), 2) + + # Verify first item is Terminal with nested Question + self.assertIsInstance(result.next_step[0], Terminal) + self.assertIsInstance(result.next_step[0].question, Question) + self.assertIn('Calculus mid-term', result.next_step[0].question.question) + self.assertIsInstance(result.next_step[0].question.context, dict) + self.assertIn('file_path', result.next_step[0].question.context) + + # Verify second item is BrowseWeb with nested Question + self.assertIsInstance(result.next_step[1], BrowseWeb) + self.assertIsInstance(result.next_step[1].question, Question) + self.assertIn( + 'Professor Willowbrook', result.next_step[1].question.question + ) + self.assertIsNone(result.next_step[1].question.context) + + def test_value_repr_with_object_field(self): + """Test that value_repr generates pyobject code blocks for Object fields. + + This ensures that few-shot examples are formatted correctly with pyobject + blocks instead of inline code or plain str(). Uses the actual example from + browse.NEXT_STEP_EXAMPLES. + """ + + class NavigateTo(pg.Object): + url: str + + class NextStep(pg.Object): + think_step_by_step: str + next_step: NavigateTo + + # This is the actual example from browse.NEXT_STEP_EXAMPLES + value = NextStep( + think_step_by_step=( + 'I should use the NavigateTo action to navigate to the Google' + ' homepage.' + ), + next_step=NavigateTo('https://www.google.com/'), + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(value) + + # Should use pyobject code block for the NavigateTo object + self.assertIn('```pyobject', markdown_repr) + self.assertIn('NavigateTo', markdown_repr) + self.assertIn('https://www.google.com/', markdown_repr) + # Should NOT use inline code (single backticks) around the object + # The entire object should be in a code block, not inline + lines = markdown_repr.split('\n') + # Find the next_step section + next_step_idx = None + for i, line in enumerate(lines): + if line.strip() == '## next_step': + next_step_idx = i + break + self.assertIsNotNone(next_step_idx) + # The line after ## next_step should be ```pyobject + self.assertEqual(lines[next_step_idx + 1], '```pyobject') + + def test_value_repr_no_backticks(self): + """Test that value_repr does NOT generate backticks around objects. + + This verifies the fix for the issue where LLMs were seeing backticks + in few-shot examples and copying them, causing parsing errors. + """ + + class NavigateTo(pg.Object): + url: str + + class NextStep(pg.Object): + think_step_by_step: str + next_step: NavigateTo + + value = NextStep( + think_step_by_step=( + 'I should use the NavigateTo action to navigate to the Google' + ' homepage.' + ), + next_step=NavigateTo(url='https://www.google.com/'), + ) + + protocol = markdown.MarkdownPromptingProtocol() + markdown_repr = protocol.value_repr(value) + + # Expected output should NOT have backticks around NavigateTo(...) + expected = """## think_step_by_step +I should use the NavigateTo action to navigate to the Google homepage. + +## next_step +```pyobject +NavigateTo(url='https://www.google.com/') +``` +""" + + self.assertEqual(markdown_repr, expected) + + def test_parse_with_backticks_defensive(self): + """Test that parser can handle backticks defensively if LLM adds them.""" + + class Question(pg.Object): + question: str + context: dict[str, str] | None + + class BrowseWeb(pg.Object): + question: Question + + class NextStep(pg.Object): + next_step: BrowseWeb | None + + # This is what caused the original error - backticks around the object + markdown_text = """## next_step +```pyobject +`BrowseWeb(question=Question(question='Find the photograph with accession number 2022.128 in the Whitney Museum of American Art collection.', context=None))` +``` +""" + + schema = base.Schema(NextStep) + protocol = markdown.MarkdownPromptingProtocol() + result = protocol.parse_value(markdown_text, schema) + + # Should successfully parse despite the backticks + self.assertIsInstance(result.next_step, BrowseWeb) + self.assertIsInstance(result.next_step.question, Question) + self.assertEqual( + result.next_step.question.question, + 'Find the photograph with accession number 2022.128 in the Whitney' + ' Museum of American Art collection.', + ) + self.assertIsNone(result.next_step.question.context) + + +if __name__ == '__main__': + unittest.main()