diff --git a/src/tensor_layouts/analysis.py b/src/tensor_layouts/analysis.py index 098089f..8c79511 100644 --- a/src/tensor_layouts/analysis.py +++ b/src/tensor_layouts/analysis.py @@ -114,7 +114,7 @@ def is_injective(layout: LayoutExpr) -> bool: return len(image(layout)) == size(layout) -def is_surjective(layout: LayoutExpr, codomain_size: int = None) -> bool: +def is_surjective(layout: LayoutExpr, codomain_size: int | None = None) -> bool: """True if every offset in [0, codomain_size) is produced. A surjective layout has no gaps --- the image covers the entire @@ -131,9 +131,22 @@ def is_surjective(layout: LayoutExpr, codomain_size: int = None) -> bool: is_surjective(Layout(4, 1)) # True (image == codomain) is_surjective(Layout(4, 2)) # False (image has gaps) """ + offsets = image(layout) if codomain_size is None: - codomain_size = cosize(layout) - return len(image(layout)) == codomain_size + if not offsets: + return True + lo, hi = offsets[0], offsets[-1] + return len(offsets) == (hi - lo + 1) + + if not is_int(codomain_size): + raise TypeError("codomain_size must be an integer") + if codomain_size < 0: + raise ValueError("codomain_size must be non-negative") + if codomain_size == 0: + return len(offsets) == 0 + if len(offsets) != codomain_size: + return False + return offsets[0] == 0 and offsets[-1] == codomain_size - 1 def is_bijective(layout: LayoutExpr) -> bool: diff --git a/tests/analysis.py b/tests/analysis.py index 8ed758a..6ce3a7f 100644 --- a/tests/analysis.py +++ b/tests/analysis.py @@ -1556,8 +1556,26 @@ def test_is_surjective_custom_codomain(): layout = Layout(4, 2) # Not surjective onto [0, 7) -- has gaps assert not is_surjective(layout) - # Surjective if codomain is exactly the image size - assert is_surjective(layout, codomain_size=4) + # Still not surjective onto [0, 4) because the image is {0,2,4,6} + assert not is_surjective(layout, codomain_size=4) + # Surjective when codomain exactly matches requested range [0, 7) + assert is_surjective(Layout(7, 1), codomain_size=7) + + +def test_is_surjective_custom_codomain_zero(): + """codomain_size=0 is surjective only for empty images.""" + assert is_surjective(Layout(0, 3), codomain_size=0) + assert not is_surjective(Layout(1, 0), codomain_size=0) + + +def test_is_surjective_custom_codomain_validation(): + """codomain_size must be a non-negative integer.""" + with pytest.raises(TypeError): + is_surjective(Layout(4, 1), codomain_size=4.0) + with pytest.raises(TypeError): + is_surjective(Layout(4, 1), codomain_size=True) + with pytest.raises(ValueError): + is_surjective(Layout(4, 1), codomain_size=-1) def test_is_bijective_contiguous(): @@ -1597,6 +1615,12 @@ def test_is_bijective_negative_stride_dense(): assert is_contiguous(layout) +def test_is_surjective_negative_stride_custom_codomain(): + """Negative-stride dense spans are not surjective onto [0, n) by default.""" + layout = Layout(4, -1) + assert not is_surjective(layout, codomain_size=4) + + def test_image_injectivity_consistency(): """image size equals domain size iff injective.""" layouts = [