2323
2424 Array : TypeAlias = CpuArray | GpuArray | DiskArray | types .CSDataset | types .DaskArray
2525
26- DTypeIn = type [ np .float32 | np .float64 | np .int32 | np .bool ]
26+ DTypeIn = np .float32 | np .float64 | np .int32 | np .bool
2727 DTypeOut = type [np .float32 | np .float64 | np .int64 ]
2828
2929 class BenchFun (Protocol ): # noqa: D101
@@ -44,29 +44,42 @@ def __call__( # noqa: D102
4444ATS_CUPY_SPARSE = {at for at in SUPPORTED_TYPES if "cupyx.scipy" in str (at )}
4545
4646
47- @pytest .fixture (scope = "session" , params = [0 , 1 , None ])
47+ @pytest .fixture (scope = "session" , params = [0 , 1 , None ], ids = [ "ax0" , "ax1" , "all" ] )
4848def axis (request : pytest .FixtureRequest ) -> Literal [0 , 1 , None ]:
4949 return cast ("Literal[0, 1, None]" , request .param )
5050
5151
52+ @pytest .fixture (scope = "session" , params = [1 , 2 ], ids = ["1d" , "2d" ])
53+ def ndim (request : pytest .FixtureRequest ) -> Literal [1 , 2 ]:
54+ return cast ("Literal[1, 2]" , request .param )
55+
56+
5257@pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , np .int32 , np .bool ])
53- def dtype_in (request : pytest .FixtureRequest ) -> DTypeIn :
54- return cast ("DTypeIn" , request .param )
58+ def dtype_in (request : pytest .FixtureRequest ) -> type [ DTypeIn ] :
59+ return cast ("type[ DTypeIn] " , request .param )
5560
5661
5762@pytest .fixture (scope = "session" , params = [np .float32 , np .float64 , None ])
5863def dtype_arg (request : pytest .FixtureRequest ) -> DTypeOut | None :
5964 return cast ("DTypeOut | None" , request .param )
6065
6166
67+ @pytest .fixture (scope = "session" )
68+ def np_arr (dtype_in : type [DTypeIn ], ndim : Literal [1 , 2 ]) -> NDArray [DTypeIn ]:
69+ np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_in )
70+ if ndim == 1 :
71+ np_arr = np_arr .flatten ()
72+ return np_arr
73+
74+
6275@pytest .mark .array_type (skip = ATS_SPARSE_DS )
6376def test_sum (
6477 array_type : ArrayType [Array ],
65- dtype_in : DTypeIn ,
78+ dtype_in : type [ DTypeIn ] ,
6679 dtype_arg : DTypeOut | None ,
6780 axis : Literal [0 , 1 , None ],
81+ np_arr : NDArray [DTypeIn ],
6882) -> None :
69- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = dtype_in )
7083 if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
7184 pytest .skip ("CuPy sparse matrices only support floats" )
7285 arr = array_type (np_arr .copy ())
@@ -106,9 +119,11 @@ def test_sum(
106119@pytest .mark .array_type (skip = ATS_SPARSE_DS )
107120@pytest .mark .parametrize (("axis" , "expected" ), [(None , 3.5 ), (0 , [2.5 , 3.5 , 4.5 ]), (1 , [2.0 , 5.0 ])])
108121def test_mean (
109- array_type : ArrayType [Array ], axis : Literal [0 , 1 , None ], expected : float | list [float ]
122+ array_type : ArrayType [Array ],
123+ axis : Literal [0 , 1 , None ],
124+ expected : float | list [float ],
125+ np_arr : NDArray [DTypeIn ],
110126) -> None :
111- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
112127 if array_type in ATS_CUPY_SPARSE and np_arr .dtype .kind != "f" :
113128 pytest .skip ("CuPy sparse matrices only support floats" )
114129 np .testing .assert_array_equal (np .mean (np_arr , axis = axis ), expected )
@@ -132,8 +147,8 @@ def test_mean_var(
132147 axis : Literal [0 , 1 , None ],
133148 mean_expected : float | list [float ],
134149 var_expected : float | list [float ],
150+ np_arr : NDArray [DTypeIn ],
135151) -> None :
136- np_arr = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], dtype = np .float64 )
137152 np .testing .assert_array_equal (np .mean (np_arr , axis = axis ), mean_expected )
138153 np .testing .assert_array_equal (np .var (np_arr , axis = axis , correction = 1 ), var_expected )
139154
0 commit comments