Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion langfun/core/coding/python/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from typing import Any
import langfun.core as lf
from langfun.core.coding.python import execution
from langfun.core.structured import schema as schema_lib
import pyglove as pg


class CodeWithError(pg.Object):
"""Python code with error."""

schema_definition: str | None
code: str
error: str

Expand All @@ -42,6 +44,7 @@ def run_with_correction(
returns_code: bool = False,
returns_stdout: bool = False,
outputs_intermediate: bool = False,
schema: schema_lib.Schema | None = None,
) -> Any | tuple[Any, str]:
"""Correct code with a language model via self-play.

Expand All @@ -68,6 +71,7 @@ def run_with_correction(
outputs_intermediate: If True, intermediate output will be outputted as a
dict, with the last line's value accessible by key '__result__'. Otherwise
the value of the last line will be returned.
schema: Optional schema for the expected output.

Returns:
Run result if `returns_code` is set to False (default), otherwise a tuple
Expand All @@ -83,6 +87,13 @@ def run_with_correction(
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top

if schema is not None:
if isinstance(schema, type):
schema = schema_lib.Schema.from_value(schema)
schema_definition = schema.schema_str(protocol="python")
else:
schema_definition = None

if max_attempts == 0:
result = _maybe_custom_validate(
execution.run(
Expand Down Expand Up @@ -126,7 +137,14 @@ def result_and_error(code: str) -> tuple[Any, str | None]:
try:
# Disable autofix for code correction to avoid recursion.
correction = querying.query(
CodeWithError(code=code, error=error), CorrectedCode, lm=lm, autofix=0
CodeWithError(
schema_definition=schema_definition,
code=code,
error=error,
),
CorrectedCode,
lm=lm,
autofix=0,
)
except pg.coding.CodeError:
break
Expand All @@ -148,6 +166,7 @@ def result_and_error(code: str) -> tuple[Any, str | None]:
def correct(
code: str,
error: str | None = None,
schema: schema_lib.Schema | None = None,
*,
global_vars: dict[str, Any] | None = None,
lm: lf.LanguageModel = lf.contextual(),
Expand All @@ -162,6 +181,7 @@ def correct(
error: An optional initial error for `code` when it's problematic, usually
caught from elsewhere when it ran. If None, code will be executed once to
verify if its good and obtain a feedback error message.
schema: Optional schema for the expected output.
global_vars: A dict of str to value as the global variables that could be
accessed within the corrected code.
lm: Language model to be used. If not specified, it will try to use the `lm`
Expand All @@ -183,6 +203,7 @@ def correct(
return run_with_correction(
code,
error=error,
schema=schema,
global_vars=global_vars,
lm=lm,
max_attempts=max_attempts,
Expand Down
37 changes: 37 additions & 0 deletions langfun/core/coding/python/correction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,43 @@ def test_run_with_correction(self):
)
self.assertEqual(result, 4)

def test_run_with_correction_with_schema(self):
class Flight(pg.Object):
airline: str
flight_number: str

class Result(pg.Object):
flights: list[Flight]

result = correction.run_with_correction(
inspect.cleandoc("""
Result(
flights=[
Flight(airline='DELTA', flight_number='DL123'),
Flight(airline='UNITED', flight_number='UA456'),
]
)
"""),
schema=Result,
global_vars=dict(Result=Result, Flight=Flight),
lm=fake.StaticSequence([
inspect.cleandoc("""
CorrectedCode(
corrected_code='Result(flights=[Flight(airline="DELTA", flight_number="DL123"), Flight(airline="UNITED", flight_number="UA456")])',
)
"""),
]),
)
self.assertEqual(
result,
Result(
flights=[
Flight(airline='DELTA', flight_number='DL123'),
Flight(airline='UNITED', flight_number='UA456'),
]
),
)

def test_run_with_correction_upon_custom_validation(self):

class Foo(pg.Object):
Expand Down
6 changes: 5 additions & 1 deletion langfun/core/structured/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import typing
from typing import Any, Literal, Sequence, Type, Union
import langfun.core as lf
from langfun.core.coding.python import correction
import pyglove as pg


Expand Down Expand Up @@ -747,6 +746,7 @@ def parse(
global_vars.update({d.__name__: d for d in dependencies})
return structure_from_python(
text,
schema=schema,
global_vars=global_vars,
autofix=autofix,
autofix_lm=autofix_lm,
Expand All @@ -757,6 +757,7 @@ def parse(
def structure_from_python(
code: str,
*,
schema: Schema | None = None,
global_vars: dict[str, Any] | None = None,
permission: pg.coding.CodePermission = (
pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
Expand All @@ -765,6 +766,8 @@ def structure_from_python(
autofix_lm: lf.LanguageModel = lf.contextual(),
) -> Any:
"""Evaluates structure from Python code with access to symbols."""
from langfun.core.coding.python import correction # pylint: disable=g-import-not-at-top # pytype: disable=import-error

global_vars = global_vars or {}
global_vars.update({
'pg': pg,
Expand All @@ -787,6 +790,7 @@ def structure_from_python(
max_attempts=autofix,
lm=autofix_lm,
permission=permission,
schema=schema,
)


Expand Down
22 changes: 22 additions & 0 deletions langfun/core/structured/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,28 @@ class A(pg.Object):
A([Foo(1), Foo(2)], y='bar'),
)

def test_parse_with_correction_with_schema(self):
class Flight(pg.Object):
airline: str
flight_number: str

class Result(pg.Object):
flights: list[Flight]

self.assertEqual(
schema_lib.ValuePythonRepr().parse(
"Result(flights=[Flight(airline='DELTA', flight_number='DL123')])",
schema_lib.Schema(Result),
autofix=1,
autofix_lm=fake.StaticResponse(inspect.cleandoc("""
CorrectedCode(
corrected_code="Result(flights=[Flight(airline='DELTA', flight_number='DL123')])",
)
""")),
),
Result(flights=[Flight(airline='DELTA', flight_number='DL123')]),
)

def test_parse_class_def(self):
self.assertTrue(
inspect.isclass(
Expand Down
Loading