Skip to content

Commit b5e4ae7

Browse files
committed
fix types
1 parent b50b758 commit b5e4ae7

4 files changed

Lines changed: 7 additions & 4 deletions

File tree

tests/test_stats.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
8787

8888
@pytest.fixture
8989
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
90-
np_arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in)
90+
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in))
9191
np_arr.flags.writeable = False
9292
if ndim == 1:
9393
np_arr = np_arr.flatten()
@@ -149,7 +149,7 @@ def test_mean(
149149
}[None]
150150
if array_type in ATS_CUPY_SPARSE and np_arr.dtype.kind != "f":
151151
pytest.skip("CuPy sparse matrices only support floats")
152-
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected)
152+
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), expected) # type: ignore[arg-type]
153153

154154
arr = array_type(np_arr)
155155
result = stats.mean(arr, axis=axis) # type: ignore[arg-type] # https://github.com/python/mypy/issues/16777
@@ -172,8 +172,8 @@ def test_mean_var(
172172
var_expected: float | list[float],
173173
np_arr: NDArray[DTypeIn],
174174
) -> None:
175-
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected)
176-
np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected)
175+
np.testing.assert_array_equal(np.mean(np_arr, axis=axis), mean_expected) # type: ignore[arg-type]
176+
np.testing.assert_array_equal(np.var(np_arr, axis=axis, correction=1), var_expected) # type: ignore[arg-type]
177177

178178
arr = array_type(np_arr)
179179
mean, var = stats.mean_var(arr, axis=axis, correction=1)

typings/cupy/_core/core.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from numpy.typing import NDArray
88
class ndarray:
99
dtype: np.dtype[Any]
1010
shape: tuple[int, ...]
11+
ndim: int
1112

1213
# cupy-specific
1314
def get(self) -> NDArray[Any]: ...

typings/cupyx/scipy/sparse/_base.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ from numpy.typing import NDArray
99
class spmatrix:
1010
dtype: np.dtype[Any]
1111
shape: tuple[int, int]
12+
ndim: int
1213
def toarray(self, order: Literal["C", "F", None] = None, out: None = None) -> cupy.ndarray: ...
1314
def __power__(self, other: int) -> Self: ...
1415
def __array__(self) -> NDArray[Any]: ...

typings/h5py.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class HLObject: ...
1212
class Dataset(HLObject):
1313
dtype: np.dtype[Any]
1414
shape: tuple[int, ...]
15+
ndim: int
1516

1617
class Group(HLObject): ...
1718

0 commit comments

Comments
 (0)