Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions mojo_bindgen/layout_tests/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BitfieldGroupMember,
MojoModule,
OpaqueStorageMember,
PaddingMember,
StoredMember,
Struct,
StructDecl,
Expand Down Expand Up @@ -92,27 +93,29 @@ def _layout_record_check(
if _is_opaque_storage_decl(mojo_decl):
return LayoutRecordCheck(record_name=mojo_name, checks=tuple(checks))

member_names = _member_names(mojo_decl)
member_indices = _member_indices(mojo_decl)
for field_fact in facts.plain_fields:
field = decl.fields[field_fact.index]
member_name = field_mojo_name(field, field_fact.index)
if member_name not in member_names:
mojo_index = member_indices.get(member_name)
if mojo_index is None:
continue
checks.append(
LayoutCheck(
label=f"{mojo_name}.{member_name}.offset",
expression=f"r.field_offset[index={field_fact.index}]()",
expression=f"r.field_offset[index={mojo_index}]()",
expected=field_fact.byte_offset,
)
)

for run in facts.bitfield_runs:
if run.name not in member_names:
mojo_index = member_indices.get(run.name)
if mojo_index is None:
continue
checks.append(
LayoutCheck(
label=f"{mojo_name}.{run.name}.offset",
expression=f"r.field_offset[index={run.first_index}]()",
expression=f"r.field_offset[index={mojo_index}]()",
expected=run.byte_offset,
)
)
Expand All @@ -124,18 +127,18 @@ def _is_opaque_storage_decl(decl: StructDecl) -> bool:
return len(decl.members) == 1 and isinstance(decl.members[0], OpaqueStorageMember)


def _member_names(decl: StructDecl) -> set[str]:
names: set[str] = set()
for member in decl.members:
def _member_indices(decl: StructDecl) -> dict[str, int]:
indices: dict[str, int] = {}
for index, member in enumerate(decl.members):
if isinstance(member, StoredMember):
names.add(member.name)
indices[member.name] = index
elif isinstance(member, BitfieldGroupMember):
names.add(member.storage_name)
indices[member.storage_name] = index
elif isinstance(member, OpaqueStorageMember):
names.add(member.name)
else:
names.add(member.name)
return names
indices[member.name] = index
elif isinstance(member, PaddingMember):
indices[member.name] = index
return indices


def render_layout_test_module(
Expand All @@ -154,6 +157,7 @@ def render_layout_test_module(
"# Generated by mojo_bindgen - do not edit by hand.",
f"# layout tests for: {main_module_name}",
"",
"from std.testing import TestSuite",
"from std.sys.info import align_of, size_of",
"from std.reflection import reflect",
]
Expand Down Expand Up @@ -181,11 +185,7 @@ def render_layout_test_module(
lines.append(" pass")

lines.extend(["", "", "def main() raises:"])
if checks:
for record in checks:
lines.append(f" test_layout_{record.record_name}()")
else:
lines.append(" pass")
lines.append(" TestSuite.discover_tests[__functions_in_module()]().run()")
return "\n".join(lines) + "\n"


Expand Down
111 changes: 110 additions & 1 deletion tests/unit/test_layout_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MojoBuiltin,
MojoModule,
OpaqueStorageMember,
PaddingMember,
StoredMember,
Struct,
StructDecl,
Expand Down Expand Up @@ -92,6 +93,55 @@ def test_collect_layout_record_checks_for_plain_struct_with_padding() -> None:
("Sample.a.offset", 0),
("Sample.b.offset", 4),
]
assert [check.expression for check in checks[0].checks] == [
"size_of[Sample]()",
"align_of[Sample]()",
"r.field_offset[index=0]()",
"r.field_offset[index=2]()",
]


def test_collect_layout_record_checks_uses_emitted_indices_after_bitfield_padding() -> None:
decl = Struct(
decl_id="struct:Instance",
name="Instance",
c_name="Instance",
fields=[
Field(
name="transform", source_name="transform", type=_i32(), byte_offset=0, size_bytes=4
),
Field(
name="flags",
source_name="flags",
type=_i32(),
byte_offset=4,
size_bytes=4,
is_bitfield=True,
bit_offset=32,
bit_width=8,
),
Field(name="address", source_name="address", type=_i32(), byte_offset=8, size_bytes=4),
],
size_bytes=12,
align_bytes=4,
)
mojo_decl = StructDecl(
name="Instance",
members=[
StoredMember(0, "transform", BuiltinType(MojoBuiltin.C_INT), 0),
PaddingMember("__pad0", 4, 4),
StoredMember(2, "address", BuiltinType(MojoBuiltin.C_INT), 8),
],
)

checks = collect_layout_record_checks(
normalized_unit=_unit(decl),
mojo_module=_module(mojo_decl),
)

assert ("Instance.address.offset", "r.field_offset[index=2]()", 8) in [
(check.label, check.expression, check.expected) for check in checks[0].checks
]


def test_collect_layout_record_checks_for_opaque_storage_skips_field_offsets() -> None:
Expand Down Expand Up @@ -174,6 +224,62 @@ def test_collect_layout_record_checks_for_bitfield_storage_group_offset() -> Non
assert all("enabled.offset" not in check.label for check in checks[0].checks)


def test_collect_layout_record_checks_uses_emitted_indices_for_bitfield_groups() -> None:
decl = Struct(
decl_id="struct:MixedFlags",
name="MixedFlags",
c_name="MixedFlags",
fields=[
Field(name="prefix", source_name="prefix", type=_i32(), byte_offset=0, size_bytes=4),
Field(
name="enabled",
source_name="enabled",
type=_i32(),
byte_offset=4,
size_bytes=4,
is_bitfield=True,
bit_offset=32,
bit_width=1,
),
],
size_bytes=8,
align_bytes=4,
)
mojo_decl = StructDecl(
name="MixedFlags",
members=[
StoredMember(0, "prefix", BuiltinType(MojoBuiltin.C_INT), 0),
PaddingMember("__pad0", 0, 4),
BitfieldGroupMember(
storage_name="__bf0",
storage_type=BuiltinType(MojoBuiltin.C_UINT),
byte_offset=4,
first_index=1,
storage_width_bits=32,
fields=[
BitfieldField(
index=1,
name="enabled",
logical_type=BuiltinType(MojoBuiltin.C_INT),
bit_offset=32,
bit_width=1,
signed=True,
)
],
),
],
)

checks = collect_layout_record_checks(
normalized_unit=_unit(decl),
mojo_module=_module(mojo_decl),
)

assert ("MixedFlags.__bf0.offset", "r.field_offset[index=2]()", 4) in [
(check.label, check.expression, check.expected) for check in checks[0].checks
]


def test_collect_layout_record_checks_skips_incomplete_union_and_missing_record_decl() -> None:
incomplete = Struct(
decl_id="struct:Forward",
Expand Down Expand Up @@ -223,9 +329,12 @@ def test_render_layout_test_module_imports_records_and_calls_tests() -> None:
)

assert "from std.sys.info import align_of, size_of" in out
assert "from std.testing import TestSuite" in out
assert "from std.reflection import reflect" in out
assert "from bindings import Sample" in out
assert "def test_layout_Sample() raises:" in out
assert "comptime r = reflect[Sample]()" in out
assert "r.field_offset[index=0]()" in out
assert "def main() raises:\n test_layout_Sample()" in out
assert (
"def main() raises:\n TestSuite.discover_tests[__functions_in_module()]().run()" in out
)
Loading