Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/tensor_layouts/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def is_injective(layout: LayoutExpr) -> bool:
return len(image(layout)) == size(layout)


def is_surjective(layout: LayoutExpr, codomain_size: int = None) -> bool:
def is_surjective(layout: LayoutExpr, codomain_size: int | None = None) -> bool:
"""True if every offset in [0, codomain_size) is produced.

A surjective layout has no gaps --- the image covers the entire
Expand All @@ -131,9 +131,22 @@ def is_surjective(layout: LayoutExpr, codomain_size: int = None) -> bool:
is_surjective(Layout(4, 1)) # True (image == codomain)
is_surjective(Layout(4, 2)) # False (image has gaps)
"""
offsets = image(layout)
if codomain_size is None:
codomain_size = cosize(layout)
return len(image(layout)) == codomain_size
if not offsets:
return True
lo, hi = offsets[0], offsets[-1]
return len(offsets) == (hi - lo + 1)

if not is_int(codomain_size):
raise TypeError("codomain_size must be an integer")
if codomain_size < 0:
raise ValueError("codomain_size must be non-negative")
if codomain_size == 0:
return len(offsets) == 0
if len(offsets) != codomain_size:
return False
return offsets[0] == 0 and offsets[-1] == codomain_size - 1


def is_bijective(layout: LayoutExpr) -> bool:
Expand Down
28 changes: 26 additions & 2 deletions tests/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,8 +1556,26 @@ def test_is_surjective_custom_codomain():
layout = Layout(4, 2)
# Not surjective onto [0, 7) -- has gaps
assert not is_surjective(layout)
# Surjective if codomain is exactly the image size
assert is_surjective(layout, codomain_size=4)
# Still not surjective onto [0, 4) because the image is {0,2,4,6}
assert not is_surjective(layout, codomain_size=4)
# Surjective when codomain exactly matches requested range [0, 7)
assert is_surjective(Layout(7, 1), codomain_size=7)


def test_is_surjective_custom_codomain_zero():
"""codomain_size=0 is surjective only for empty images."""
assert is_surjective(Layout(0, 3), codomain_size=0)
assert not is_surjective(Layout(1, 0), codomain_size=0)


def test_is_surjective_custom_codomain_validation():
"""codomain_size must be a non-negative integer."""
with pytest.raises(TypeError):
is_surjective(Layout(4, 1), codomain_size=4.0)
with pytest.raises(TypeError):
is_surjective(Layout(4, 1), codomain_size=True)
with pytest.raises(ValueError):
is_surjective(Layout(4, 1), codomain_size=-1)


def test_is_bijective_contiguous():
Expand Down Expand Up @@ -1597,6 +1615,12 @@ def test_is_bijective_negative_stride_dense():
assert is_contiguous(layout)


def test_is_surjective_negative_stride_custom_codomain():
"""Negative-stride dense spans are not surjective onto [0, n) by default."""
layout = Layout(4, -1)
assert not is_surjective(layout, codomain_size=4)


def test_image_injectivity_consistency():
"""image size equals domain size iff injective."""
layouts = [
Expand Down
Loading