From 1500e56e03b94f15a8b6956742b47babd4cac398 Mon Sep 17 00:00:00 2001 From: Soumyadip Sarkar Date: Sat, 2 May 2026 18:54:51 +0530 Subject: [PATCH] Fix surjectivity checks for explicit and shifted codomains --- src/tensor_layouts/analysis.py | 19 ++++++++++++++++--- tests/analysis.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 098089f..8c79511 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -114,7 +114,7 @@ def is_injective(layout: LayoutExpr) -> bool: return len(image(layout)) == size(layout) -def is_surjective(layout: LayoutExpr, codomain_size: int = None) -> bool: +def is_surjective(layout: LayoutExpr, codomain_size: int | None = None) -> bool: """True if every offset in [0, codomain_size) is produced. A surjective layout has no gaps --- the image covers the entire @@ -131,9 +131,22 @@ def is_surjective(layout: LayoutExpr, codomain_size: int = None) -> bool: is_surjective(Layout(4, 1)) # True (image == codomain) is_surjective(Layout(4, 2)) # False (image has gaps) """ + offsets = image(layout) if codomain_size is None: - codomain_size = cosize(layout) - return len(image(layout)) == codomain_size + if not offsets: + return True + lo, hi = offsets[0], offsets[-1] + return len(offsets) == (hi - lo + 1) + + if not is_int(codomain_size): + raise TypeError("codomain_size must be an integer") + if codomain_size < 0: + raise ValueError("codomain_size must be non-negative") + if codomain_size == 0: + return len(offsets) == 0 + if len(offsets) != codomain_size: + return False + return offsets[0] == 0 and offsets[-1] == codomain_size - 1 def is_bijective(layout: LayoutExpr) -> bool: diff --git a/tests/analysis.py b/tests/analysis.py index 8ed758a..6ce3a7f 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -1556,8 +1556,26 @@ def test_is_surjective_custom_codomain(): layout = Layout(4, 2) # Not surjective onto [0, 7) -- has gaps assert not is_surjective(layout) - # Surjective if codomain is exactly the image size - assert is_surjective(layout, codomain_size=4) + # Still not surjective onto [0, 4) because the image is {0,2,4,6} + assert not is_surjective(layout, codomain_size=4) + # Surjective when codomain exactly matches requested range [0, 7) + assert is_surjective(Layout(7, 1), codomain_size=7) + + +def test_is_surjective_custom_codomain_zero(): + """codomain_size=0 is surjective only for empty images.""" + assert is_surjective(Layout(0, 3), codomain_size=0) + assert not is_surjective(Layout(1, 0), codomain_size=0) + + +def test_is_surjective_custom_codomain_validation(): + """codomain_size must be a non-negative integer.""" + with pytest.raises(TypeError): + is_surjective(Layout(4, 1), codomain_size=4.0) + with pytest.raises(TypeError): + is_surjective(Layout(4, 1), codomain_size=True) + with pytest.raises(ValueError): + is_surjective(Layout(4, 1), codomain_size=-1) def test_is_bijective_contiguous(): @@ -1597,6 +1615,12 @@ def test_is_bijective_negative_stride_dense(): assert is_contiguous(layout) +def test_is_surjective_negative_stride_custom_codomain(): + """Negative-stride dense spans are not surjective onto [0, n) by default.""" + layout = Layout(4, -1) + assert not is_surjective(layout, codomain_size=4) + + def test_image_injectivity_consistency(): """image size equals domain size iff injective.""" layouts = [