Skip to content

Commit 13e136e

Browse files
committed
Update json schema creation to dereference pydantic complex data types
1 parent 307c665 commit 13e136e

2 files changed

Lines changed: 332 additions & 0 deletions

File tree

py/src/braintrust/parameters.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Evaluation parameters support for Python SDK."""
22

3+
from copy import deepcopy
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Literal, TypedDict
56

@@ -32,6 +33,7 @@ class ModelParameter(TypedDict):
3233
description: NotRequired[str | None]
3334

3435

36+
JSONValue = None | bool | int | float | str | list["JSONValue"] | dict[str, "JSONValue"]
3537
ValidatedParameters = dict[str, object]
3638
ParameterSchema = PromptParameter | ModelParameter | type[object] | None
3739
EvalParameters = dict[str, ParameterSchema]
@@ -97,8 +99,61 @@ def _get_pydantic_fields(schema: Any) -> dict[str, Any]:
9799
return getattr(schema, "__fields__", {})
98100

99101

102+
def _resolve_json_pointer(document: dict[str, JSONValue], pointer: str) -> JSONValue:
103+
if pointer == "#":
104+
return document
105+
if not pointer.startswith("#/"):
106+
raise ValueError(f"Unsupported JSON schema ref '{pointer}'")
107+
108+
current: JSONValue = document
109+
for raw_part in pointer[2:].split("/"):
110+
part = raw_part.replace("~1", "/").replace("~0", "~")
111+
if not isinstance(current, dict) or part not in current:
112+
raise ValueError(f"JSON schema ref '{pointer}' could not be resolved")
113+
current = current[part]
114+
return current
115+
116+
117+
def _resolve_local_json_schema_refs(
118+
node: JSONValue,
119+
root: dict[str, JSONValue],
120+
resolving: tuple[str, ...] = (),
121+
) -> JSONValue:
122+
if isinstance(node, list):
123+
return [_resolve_local_json_schema_refs(item, root, resolving) for item in node]
124+
125+
if not isinstance(node, dict):
126+
return node
127+
128+
ref = node.get("$ref")
129+
if isinstance(ref, str):
130+
if ref in resolving:
131+
raise ValueError(f"Cyclic JSON schema ref '{ref}'")
132+
133+
resolved = deepcopy(_resolve_json_pointer(root, ref))
134+
resolved = _resolve_local_json_schema_refs(resolved, root, resolving + (ref,))
135+
136+
siblings = {
137+
key: _resolve_local_json_schema_refs(value, root, resolving)
138+
for key, value in node.items()
139+
if key != "$ref"
140+
}
141+
if siblings:
142+
if not isinstance(resolved, dict):
143+
raise ValueError(f"Cannot merge sibling keys into non-object JSON schema ref '{ref}'")
144+
merged = dict(resolved)
145+
merged.update(siblings)
146+
return merged
147+
return resolved
148+
149+
return {key: _resolve_local_json_schema_refs(value, root, resolving) for key, value in node.items()}
150+
151+
100152
def _serialize_pydantic_parameter_schema(schema: Any) -> dict[str, Any]:
101153
schema_json = _pydantic_to_json_schema(schema)
154+
schema_json = _resolve_local_json_schema_refs(schema_json, schema_json)
155+
schema_json.pop("$defs", None)
156+
schema_json.pop("definitions", None)
102157
fields = _get_pydantic_fields(schema)
103158
if len(fields) == 1 and "value" in fields:
104159
properties = schema_json.get("properties")

py/src/braintrust/test_parameters.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,24 @@
44
from braintrust.parameters import (
55
RemoteEvalParameters,
66
parameters_to_json_schema,
7+
serialize_eval_parameters,
78
validate_parameters,
89
)
910

1011

1112
HAS_PYDANTIC = importlib.util.find_spec("pydantic") is not None
1213

1314

15+
def _contains_json_schema_ref(node):
16+
if isinstance(node, dict):
17+
if "$ref" in node or "$defs" in node or "definitions" in node:
18+
return True
19+
return any(_contains_json_schema_ref(value) for value in node.values())
20+
if isinstance(node, list):
21+
return any(_contains_json_schema_ref(value) for value in node)
22+
return False
23+
24+
1425
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
1526
def test_validate_local_parameters_with_prompt_and_model_defaults():
1627
from pydantic import BaseModel
@@ -130,6 +141,272 @@ class PrefixParam(BaseModel):
130141
}
131142

132143

144+
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
145+
def test_parameters_to_json_schema_keeps_complex_single_value_models_self_contained():
146+
from pydantic import BaseModel, Field
147+
148+
class ComplexValue(BaseModel):
149+
a: int = Field(default=1)
150+
b: list[int] = Field(default=[2, 3])
151+
152+
class ComplexParameter(BaseModel):
153+
value: ComplexValue = Field(
154+
default=ComplexValue(),
155+
description="Complex example parameter",
156+
)
157+
158+
schema = parameters_to_json_schema({"complex": ComplexParameter})
159+
160+
assert schema["properties"]["complex"] == {
161+
"type": "object",
162+
"properties": {
163+
"a": {
164+
"default": 1,
165+
"title": "A",
166+
"type": "integer",
167+
},
168+
"b": {
169+
"default": [2, 3],
170+
"items": {
171+
"type": "integer",
172+
},
173+
"title": "B",
174+
"type": "array",
175+
},
176+
},
177+
"default": {
178+
"a": 1,
179+
"b": [2, 3],
180+
},
181+
"description": "Complex example parameter",
182+
"title": "ComplexValue",
183+
}
184+
185+
186+
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
187+
def test_serialize_eval_parameters_does_not_emit_dangling_ref_for_complex_single_value_models():
188+
from pydantic import BaseModel, Field
189+
190+
class ComplexValue(BaseModel):
191+
a: int = Field(default=1)
192+
b: list[int] = Field(default=[2, 3])
193+
194+
class ComplexParameter(BaseModel):
195+
value: ComplexValue = Field(
196+
default=ComplexValue(),
197+
description="Complex example parameter",
198+
)
199+
200+
serialized = serialize_eval_parameters({"complex": ComplexParameter})
201+
202+
assert serialized["complex"] == {
203+
"type": "data",
204+
"schema": {
205+
"type": "object",
206+
"properties": {
207+
"a": {
208+
"default": 1,
209+
"title": "A",
210+
"type": "integer",
211+
},
212+
"b": {
213+
"default": [2, 3],
214+
"items": {
215+
"type": "integer",
216+
},
217+
"title": "B",
218+
"type": "array",
219+
},
220+
},
221+
"default": {
222+
"a": 1,
223+
"b": [2, 3],
224+
},
225+
"description": "Complex example parameter",
226+
"title": "ComplexValue",
227+
},
228+
"default": {
229+
"a": 1,
230+
"b": [2, 3],
231+
},
232+
"description": "Complex example parameter",
233+
}
234+
235+
236+
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
237+
def test_parameters_to_json_schema_keeps_array_of_objects_single_value_models_self_contained():
238+
from pydantic import BaseModel, Field
239+
240+
class ComplexItem(BaseModel):
241+
a: int = Field(default=1)
242+
b: str = Field(default="x")
243+
244+
class ComplexParameter(BaseModel):
245+
value: list[ComplexItem] = Field(
246+
default=[ComplexItem()],
247+
description="Array example parameter",
248+
)
249+
250+
schema = parameters_to_json_schema({"complex_array": ComplexParameter})
251+
252+
assert schema["properties"]["complex_array"] == {
253+
"type": "array",
254+
"items": {
255+
"type": "object",
256+
"properties": {
257+
"a": {
258+
"default": 1,
259+
"title": "A",
260+
"type": "integer",
261+
},
262+
"b": {
263+
"default": "x",
264+
"title": "B",
265+
"type": "string",
266+
},
267+
},
268+
"title": "ComplexItem",
269+
},
270+
"default": [
271+
{
272+
"a": 1,
273+
"b": "x",
274+
}
275+
],
276+
"description": "Array example parameter",
277+
"title": "Value",
278+
}
279+
assert not _contains_json_schema_ref(schema["properties"]["complex_array"])
280+
281+
282+
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
283+
def test_parameters_to_json_schema_inlines_refs_for_multi_field_models():
284+
from pydantic import BaseModel, Field
285+
286+
class Address(BaseModel):
287+
street: str = Field(default="Main")
288+
zip_code: int = Field(default=12345)
289+
290+
class ComplexParameter(BaseModel):
291+
address: Address = Field(default=Address())
292+
enabled: bool = Field(default=True)
293+
294+
schema = parameters_to_json_schema({"complex": ComplexParameter})
295+
296+
assert not _contains_json_schema_ref(schema["properties"]["complex"])
297+
assert schema["properties"]["complex"]["properties"]["address"] == {
298+
"type": "object",
299+
"properties": {
300+
"street": {
301+
"default": "Main",
302+
"title": "Street",
303+
"type": "string",
304+
},
305+
"zip_code": {
306+
"default": 12345,
307+
"title": "Zip Code",
308+
"type": "integer",
309+
},
310+
},
311+
"default": {
312+
"street": "Main",
313+
"zip_code": 12345,
314+
},
315+
"title": "Address",
316+
}
317+
318+
319+
def test_parameters_to_json_schema_inlines_legacy_definitions_refs():
320+
class _FakeField:
321+
required = False
322+
323+
class LegacyParameter:
324+
__fields__ = {"value": _FakeField()}
325+
326+
@classmethod
327+
def parse_obj(cls, value):
328+
return value
329+
330+
@classmethod
331+
def schema(cls):
332+
return {
333+
"type": "object",
334+
"properties": {
335+
"value": {
336+
"$ref": "#/definitions/ComplexValue",
337+
"description": "Legacy example parameter",
338+
"default": {"a": 1},
339+
},
340+
},
341+
"definitions": {
342+
"ComplexValue": {
343+
"type": "object",
344+
"properties": {
345+
"a": {
346+
"type": "integer",
347+
"title": "A",
348+
"default": 1,
349+
},
350+
},
351+
"title": "ComplexValue",
352+
},
353+
},
354+
}
355+
356+
schema = parameters_to_json_schema({"legacy": LegacyParameter})
357+
358+
assert schema["properties"]["legacy"] == {
359+
"type": "object",
360+
"properties": {
361+
"a": {
362+
"type": "integer",
363+
"title": "A",
364+
"default": 1,
365+
},
366+
},
367+
"title": "ComplexValue",
368+
"description": "Legacy example parameter",
369+
"default": {"a": 1},
370+
}
371+
assert not _contains_json_schema_ref(schema["properties"]["legacy"])
372+
373+
374+
def test_parameters_to_json_schema_raises_for_cyclic_local_refs():
375+
class _FakeField:
376+
required = False
377+
378+
class CyclicParameter:
379+
__fields__ = {"value": _FakeField()}
380+
381+
@classmethod
382+
def parse_obj(cls, value):
383+
return value
384+
385+
@classmethod
386+
def schema(cls):
387+
return {
388+
"type": "object",
389+
"properties": {
390+
"value": {
391+
"$ref": "#/definitions/Node",
392+
},
393+
},
394+
"definitions": {
395+
"Node": {
396+
"type": "object",
397+
"properties": {
398+
"child": {
399+
"$ref": "#/definitions/Node",
400+
},
401+
},
402+
},
403+
},
404+
}
405+
406+
with pytest.raises(ValueError, match="Cyclic JSON schema ref"):
407+
parameters_to_json_schema({"cyclic": CyclicParameter})
408+
409+
133410
@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
134411
def test_parameters_to_json_schema_marks_prompt_and_model_without_defaults_required():
135412
schema = parameters_to_json_schema(

0 commit comments

Comments
 (0)