Skip to content

Commit be8da8c

Browse files
committed
run tests on 0d arrays
1 parent 75506e9 commit be8da8c

1 file changed

Lines changed: 24 additions & 9 deletions

File tree

tests/test_stats.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
Array: TypeAlias = CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray
2525

26-
DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool]
26+
DTypeIn = np.float32 | np.float64 | np.int32 | np.bool
2727
DTypeOut = type[np.float32 | np.float64 | np.int64]
2828

2929
class BenchFun(Protocol): # noqa: D101
@@ -44,29 +44,42 @@ def __call__( # noqa: D102
4444
ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str(at)}
4545

4646

47-
@pytest.fixture(scope="session", params=[0, 1, None])
47+
@pytest.fixture(scope="session", params=[0, 1, None], ids=["ax0", "ax1", "all"])
4848
def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]:
4949
return cast("Literal[0, 1, None]", request.param)
5050

5151

52+
@pytest.fixture(scope="session", params=[1, 2], ids=["1d", "2d"])
53+
def ndim(request: pytest.FixtureRequest) -> Literal[1, 2]:
54+
return cast("Literal[1, 2]", request.param)
55+
56+
5257
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool])
53-
def dtype_in(request: pytest.FixtureRequest) -> DTypeIn:
54-
return cast("DTypeIn", request.param)
58+
def dtype_in(request: pytest.FixtureRequest) -> type[DTypeIn]:
59+
return cast("type[DTypeIn]", request.param)
5560

5661

5762
@pytest.fixture(scope="session", params=[np.float32, np.float64, None])
5863
def dtype_arg(request: pytest.FixtureRequest) -> DTypeOut | None:
5964
return cast("DTypeOut | None", request.param)
6065

6166

67+
@pytest.fixture(scope="session")
68+
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
69+
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)
70+
if ndim == 1:
71+
np_arr = np_arr.flatten()
72+
return np_arr
73+
74+
6275
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
6376
def test_sum(
6477
array_type: ArrayType[Array],
65-
dtype_in: DTypeIn,
78+
dtype_in: type[DTypeIn],
6679
dtype_arg: DTypeOut | None,
6780
axis: Literal[0, 1, None],
81+
np_arr: NDArray[DTypeIn],
6882
) -> None:
69-
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)
7083
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
7184
pytest.skip("CuPy sparse matrices only support floats")
7285
arr = array_type(np_arr.copy())
@@ -106,9 +119,11 @@ def test_sum(
106119
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
107120
@pytest.mark.parametrize(("axis", "expected"), [(None, 3.5), (0, [2.5, 3.5, 4.5]), (1, [2.0, 5.0])])
108121
def test_mean(
109-
array_type: ArrayType[Array], axis: Literal[0, 1, None], expected: float | list[float]
122+
array_type: ArrayType[Array],
123+
axis: Literal[0, 1, None],
124+
expected: float | list[float],
125+
np_arr: NDArray[DTypeIn],
110126
) -> None:
111-
np_arr = np.array([[1, 2, 3], [4, 5, 6]])
112127
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
113128
pytest.skip("CuPy sparse matrices only support floats")
114129
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected)
@@ -132,8 +147,8 @@ def test_mean_var(
132147
axis: Literal[0, 1, None],
133148
mean_expected: float | list[float],
134149
var_expected: float | list[float],
150+
np_arr: NDArray[DTypeIn],
135151
) -> None:
136-
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
137152
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected)
138153
np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected)
139154

0 commit comments

Comments
 (0)