diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a837e7e..cdf2900 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,6 +25,8 @@ jobs: - '3.10' - '3.11' - '3.12' + - '3.13' + - '3.14' sqlalchemy-version: - '1.4' - '2.0' @@ -42,7 +44,7 @@ jobs: - name: Run test suite with tox # Run tox using the version of Python in `PATH` - run: tox run -e clean,py-sqlalchemy${{ matrix.sqlalchemy-version }},report,flake8 -- --junit-xml=reports/pytest_${{ matrix.python-version }}_sqlalchemy${{ matrix.sqlalchemy-version }}.xml + run: tox run -e clean,py-sqlalchemy${{ matrix.sqlalchemy-version }},report,flake8,mypy -- --junit-xml=reports/pytest_${{ matrix.python-version }}_sqlalchemy${{ matrix.sqlalchemy-version }}.xml - name: Upload test result artifacts uses: actions/upload-artifact@v4 diff --git a/Makefile b/Makefile index b1290ef..f449ac6 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,11 @@ test: flake8: tox run -e flake8 +# Only run mypy (via tox; you can also just run "mypy" directly) +.PHONY: mypy +mypy: + tox run -e mypy + # Open HTML coverage report in browser .PHONY: open-coverage open-coverage: @@ -65,9 +70,19 @@ _docker-tox: docker-tox: _docker-tox # Run partial tox test suites in Docker -.PHONY: docker-test-py312-sqlalchemy1.4 docker-test-py312-sqlalchemy2.0 \ +.PHONY: docker-test-py314-sqlalchemy1.4 docker-test-py314-sqlalchemy2.0 \ + docker-test-py313-sqlalchemy1.4 docker-test-py313-sqlalchemy2.0 \ + docker-test-py312-sqlalchemy1.4 docker-test-py312-sqlalchemy2.0 \ docker-test-py311-sqlalchemy1.4 docker-test-py311-sqlalchemy2.0 \ docker-test-py310-sqlalchemy1.4 docker-test-py310-sqlalchemy2.0 +docker-test-py314-sqlalchemy1.4: TOX_ARGS="-e clean,py314-sqlalchemy1.4,py312-report" +docker-test-py314-sqlalchemy1.4: _docker-tox +docker-test-py314-sqlalchemy2.0: TOX_ARGS="-e clean,py314-sqlalchemy2.0,py312-report" +docker-test-py314-sqlalchemy2.0: _docker-tox +docker-test-py313-sqlalchemy1.4: TOX_ARGS="-e clean,py313-sqlalchemy1.4,py312-report" +docker-test-py313-sqlalchemy1.4: _docker-tox +docker-test-py313-sqlalchemy2.0: TOX_ARGS="-e clean,py313-sqlalchemy2.0,py312-report" +docker-test-py313-sqlalchemy2.0: _docker-tox docker-test-py312-sqlalchemy1.4: TOX_ARGS="-e clean,py312-sqlalchemy1.4,py312-report" docker-test-py312-sqlalchemy1.4: _docker-tox docker-test-py312-sqlalchemy2.0: TOX_ARGS="-e clean,py312-sqlalchemy2.0,py312-report" @@ -90,6 +105,10 @@ docker-test-all: make docker-test-py311-sqlalchemy2.0 make docker-test-py312-sqlalchemy1.4 make docker-test-py312-sqlalchemy2.0 + make docker-test-py313-sqlalchemy1.4 + make docker-test-py313-sqlalchemy2.0 + make docker-test-py314-sqlalchemy1.4 + make docker-test-py314-sqlalchemy2.0 # Pull the latest image of the multi-python Docker image .PHONY: docker-pull diff --git a/docs/02-using-search-queries.md b/docs/02-using-search-queries.md index 61c71ee..6af09c2 100644 --- a/docs/02-using-search-queries.md +++ b/docs/02-using-search-queries.md @@ -197,6 +197,7 @@ handle your special filter, and then continue using the methods like you would n Here is an example how you could implement this: ```python +from typing import TypeVar, override from sqlalchemy.orm import Session, Query from validataclass.validators import StringValidator from validataclass_search_queries.filters import SearchParamContains, BoundSearchFilter @@ -204,6 +205,8 @@ from validataclass_search_queries.pagination import PaginatedResult from validataclass_search_queries.repositories import SearchQueryRepositoryMixin from validataclass_search_queries.search_queries import search_query_dataclass, BaseSearchQuery +T_Query = TypeVar('T_Query') + # Stubs for SQLAlchemy models (Customer needs a 1:n relationship "addresses" to Address) class Customer: ... class Address: ... @@ -231,7 +234,8 @@ class CustomerRepository(SearchQueryRepositoryMixin[Customer]): query = self.session.query(Customer).join(Customer.addresses) return self._search_and_paginate(query, search_query) - def _apply_bound_search_filter(self, query: Query, bound_filter: BoundSearchFilter) -> Query: + @override + def _apply_bound_search_filter(self, query: Query[T_Query], bound_filter: BoundSearchFilter) -> Query[T_Query]: # Implement special handling for the "city" filter if bound_filter.column_name == 'city': return query.filter(bound_filter.get_sqlalchemy_filter(Address.city)) diff --git a/pyproject.toml b/pyproject.toml index 9097f7d..b67f19c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,3 +9,54 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] write_to = "src/validataclass_search_queries/_version.py" version_scheme = "post-release" + +[tool.mypy] +files = ["src/", "tests/"] +mypy_path = "src/" +explicit_package_bases = true + +# Enable validataclass mypy plugin +plugins = [ + "validataclass.mypy.plugin", +] + +# Enable strict type checking +strict = true + +# Enable further checks that are not included in strict mode +disallow_any_unimported = true +strict_equality_for_none = true +warn_unreachable = true +enable_error_code = [ + "deprecated", + "explicit-override", + "ignore-without-code", + "mutable-override", + "possibly-undefined", + "redundant-expr", + "redundant-self", + "truthy-bool", + "truthy-iterable", + "unused-awaitable", +] + +[[tool.mypy.overrides]] +module = 'tests.*' + +# Don't enforce typed definitions in tests, this is a lot of unnecessary work (most parameters would be Any anyway). +allow_untyped_defs = true + +[tool.validataclass_mypy] +# Allow incompatible overrides for fields in validataclass sub classes (this is the default, but the default might be +# changed in the future). +allow_incompatible_field_overrides = true + +# Declare @search_query_dataclass as a decorator that creates validataclasses +custom_validataclass_decorators = [ + "validataclass_search_queries.search_queries.search_query_dataclass.search_query_dataclass", +] + +# Ignore SearchParam objects in validataclass field definitions +ignore_custom_types_in_fields = [ + "validataclass_search_queries.filters.search_params.base_search_param.SearchParam", +] diff --git a/setup.cfg b/setup.cfg index 642db67..41bdb86 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,16 +30,19 @@ packages = find: python_requires = ~=3.10 install_requires = typing-extensions ~= 4.15 - # Allow validataclass 0.10.* and 0.11.* - validataclass >= 0.10.0, < 0.12.0 + # Allow validataclass 0.12.* + validataclass >= 0.12.0, < 0.13.0 sqlalchemy >= 1.4, < 2.1 [options.packages.find] where = src [options.extras_require] +# Set minimum versions but allow patch level updates with "~= x.y.z" to avoid breaking tests or similar, e.g. when mypy +# changes how error messages look in a minor version update. testing = - pytest ~= 9.0 - pytest-cov ~= 7.0 - coverage ~= 7.13 - flake8 ~= 7.3 + pytest ~= 9.0.3 + pytest-cov ~= 7.0.0 + coverage ~= 7.13.5 + flake8 ~= 7.3.0 + mypy ~= 1.19.1 diff --git a/src/validataclass_search_queries/filters/__init__.py b/src/validataclass_search_queries/filters/__init__.py index 0abf26e..c450530 100644 --- a/src/validataclass_search_queries/filters/__init__.py +++ b/src/validataclass_search_queries/filters/__init__.py @@ -24,3 +24,24 @@ SearchParamStartsWith, SearchParamEndsWith, ) + +__all__ = [ + 'BoundSearchFilter', + 'SearchParam', + 'SearchParamBoolean', + 'SearchParamIsNone', + 'SearchParamIsNotNone', + 'SearchParamTernary', + 'SearchParamCustom', + 'SearchParamEquals', + 'SearchParamGreaterThan', + 'SearchParamGreaterOrEqual', + 'SearchParamLessThan', + 'SearchParamLessOrEqual', + 'SearchParamSince', + 'SearchParamUntil', + 'SearchParamMultiSelect', + 'SearchParamContains', + 'SearchParamStartsWith', + 'SearchParamEndsWith', +] diff --git a/src/validataclass_search_queries/filters/bound_search_filter.py b/src/validataclass_search_queries/filters/bound_search_filter.py index fe2caac..da96042 100644 --- a/src/validataclass_search_queries/filters/bound_search_filter.py +++ b/src/validataclass_search_queries/filters/bound_search_filter.py @@ -69,7 +69,7 @@ def column_name(self) -> str: """ return self.search_param.column_name or self.param_name - def get_sqlalchemy_filter(self, column: ColumnElement) -> ColumnElement: + def get_sqlalchemy_filter(self, column: ColumnElement[Any]) -> ColumnElement[bool]: """ Returns an SQLAlchemy filter for the given column (can be any ColumnElement, i.e. any SQLAlchemy expression that can be used in a WHERE clause) based on the filter function defined by the SearchParam. diff --git a/src/validataclass_search_queries/filters/search_params/__init__.py b/src/validataclass_search_queries/filters/search_params/__init__.py index 1b32121..77ca58f 100644 --- a/src/validataclass_search_queries/filters/search_params/__init__.py +++ b/src/validataclass_search_queries/filters/search_params/__init__.py @@ -27,3 +27,23 @@ SearchParamStartsWith, SearchParamEndsWith, ) + +__all__ = [ + 'SearchParam', + 'SearchParamBoolean', + 'SearchParamIsNone', + 'SearchParamIsNotNone', + 'SearchParamTernary', + 'SearchParamCustom', + 'SearchParamEquals', + 'SearchParamGreaterThan', + 'SearchParamGreaterOrEqual', + 'SearchParamLessThan', + 'SearchParamLessOrEqual', + 'SearchParamSince', + 'SearchParamUntil', + 'SearchParamMultiSelect', + 'SearchParamContains', + 'SearchParamStartsWith', + 'SearchParamEndsWith', +] diff --git a/src/validataclass_search_queries/filters/search_params/base_search_param.py b/src/validataclass_search_queries/filters/search_params/base_search_param.py index 879ac83..05b72cb 100644 --- a/src/validataclass_search_queries/filters/search_params/base_search_param.py +++ b/src/validataclass_search_queries/filters/search_params/base_search_param.py @@ -52,9 +52,8 @@ class MySearchQuery(BaseSearchQuery): def __init__(self, column_name: str | None = None): self.column_name = column_name - @staticmethod # pragma: nocover - @abstractmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: + @abstractmethod # pragma: nocover + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: """ This abstract method defines the SQLAlchemy filter expression. See existing implementations for examples. """ diff --git a/src/validataclass_search_queries/filters/search_params/search_param_boolean.py b/src/validataclass_search_queries/filters/search_params/search_param_boolean.py index 443bac8..0805e79 100644 --- a/src/validataclass_search_queries/filters/search_params/search_param_boolean.py +++ b/src/validataclass_search_queries/filters/search_params/search_param_boolean.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy.sql import ColumnElement +from typing_extensions import override from .base_search_param import SearchParam @@ -23,8 +24,8 @@ class SearchParamBoolean(SearchParam): Boolean search parameter to filter a boolean column for true or false. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: return column.is_(bool(value)) @@ -36,21 +37,22 @@ class SearchParamIsNone(SearchParam): If the search parameter is False, only results where the specified column is NOT None will be included. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: return column.is_(None) if value is True else column.is_not(None) class SearchParamIsNotNone(SearchParam): """ - Boolean search parameter to filter a column for values that are None or not None. Inverted version of SearchParamIsNone. + Boolean search parameter to filter a column for values that are None or not None. + Inverted version of SearchParamIsNone. If the search parameter is True, only results where the specified column is NOT None will be included. If the search parameter is False, only results where the specified column is None will be included. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: return column.is_not(None) if value is True else column.is_(None) @@ -74,5 +76,6 @@ def __init__(self, true: Any, false: Any, *, column_name: str | None = None): self.value_true = true self.value_false = false - def sqlalchemy_filter(self, column: ColumnElement, value: Any) -> ColumnElement: - return column == (self.value_true if value else self.value_false) + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.__eq__(self.value_true if value else self.value_false) diff --git a/src/validataclass_search_queries/filters/search_params/search_param_custom.py b/src/validataclass_search_queries/filters/search_params/search_param_custom.py index 5461187..134b06d 100644 --- a/src/validataclass_search_queries/filters/search_params/search_param_custom.py +++ b/src/validataclass_search_queries/filters/search_params/search_param_custom.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy.sql import ColumnElement +from typing_extensions import override from .base_search_param import SearchParam @@ -24,6 +25,6 @@ class SearchParamCustom(SearchParam): overriding either `_apply_bound_search_filter` or `_filter_by_search_query`. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: raise NotImplementedError('Custom search parameter needs to be handled in the repository!') diff --git a/src/validataclass_search_queries/filters/search_params/search_param_equals.py b/src/validataclass_search_queries/filters/search_params/search_param_equals.py index 55ff16f..f1e9bb9 100644 --- a/src/validataclass_search_queries/filters/search_params/search_param_equals.py +++ b/src/validataclass_search_queries/filters/search_params/search_param_equals.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy.sql import ColumnElement +from typing_extensions import override from .base_search_param import SearchParam @@ -22,6 +23,6 @@ class SearchParamEquals(SearchParam): Note: For strings, this might or might not be case sensitive, depending on your database collations. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - return column == value + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.__eq__(value) diff --git a/src/validataclass_search_queries/filters/search_params/search_param_greater_less.py b/src/validataclass_search_queries/filters/search_params/search_param_greater_less.py index e0e6bd0..ec13c58 100644 --- a/src/validataclass_search_queries/filters/search_params/search_param_greater_less.py +++ b/src/validataclass_search_queries/filters/search_params/search_param_greater_less.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy.sql import ColumnElement +from typing_extensions import override from .base_search_param import SearchParam @@ -25,9 +26,9 @@ class SearchParamGreaterThan(SearchParam): Search parameter to filter for values greater than the filter value (`column > value`). """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - return column > value + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.__gt__(value) class SearchParamGreaterOrEqual(SearchParam): @@ -35,9 +36,9 @@ class SearchParamGreaterOrEqual(SearchParam): Search parameter to filter for values greater than or equal to the filter value (`column >= value`). """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - return column >= value + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.__ge__(value) class SearchParamLessThan(SearchParam): @@ -45,9 +46,9 @@ class SearchParamLessThan(SearchParam): Search parameter to filter for values less than the filter value (`column < value`). """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - return column < value + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.__lt__(value) class SearchParamLessOrEqual(SearchParam): @@ -55,9 +56,9 @@ class SearchParamLessOrEqual(SearchParam): Search parameter to filter for values less than or equal to the filter value (`column <= value`). """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - return column <= value + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.__le__(value) class SearchParamSince(SearchParamGreaterOrEqual): diff --git a/src/validataclass_search_queries/filters/search_params/search_param_multi_select.py b/src/validataclass_search_queries/filters/search_params/search_param_multi_select.py index 29aa2cc..4bc61c2 100644 --- a/src/validataclass_search_queries/filters/search_params/search_param_multi_select.py +++ b/src/validataclass_search_queries/filters/search_params/search_param_multi_select.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy.sql import ColumnElement +from typing_extensions import override from .base_search_param import SearchParam @@ -26,7 +27,7 @@ class SearchParamMultiSelect(SearchParam): is set to a single value, the filter will be equivalent to an "equals" filter. See the `MultiSelectValidator`. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: value_list = value if isinstance(value, list) else [value] return column.in_(value_list) diff --git a/src/validataclass_search_queries/filters/search_params/search_param_substring.py b/src/validataclass_search_queries/filters/search_params/search_param_substring.py index 5290d29..5ef3dc4 100644 --- a/src/validataclass_search_queries/filters/search_params/search_param_substring.py +++ b/src/validataclass_search_queries/filters/search_params/search_param_substring.py @@ -7,6 +7,7 @@ from typing import Any from sqlalchemy.sql import ColumnElement +from typing_extensions import override from .base_search_param import SearchParam @@ -25,10 +26,9 @@ class SearchParamContains(SearchParam): interpreted as literal characters, not as wildcard characters. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - # Short-circuit if value is empty - return column.contains(value, autoescape=True) if value else column + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.contains(value, autoescape=True) class SearchParamStartsWith(SearchParam): @@ -39,10 +39,9 @@ class SearchParamStartsWith(SearchParam): interpreted as literal characters, not as wildcard characters. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - # Short-circuit if value is empty - return column.startswith(value, autoescape=True) if value else column + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.startswith(value, autoescape=True) class SearchParamEndsWith(SearchParam): @@ -53,7 +52,6 @@ class SearchParamEndsWith(SearchParam): interpreted as literal characters, not as wildcard characters. """ - @staticmethod - def sqlalchemy_filter(column: ColumnElement, value: Any) -> ColumnElement: - # Short-circuit if value is empty - return column.endswith(value, autoescape=True) if value else column + @override + def sqlalchemy_filter(self, column: ColumnElement[Any], value: Any) -> ColumnElement[bool]: + return column.endswith(value, autoescape=True) diff --git a/src/validataclass_search_queries/pagination/__init__.py b/src/validataclass_search_queries/pagination/__init__.py index 49bc070..7255f4b 100644 --- a/src/validataclass_search_queries/pagination/__init__.py +++ b/src/validataclass_search_queries/pagination/__init__.py @@ -10,3 +10,13 @@ from .paginated_result import PaginatedResult from .pagination_limit_validator import PaginationLimitValidator, PaginationLimitRequiredError from .response_helpers import paginated_api_response + +__all__ = [ + 'AbstractPaginationMixin', + 'CursorPaginationMixin', + 'OffsetPaginationMixin', + 'PaginatedResult', + 'PaginationLimitValidator', + 'PaginationLimitRequiredError', + 'paginated_api_response', +] diff --git a/src/validataclass_search_queries/pagination/abstract_pagination_mixin.py b/src/validataclass_search_queries/pagination/abstract_pagination_mixin.py index 65bfeb3..3d0d9ce 100644 --- a/src/validataclass_search_queries/pagination/abstract_pagination_mixin.py +++ b/src/validataclass_search_queries/pagination/abstract_pagination_mixin.py @@ -5,7 +5,7 @@ """ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeVar from sqlalchemy.orm import Query @@ -15,6 +15,8 @@ 'AbstractPaginationMixin', ] +T = TypeVar('T') + class AbstractPaginationMixin(ABC): """ @@ -25,7 +27,7 @@ class AbstractPaginationMixin(ABC): limit: int | None @abstractmethod - def apply_pagination_to_query(self, query: Query, model_cls: Any) -> Query: + def apply_pagination_to_query(self, query: Query[T], model_cls: Any) -> Query[T]: """ Applies the pagination parameters to an SQLAlchemy query and returns the new query. @@ -45,7 +47,7 @@ def get_start_parameter_name(self) -> str: raise NotImplementedError @abstractmethod - def get_next_start_value(self, paginated_result: PaginatedResult) -> int | None: + def get_next_start_value(self, paginated_result: PaginatedResult[Any]) -> int | None: """ Returns the next value for the pagination start parameter (see also: `get_start_parameter_name()`) to retrieve the next page of data, or None if there is no next page. diff --git a/src/validataclass_search_queries/pagination/cursor_pagination_mixin.py b/src/validataclass_search_queries/pagination/cursor_pagination_mixin.py index 55db410..3846e78 100644 --- a/src/validataclass_search_queries/pagination/cursor_pagination_mixin.py +++ b/src/validataclass_search_queries/pagination/cursor_pagination_mixin.py @@ -4,22 +4,25 @@ Use of this source code is governed by an MIT-style license that can be found in the LICENSE file. """ -from typing import Any +from typing import Any, TypeVar, cast from sqlalchemy.orm import Query from sqlalchemy.sql import ColumnElement +from typing_extensions import override from validataclass.dataclasses import validataclass, Default from validataclass.validators import IntegerValidator +from validataclass_search_queries import pagination, sorting from .abstract_pagination_mixin import AbstractPaginationMixin from .paginated_result import PaginatedResult from .pagination_limit_validator import PaginationLimitValidator -from .. import pagination, sorting __all__ = [ 'CursorPaginationMixin', ] +T = TypeVar('T') + @validataclass class CursorPaginationMixin(AbstractPaginationMixin): @@ -92,12 +95,17 @@ class ExampleSearchQuery(CursorPaginationMixin, BaseSearchQuery): # Limit: Number of entries per page limit: int | None = PaginationLimitValidator(max_value=100), Default(20) - def __init_subclass__(cls, **kwargs): + @override + def __init_subclass__(cls, **kwargs: Any): # Pagination mixins are not compatible with each other, only one can be used at the same time - if issubclass(cls, pagination.OffsetPaginationMixin): - raise TypeError(f'Invalid base classes in {cls}: Combining multiple pagination mixins is not allowed') - if issubclass(cls, sorting.SortingMixin): - raise TypeError(f'Invalid base classes in {cls}: CursorPaginationMixin cannot be combined with SortingMixin') + if issubclass(cls, pagination.OffsetPaginationMixin): # type: ignore[unreachable, unused-ignore] + raise TypeError( + f'Invalid base classes in {cls}: Combining multiple pagination mixins is not allowed' + ) + if issubclass(cls, sorting.SortingMixin): # type: ignore[unreachable, unused-ignore] + raise TypeError( + f'Invalid base classes in {cls}: CursorPaginationMixin cannot be combined with SortingMixin' + ) super().__init_subclass__(**kwargs) @@ -111,7 +119,7 @@ def get_cursor_column_name() -> str: """ return 'id' - def get_cursor_column(self, model_cls: Any) -> ColumnElement: + def get_cursor_column(self, model_cls: Any) -> ColumnElement[Any]: """ Returns the column that is used as cursor for cursor pagination. @@ -121,9 +129,12 @@ def get_cursor_column(self, model_cls: Any) -> ColumnElement: THIS method, be sure to also adjust `get_cursor_column_name()` so that it still works with other methods like `get_next_start_value()`. """ - return getattr(model_cls, self.get_cursor_column_name()) + # SQLAlchemy's typing is complicated and we don't know what exact types we have to expect here, so we'll just + # pretend it's always a ColumnElement to make the type checker happy. + return cast(ColumnElement[Any], getattr(model_cls, self.get_cursor_column_name())) - def apply_pagination_to_query(self, query: Query, model_cls: Any) -> Query: + @override + def apply_pagination_to_query(self, query: Query[T], model_cls: Any) -> Query[T]: """ Applies the pagination parameters to an SQLAlchemy query and returns the new query. @@ -135,24 +146,28 @@ def apply_pagination_to_query(self, query: Query, model_cls: Any) -> Query: return query # The start parameter should always be set, but in case it is not, default to 0 - if self.start is None: - self.start = 0 + if self.start is None: # type: ignore[comparison-overlap] + self.start = 0 # type: ignore[unreachable] # Get the cursor column from the model class key_column = self.get_cursor_column(model_cls) # Cursor pagination requires the data to be ordered by the cursor column - return query.order_by(key_column) \ - .filter(key_column >= self.start) \ + return ( + query.order_by(key_column) + .filter(key_column >= self.start) .limit(self.limit) + ) + @override def get_start_parameter_name(self) -> str: """ Returns the name of the pagination start parameter ("start" for cursor pagination). """ return 'start' - def get_next_start_value(self, paginated_result: PaginatedResult) -> int | None: + @override + def get_next_start_value(self, paginated_result: PaginatedResult[Any]) -> int | None: """ Returns the next value for the pagination start parameter to retrieve the next page of data, or None if there is no next page. @@ -170,11 +185,16 @@ def get_next_start_value(self, paginated_result: PaginatedResult) -> int | None: # Get last result in list last_item = paginated_result[-1] - # Get cursor value (e.g. ID) of last result, allowing both objects and dictionaries, and increment by one + # Get cursor value (e.g. ID) of last result, allowing both objects and dictionaries cursor_key = self.get_cursor_column_name() if isinstance(last_item, dict): - return last_item.get(cursor_key) + 1 + last_item_value = last_item.get(cursor_key) elif hasattr(last_item, cursor_key): - return getattr(last_item, cursor_key) + 1 + last_item_value = getattr(last_item, cursor_key) else: - raise Exception(f'Last item of PaginatedResult has neither attribute nor dictionary key "{cursor_key}": {last_item}') + raise Exception( + f'Last item of PaginatedResult has neither attribute nor dictionary key "{cursor_key}": {last_item}' + ) + + # Return last cursor value incremented by one + return last_item_value + 1 if isinstance(last_item_value, int) else None diff --git a/src/validataclass_search_queries/pagination/offset_pagination_mixin.py b/src/validataclass_search_queries/pagination/offset_pagination_mixin.py index f448e91..258f24d 100644 --- a/src/validataclass_search_queries/pagination/offset_pagination_mixin.py +++ b/src/validataclass_search_queries/pagination/offset_pagination_mixin.py @@ -4,21 +4,24 @@ Use of this source code is governed by an MIT-style license that can be found in the LICENSE file. """ -from typing import Any +from typing import Any, TypeVar from sqlalchemy.orm import Query +from typing_extensions import override from validataclass.dataclasses import validataclass, Default from validataclass.validators import IntegerValidator +from validataclass_search_queries import pagination from .abstract_pagination_mixin import AbstractPaginationMixin from .paginated_result import PaginatedResult from .pagination_limit_validator import PaginationLimitValidator -from .. import pagination __all__ = [ 'OffsetPaginationMixin', ] +T = TypeVar('T') + @validataclass class OffsetPaginationMixin(AbstractPaginationMixin): @@ -87,14 +90,16 @@ class ExampleSearchQuery(OffsetPaginationMixin, BaseSearchQuery): # Limit: Number of entries per page limit: int | None = PaginationLimitValidator(max_value=100), Default(20) - def __init_subclass__(cls, **kwargs): + @override + def __init_subclass__(cls, **kwargs: Any): # Pagination mixins are not compatible with each other, only one can be used at the same time - if issubclass(cls, pagination.CursorPaginationMixin): + if issubclass(cls, pagination.CursorPaginationMixin): # type: ignore[unreachable, unused-ignore] raise TypeError(f'Invalid base classes in {cls}: Combining multiple pagination mixins is not allowed') super().__init_subclass__(**kwargs) - def apply_pagination_to_query(self, query: Query, model_cls: Any) -> Query: + @override + def apply_pagination_to_query(self, query: Query[T], model_cls: Any) -> Query[T]: """ Applies the pagination parameters to an SQLAlchemy query and returns the new query. @@ -106,13 +111,15 @@ def apply_pagination_to_query(self, query: Query, model_cls: Any) -> Query: return query.offset(self.offset).limit(self.limit) + @override def get_start_parameter_name(self) -> str: """ Returns the name of the pagination start parameter ("offset" for offset pagination). """ return 'offset' - def get_next_start_value(self, paginated_result: PaginatedResult) -> int | None: + @override + def get_next_start_value(self, paginated_result: PaginatedResult[Any]) -> int | None: """ Returns the next value for the pagination start parameter to retrieve the next page of data, or None if there is no next page. diff --git a/src/validataclass_search_queries/pagination/paginated_result.py b/src/validataclass_search_queries/pagination/paginated_result.py index 03e2923..0e7dcd3 100644 --- a/src/validataclass_search_queries/pagination/paginated_result.py +++ b/src/validataclass_search_queries/pagination/paginated_result.py @@ -5,12 +5,10 @@ """ from collections.abc import Callable, Iterable -from typing import TypeVar +from typing import Any, TypeVar __all__ = [ 'PaginatedResult', - 'T_Result', - 'T_MappedResult', ] T_Result = TypeVar('T_Result') @@ -60,17 +58,20 @@ def map_customers(customer: Customer) -> dict: # This results in a PaginatedResult[dict] containing dictionaries as defined in map_customers above: mapped_customers = paginated_result.map(map_customers) - # Assuming that the Customer class has a similar method `to_dict()` that takes no arguments, we can also do this: + # If the Customer class has a similar method `to_dict()` that takes no arguments, we can also do this: mapped_customers = paginated_result.map(Customers.to_dict) ``` """ - # We use self.__class__() instead of PaginatedResult() to properly support subclassing - return self.__class__( + # Previously, we used self.__class__() instead of PaginatedResult() here to properly support subclassing. + # However, we don't know whether this potential subclass is a Generic too, so it might not even support + # T_MappedResult. In other words, `self.__class__` can only be a subtype of `PaginatedResult[T_Result]`, + # not of `PaginatedResult[T_MappedResult]`. + return PaginatedResult( map(map_func, self), total_count=self.total_count, ) - def to_dict(self, *, recursive: bool = False) -> dict: + def to_dict(self, *, recursive: bool = False) -> dict[str, Any]: """ Returns a dictionary representing the PaginatedResult, consisting of the keys "items" (a list of the items) and "total_count" (the total count as an integer). diff --git a/src/validataclass_search_queries/pagination/pagination_limit_validator.py b/src/validataclass_search_queries/pagination/pagination_limit_validator.py index b14f10a..a090b17 100644 --- a/src/validataclass_search_queries/pagination/pagination_limit_validator.py +++ b/src/validataclass_search_queries/pagination/pagination_limit_validator.py @@ -6,8 +6,9 @@ from typing import Any +from typing_extensions import override from validataclass.exceptions import ValidationError -from validataclass.validators import IntegerValidator +from validataclass.validators import IntegerValidator, Validator __all__ = [ 'PaginationLimitValidator', @@ -15,7 +16,7 @@ ] -class PaginationLimitValidator(IntegerValidator): +class PaginationLimitValidator(Validator[int | None]): """ Validator for the pagination limit, based on an IntegerValidator. @@ -44,6 +45,9 @@ class PaginationLimitValidator(IntegerValidator): ``` """ + # Base validator for integer validation + integer_validator: IntegerValidator + # If true, pagination is optional for the user (set limit=0 to disable pagination) optional: bool @@ -60,17 +64,19 @@ def __init__( meaning that pagination is disabled (i.e. unlimited results). Parameters: - optional: Boolean, whether pagination is optional, i.e. the user can set limit=0 to disable pagination (default: False) - max_value: Integer or None, maximum value for pagination limit (default: IntegerValidator.DEFAULT_MAX_VALUE = 2147483647) + `optional`: bool, whether pagination can be disabled by setting limit to 0 (default: False) + `max_value`: int or None, maximum value for pagination limit (default: `IntegerValidator.DEFAULT_MAX_VALUE`) """ - super().__init__( - min_value=0, # if optional else 1, + # Initialize base integer validator + self.integer_validator = IntegerValidator( + min_value=0, max_value=max_value, allow_strings=True, ) self.optional = optional - def validate(self, input_data: Any, **kwargs) -> int | None: + @override + def validate(self, input_data: Any, **kwargs: Any) -> int | None: """ Validates the input as an integer. Returns the integer or None if the input is 0 or None. """ @@ -79,7 +85,7 @@ def validate(self, input_data: Any, **kwargs) -> int | None: input_data = 0 # Validate input as integer - validated_input = super().validate(input_data, **kwargs) + validated_input = self.integer_validator.validate(input_data, **kwargs) # If pagination is optional, treat 0 as "no limit" (i.e. no pagination) if validated_input == 0: diff --git a/src/validataclass_search_queries/pagination/response_helpers.py b/src/validataclass_search_queries/pagination/response_helpers.py index d30bcad..d44ccde 100644 --- a/src/validataclass_search_queries/pagination/response_helpers.py +++ b/src/validataclass_search_queries/pagination/response_helpers.py @@ -6,9 +6,9 @@ from typing import Any +from validataclass_search_queries.search_queries import BaseSearchQuery from .abstract_pagination_mixin import AbstractPaginationMixin from .paginated_result import PaginatedResult -from ..search_queries import BaseSearchQuery __all__ = [ 'paginated_api_response', @@ -21,8 +21,8 @@ def paginated_api_response( *, recursive_to_dict: bool = True, request_path: str | None = None, - original_params: dict | None = None, -) -> dict: + original_params: dict[str, Any] | None = None, +) -> dict[str, Any]: """ Constructs a REST API response (as a dictionary) for paginated results. @@ -53,9 +53,9 @@ def paginated_api_response( "total_count" is the total number of results before pagination (i.e. if there are 123 results and the page limit is 10, there will be 10 items, but "total_count" will be 123). - If there might be a next page (which is determined e.g. by the number of results and the given pagination parameters), - there will be another field that contains the start value for the next page. In case of offset pagination, this - field will be called "next_offset", in case of cursor pagination, it will be called "next_id". + If it is possible that there is a next page (which is determined e.g. by the number of results and the given + pagination parameters), there will be another field that contains the start value for the next page. For offset + pagination, this field will be called "next_offset", for cursor pagination, it will be called "next_id". Additionally, if the optional parameter `request_path` is set, another field called "next_path" will be added to the response, containing the URL path with query parameters that can be used to retrieve the next page. This string @@ -83,14 +83,15 @@ def paginated_api_response( if next_start_value is None: return response_data - # Write next start parameter to response. For legacy reasons, cursor pagination uses "next_id" instead of "next_start". + # Add next start parameter to response. For compatibility reasons, cursor pagination uses "next_id" instead of + # "next_start". # TODO: This might be changed in the future, but we need to keep compatibility somehow... start_param = search_query.get_start_parameter_name() response_data['next_id' if start_param == 'start' else f'next_{start_param}'] = next_start_value # Only set next_path if a request base path is given if request_path is not None: - # Construct parameters for next page from original request parameters (if given). Ensure limit parameter is set. + # Construct parameters for next page from original request (if given). Ensure limit parameter is set. next_path_params = dict(original_params) if original_params is not None else {} next_path_params.update({ start_param: next_start_value, diff --git a/src/validataclass_search_queries/py.typed b/src/validataclass_search_queries/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/validataclass_search_queries/repositories/__init__.py b/src/validataclass_search_queries/repositories/__init__.py index 4cd9529..6318f19 100644 --- a/src/validataclass_search_queries/repositories/__init__.py +++ b/src/validataclass_search_queries/repositories/__init__.py @@ -5,3 +5,7 @@ """ from .search_query_repository_mixin import SearchQueryRepositoryMixin + +__all__ = [ + 'SearchQueryRepositoryMixin', +] diff --git a/src/validataclass_search_queries/repositories/search_query_repository_mixin.py b/src/validataclass_search_queries/repositories/search_query_repository_mixin.py index f1816f8..261f2fb 100644 --- a/src/validataclass_search_queries/repositories/search_query_repository_mixin.py +++ b/src/validataclass_search_queries/repositories/search_query_repository_mixin.py @@ -5,21 +5,21 @@ """ from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Any from sqlalchemy.orm import Query -from ..filters import BoundSearchFilter -from ..pagination import AbstractPaginationMixin, PaginatedResult -from ..search_queries import BaseSearchQuery -from ..sorting import AbstractSortingMixin +from validataclass_search_queries.filters import BoundSearchFilter +from validataclass_search_queries.pagination import AbstractPaginationMixin, PaginatedResult +from validataclass_search_queries.search_queries import BaseSearchQuery +from validataclass_search_queries.sorting import AbstractSortingMixin __all__ = [ 'SearchQueryRepositoryMixin', - 'T_Model', ] T_Model = TypeVar('T_Model') +T_Query = TypeVar('T_Query') class SearchQueryRepositoryMixin(Generic[T_Model], ABC): @@ -85,8 +85,9 @@ def fetch_examples(self, *, search_query: BaseSearchQuery | None = None) -> Pagi The default implementation of `_apply_bound_search_filter()` first gets the column of the model class with the name specified in the SearchParam (`getattr(self.model_cls, bound_filter.column_name)` and applies the search filter to - this column (`query.filter(bound_filter.get_sqlalchemy_filter(col))`). For example, with a `SearchParamSince('created')` - and the `Example` model, the "since" filter (i.e. `>=`) would be applied to the column `Example.created`. + this column (`query.filter(bound_filter.get_sqlalchemy_filter(col))`). For example, with a search parameter + `SearchParamSince('created')` and the `Example` model, the "since" filter (i.e. `>=`) would be applied to the + column `Example.created`. If you need to, you can override this method, choose a different column (maybe even from a different model, or using some SQL functions) and apply the search filter to this column. @@ -114,7 +115,8 @@ def fetch_customers(self, *, search_query: BaseSearchQuery | None = None) -> Pag return self._search_and_paginate(query, search_query) # Override the default method for applying search filters - def _apply_bound_search_filter(self, query: Query, bound_filter: BoundSearchFilter) -> Query: + @override + def _apply_bound_search_filter(self, query: Query[T_Query], bound_filter: BoundSearchFilter) -> Query[T_Query]: # Only implement a special case for the "modified" column if bound_filter.column_name == 'modified': # Get column objects for both models @@ -145,9 +147,10 @@ def model_cls(self) -> type[T_Model]: """ raise NotImplementedError - def _search_and_paginate(self, query: Query, search_query: BaseSearchQuery | None) -> PaginatedResult[T_Model]: + def _search_and_paginate(self, query: Query[Any], search_query: BaseSearchQuery | None) -> PaginatedResult[T_Model]: """ - Filters a query based on search parameters (usually parsed from HTTP query parameters) and paginates the result. + Apply filters, sorting and pagination to a database query, based on search parameters (usually parsed from + HTTP query parameters), then execute the query and return a paginated list of results. Shortcut method for calling `_filter_by_search_query()`, `_order_by_search_query()` and `_paginate_result()`. """ @@ -155,9 +158,10 @@ def _search_and_paginate(self, query: Query, search_query: BaseSearchQuery | Non query = self._order_by_search_query(query, search_query) return self._paginate_result(query, search_query) - def _filter_by_search_query(self, query: Query, search_query: BaseSearchQuery | None) -> Query: + def _filter_by_search_query(self, query: Query[T_Query], search_query: BaseSearchQuery | None) -> Query[T_Query]: """ - Filters a query based on search parameters (usually parsed from HTTP query parameters), *excluding* pagination. + Apply filters to a database query, based on search parameters (usually parsed from HTTP query parameters). + This does not include sorting or pagination! If no search query is given (or no search parameter is set), the database query is returned unmodified. """ @@ -170,9 +174,10 @@ def _filter_by_search_query(self, query: Query, search_query: BaseSearchQuery | return query - def _apply_bound_search_filter(self, query: Query, bound_filter: BoundSearchFilter) -> Query: + def _apply_bound_search_filter(self, query: Query[T_Query], bound_filter: BoundSearchFilter) -> Query[T_Query]: """ - Filters a query based on a BoundSearchFilter. Called by _filter_by_search_query() for every set search filter. + Apply a single search filter from a `BoundSearchFilter` to a database query with `query.filter(...)`. + Called by `_filter_by_search_query()` for every set search filter. Override this method to implement custom handling for (all or specific) search filters. """ @@ -180,9 +185,9 @@ def _apply_bound_search_filter(self, query: Query, bound_filter: BoundSearchFilt col = getattr(self.model_cls, bound_filter.column_name) return query.filter(bound_filter.get_sqlalchemy_filter(col)) - def _order_by_search_query(self, query: Query, search_query: BaseSearchQuery | None) -> Query: + def _order_by_search_query(self, query: Query[T_Query], search_query: BaseSearchQuery | None) -> Query[T_Query]: """ - Applies sorting (order_by) to a query based on sorting parameters from a search query. + Apply sorting (`query.order_by(...)`) to a database query based on sorting parameters from a search query. If the search query does not implement sorting (i.e. it does not inherit from `AbstractSortingMixin`), the database query is returned unmodified. @@ -192,15 +197,16 @@ def _order_by_search_query(self, query: Query, search_query: BaseSearchQuery | N return query - def _paginate_result(self, query: Query, search_query: BaseSearchQuery | None) -> PaginatedResult[T_Model]: + def _paginate_result(self, query: Query[Any], search_query: BaseSearchQuery | None) -> PaginatedResult[T_Model]: """ - Applies pagination to a query based on search parameters, executes the query and returns a paginated result list. + Apply pagination to a database query based on search parameters, execute the query and return a paginated list + of results. - To define pagination parameters in your search query dataclass, use a pagination mixin like OffsetPaginationMixin - or StablePaginationMixin. + To define pagination parameters in your search query dataclass, use a pagination mixin class like + `OffsetPaginationMixin` or `CursorPaginationMixin`. If the search query does not implement pagination (i.e. it does not inherit from `AbstractPaginationMixin`), - a PaginatedResult with ALL results is returned (as if the pagination limit was set to infinity). + a `PaginatedResult` with ALL results is returned (as if there was no pagination limit). """ # Get total count of search results BEFORE pagination is applied total_count = query.count() diff --git a/src/validataclass_search_queries/search_queries/__init__.py b/src/validataclass_search_queries/search_queries/__init__.py index 7a173b5..ff69981 100644 --- a/src/validataclass_search_queries/search_queries/__init__.py +++ b/src/validataclass_search_queries/search_queries/__init__.py @@ -6,3 +6,8 @@ from .base_search_query import BaseSearchQuery from .search_query_dataclass import search_query_dataclass + +__all__ = [ + 'BaseSearchQuery', + 'search_query_dataclass', +] diff --git a/src/validataclass_search_queries/search_queries/base_search_query.py b/src/validataclass_search_queries/search_queries/base_search_query.py index 67cbe0e..f475aba 100644 --- a/src/validataclass_search_queries/search_queries/base_search_query.py +++ b/src/validataclass_search_queries/search_queries/base_search_query.py @@ -11,13 +11,14 @@ from validataclass.helpers import UnsetValue -from ..filters import BoundSearchFilter +from validataclass_search_queries.filters import BoundSearchFilter __all__ = [ 'BaseSearchQuery', ] +@dataclasses.dataclass class BaseSearchQuery: """ Base class for search query validataclasses, which can be used to validate search parameters (e.g. GET query diff --git a/src/validataclass_search_queries/search_queries/search_query_dataclass.py b/src/validataclass_search_queries/search_queries/search_query_dataclass.py index 00b5010..9e11976 100644 --- a/src/validataclass_search_queries/search_queries/search_query_dataclass.py +++ b/src/validataclass_search_queries/search_queries/search_query_dataclass.py @@ -6,19 +6,28 @@ import dataclasses from collections.abc import Callable +from inspect import get_annotations from typing import Any, TypeVar, overload from typing_extensions import dataclass_transform -from validataclass.dataclasses import validataclass, validataclass_field, Default +from validataclass.dataclasses import validataclass, validataclass_field, BaseDefault, Default from validataclass.exceptions import DataclassValidatorFieldException from validataclass.validators import Validator -from ..filters import SearchParam +from validataclass_search_queries.filters import SearchParam __all__ = [ 'search_query_dataclass', ] + +@dataclasses.dataclass +class _ValidatorField: + validator: Validator[Any] | None = None + default: BaseDefault[Any] | None = None + search_param: SearchParam | None = None + + _T = TypeVar('_T') @@ -31,7 +40,7 @@ def search_query_dataclass(cls: type[_T]) -> type[_T]: @overload -def search_query_dataclass(cls: None = None, /, **kwargs) -> Callable[[type[_T]], type[_T]]: +def search_query_dataclass(cls: None = None, /, **kwargs: Any) -> Callable[[type[_T]], type[_T]]: ... @@ -62,22 +71,25 @@ def search_query_dataclass( """ def decorator(_cls: type[_T]) -> type[_T]: + # Transform class to be a valid validataclass _prepare_search_query_dataclass(_cls) + + # Use @validataclass decorator to transform class into a validataclass return validataclass(_cls, **kwargs) # Allow decorator to be called with and without parenthesis return decorator if cls is None else decorator(cls) -def _prepare_search_query_dataclass(cls) -> None: +def _prepare_search_query_dataclass(cls: type) -> None: """ Internal helper function used by @search_query_dataclass to prepare validataclass fields in a soon-to-be dataclass. """ # In case of a subclassed dataclass, get the already existing fields existing_fields = _get_existing_validator_fields(cls) - # Get class annotations - cls_annotations = cls.__dict__.get('__annotations__', {}) + # Get annotations of this class (ignores base classes) + cls_annotations = get_annotations(cls) # Prepare dataclass fields by checking for validators and setting metadata accordingly for name, field_type in cls_annotations.items(): @@ -88,39 +100,51 @@ def _prepare_search_query_dataclass(cls) -> None: continue # Get current validator etc. if the field is already existing - field_args = existing_fields.get(name, {}) + existing_field = existing_fields.get(name, _ValidatorField()) - # Overwrite existing field arguments with validator etc. from tuple + # Parse field tuple try: - field_args.update(_parse_validator_tuple(value)) + parsed_field = _parse_validator_tuple(value) except Exception as e: raise DataclassValidatorFieldException(f'Dataclass field "{name}": {e}') + # Overwrite existing field arguments with validator etc. from tuple + field = _ValidatorField( + validator=parsed_field.validator or existing_field.validator, + default=parsed_field.default or existing_field.default, + search_param=parsed_field.search_param or existing_field.search_param, + ) + # Ignore all fields without a SearchParam (they will be handled by @validataclass as usual validataclass fields) - if 'search_param' not in field_args.keys(): + if field.search_param is None: continue # Ensure that a validator is set - if not isinstance(field_args.get('validator', None), Validator): - # TODO: Update exception messages to be consistent with validataclass 0.12.0 - raise DataclassValidatorFieldException(f'Dataclass field "{name}" must specify a Validator.') + if not isinstance(field.validator, Validator): + raise DataclassValidatorFieldException(f'Dataclass field "{name}" must specify a validator.') # For SearchParam fields, use Default(None) if no explicit default was set - if field_args.get('default', None) is None: - field_args['default'] = Default(None) + if field.default is None: + field.default = Default(None) - # Create validataclass field (undocumented parameter _name is needed for required fields in Python < 3.10) + # Create validataclass field setattr(cls, name, validataclass_field( - validator=field_args.get('validator'), - default=field_args.get('default'), - metadata={'search_param': field_args.get('search_param')}, - _name=name, + validator=field.validator, + default=field.default, + metadata={'search_param': field.search_param}, )) -def _get_existing_validator_fields(cls) -> dict[str, dict[str, Any]]: +def _get_existing_validator_fields(cls: type) -> dict[str, _ValidatorField]: """ - Internal helper function used by @search_query_dataclass to get all pre-existing validataclass fields from the base classes. + Returns a dictionary containing all fields (as `_ValidatorField` objects) of an existing validataclass that have a + validator set in their metadata, or an empty dictionary if the class is not a dataclass (yet). + + Existing dataclass fields are determined by looking at all direct parent classes that are dataclasses themselves. + If two unrelated base classes define a field with the same name, the most-left class takes precedence (for example, + in `class C(B, A)`, the definitions of B take precendence over A). + + (Internal helper function.) """ existing_fields = {} @@ -129,44 +153,46 @@ def _get_existing_validator_fields(cls) -> dict[str, dict[str, Any]]: continue for field in dataclasses.fields(base_cls): - existing_fields[field.name] = { - 'validator': field.metadata.get('validator', None), - 'default': field.metadata.get('validator_default', None), - 'search_param': field.metadata.get('search_param', None), - } + existing_fields[field.name] = _ValidatorField( + validator=field.metadata.get('validator', None), + default=field.metadata.get('validator_default', None), + search_param=field.metadata.get('search_param', None), + ) return existing_fields -def _parse_validator_tuple(args: Any) -> dict: +def _parse_validator_tuple(args: Any) -> _ValidatorField: """ - Internal helper function used by @search_query_dataclass to parse validataclass-style field tuples to dictionaries. + Parses field arguments (the value of a field in a dataclass that has not been parsed by `@dataclass` yet) to a + `_ValidatorField` object. + + (Internal helper function.) """ if args is None: - return {} + return _ValidatorField() # Ensure args is a tuple if not isinstance(args, tuple): args = (args,) # Find validator, default object and search param in tuple and return them as a dictionary - arg_dict = {} + field = _ValidatorField() - # TODO: Update exception messages to be consistent with validataclass 0.12.0 for arg in args: if isinstance(arg, Validator): - if 'validator' in arg_dict: - raise ValueError('Only one Validator can be specified.') - arg_dict['validator'] = arg - elif isinstance(arg, Default): - if 'default' in arg_dict: - raise ValueError('Only one Default can be specified.') - arg_dict['default'] = arg + if field.validator is not None: + raise ValueError('Only one validator can be specified.') + field.validator = arg + elif isinstance(arg, BaseDefault): + if field.default is not None: + raise ValueError('Only one default can be specified.') + field.default = arg elif isinstance(arg, SearchParam): - if 'search_param' in arg_dict: + if field.search_param is not None: raise ValueError('Only one SearchParam can be specified.') - arg_dict['search_param'] = arg + field.search_param = arg else: raise TypeError('Unexpected type of argument: ' + type(arg).__name__) - return arg_dict + return field diff --git a/src/validataclass_search_queries/sorting/__init__.py b/src/validataclass_search_queries/sorting/__init__.py index 2d06345..4579524 100644 --- a/src/validataclass_search_queries/sorting/__init__.py +++ b/src/validataclass_search_queries/sorting/__init__.py @@ -7,3 +7,10 @@ from .abstract_sorting_mixin import AbstractSortingMixin from .sorting_direction import SortingDirection, SortingDirectionValidator from .sorting_mixin import SortingMixin + +__all__ = [ + 'AbstractSortingMixin', + 'SortingDirection', + 'SortingDirectionValidator', + 'SortingMixin', +] diff --git a/src/validataclass_search_queries/sorting/abstract_sorting_mixin.py b/src/validataclass_search_queries/sorting/abstract_sorting_mixin.py index 491d098..5b35be7 100644 --- a/src/validataclass_search_queries/sorting/abstract_sorting_mixin.py +++ b/src/validataclass_search_queries/sorting/abstract_sorting_mixin.py @@ -5,7 +5,7 @@ """ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeVar from sqlalchemy.orm import Query from sqlalchemy.sql import ColumnElement @@ -16,6 +16,8 @@ 'AbstractSortingMixin', ] +T = TypeVar('T') + class AbstractSortingMixin(ABC): """ @@ -29,7 +31,7 @@ class AbstractSortingMixin(ABC): sorting_direction: SortingDirection @abstractmethod - def get_sorting_column(self, model_cls: Any) -> ColumnElement: + def get_sorting_column(self, model_cls: Any) -> ColumnElement[Any]: """ Returns the column that the query should be ordered by (excluding the sorting direction). @@ -38,7 +40,7 @@ def get_sorting_column(self, model_cls: Any) -> ColumnElement: raise NotImplementedError @abstractmethod - def apply_sorting_direction(self, column: ColumnElement) -> ColumnElement: + def apply_sorting_direction(self, column: ColumnElement[T]) -> ColumnElement[T]: """ Applies the sorting direction to an SQLAlchemy column element, i.e. `column.asc()` or `column.desc()`, and returns the new column element. @@ -46,7 +48,7 @@ def apply_sorting_direction(self, column: ColumnElement) -> ColumnElement: raise NotImplementedError @abstractmethod - def apply_sorting_to_query(self, query: Query, model_cls: Any) -> Query: + def apply_sorting_to_query(self, query: Query[T], model_cls: Any) -> Query[T]: """ Applies the sorting parameters to an SQLAlchemy query (`query.order_by()`) and returns the new query. diff --git a/src/validataclass_search_queries/sorting/sorting_mixin.py b/src/validataclass_search_queries/sorting/sorting_mixin.py index def73ba..2743e60 100644 --- a/src/validataclass_search_queries/sorting/sorting_mixin.py +++ b/src/validataclass_search_queries/sorting/sorting_mixin.py @@ -4,10 +4,11 @@ Use of this source code is governed by an MIT-style license that can be found in the LICENSE file. """ -from typing import Any +from typing import Any, TypeVar, cast from sqlalchemy.orm import Query from sqlalchemy.sql import ColumnElement +from typing_extensions import override from validataclass.dataclasses import validataclass, Default from validataclass.validators import AnyOfValidator @@ -18,6 +19,8 @@ 'SortingMixin', ] +T = TypeVar('T') + @validataclass class SortingMixin(AbstractSortingMixin): @@ -58,31 +61,36 @@ class ExampleSearchQuery(SortingMixin, BaseSearchQuery): # Sorting direction ("ASC" or "DESC", case-insensitive) sorting_direction: SortingDirection = SortingDirectionValidator(), Default(SortingDirection.ASC) - def get_sorting_column(self, model_cls: Any) -> ColumnElement: + @override + def get_sorting_column(self, model_cls: Any) -> ColumnElement[Any]: """ Returns the column that the query should be ordered by (excluding the sorting direction). The "model_cls" parameter should be the class of the database model that is queried. """ - return getattr(model_cls, self.sorted_by) + # SQLAlchemy's typing is complicated and we don't know what exact types we have to expect here, so we'll just + # pretend it's always a ColumnElement to make the type checker happy. + return cast(ColumnElement[Any], getattr(model_cls, self.sorted_by)) - def apply_sorting_direction(self, column: ColumnElement) -> ColumnElement: + @override + def apply_sorting_direction(self, column: ColumnElement[T]) -> ColumnElement[T]: """ Applies the sorting direction to an SQLAlchemy column element, i.e. `column.asc()` or `column.desc()`, and returns the new column element. """ return column.desc() if self.sorting_direction is SortingDirection.DESC else column.asc() - def apply_sorting_to_query(self, query: Query, model_cls: Any) -> Query: + @override + def apply_sorting_to_query(self, query: Query[T], model_cls: Any) -> Query[T]: """ Applies the sorting parameters to an SQLAlchemy query (`query.order_by()`) and returns the new query. The "model_cls" parameter should be the class of the database model that is queried. It is needed to get the sorting column from the model. """ - # If we want to disable sorting for some reason - if self.sorted_by is None: - return query + # If someone wants to disable sorting for some reason + if self.sorted_by is None: # type: ignore[comparison-overlap] + return query # type: ignore[unreachable] sorting_column = self.get_sorting_column(model_cls) return query.order_by(self.apply_sorting_direction(sorting_column)) diff --git a/src/validataclass_search_queries/validators/__init__.py b/src/validataclass_search_queries/validators/__init__.py index 8914157..7d7844d 100644 --- a/src/validataclass_search_queries/validators/__init__.py +++ b/src/validataclass_search_queries/validators/__init__.py @@ -8,3 +8,10 @@ from .multi_select_enum_validator import MultiSelectEnumValidator from .multi_select_integer_validator import MultiSelectIntegerValidator from .multi_select_validator import MultiSelectValidator + +__all__ = [ + 'MultiSelectAnyOfValidator', + 'MultiSelectEnumValidator', + 'MultiSelectIntegerValidator', + 'MultiSelectValidator', +] diff --git a/src/validataclass_search_queries/validators/multi_select_any_of_validator.py b/src/validataclass_search_queries/validators/multi_select_any_of_validator.py index 1ea57af..d08ec93 100644 --- a/src/validataclass_search_queries/validators/multi_select_any_of_validator.py +++ b/src/validataclass_search_queries/validators/multi_select_any_of_validator.py @@ -5,7 +5,7 @@ """ from collections.abc import Iterable -from typing import Any +from typing import TypeVar from validataclass.validators import AnyOfValidator @@ -15,8 +15,10 @@ 'MultiSelectAnyOfValidator', ] +T_AnyOfValues = TypeVar('T_AnyOfValues') -class MultiSelectAnyOfValidator(MultiSelectValidator): + +class MultiSelectAnyOfValidator(MultiSelectValidator[T_AnyOfValues]): """ Validator for multi-select search parameters that only allows a specified set of values. @@ -26,7 +28,7 @@ class MultiSelectAnyOfValidator(MultiSelectValidator): def __init__( self, # AnyOfValidator settings - allowed_values: Iterable[Any], + allowed_values: Iterable[T_AnyOfValues], # TODO: case_insensitive is deprecated in validataclass and must be removed in a future version. case_sensitive: bool | None = None, case_insensitive: bool | None = None, diff --git a/src/validataclass_search_queries/validators/multi_select_enum_validator.py b/src/validataclass_search_queries/validators/multi_select_enum_validator.py index 689de06..98ac81a 100644 --- a/src/validataclass_search_queries/validators/multi_select_enum_validator.py +++ b/src/validataclass_search_queries/validators/multi_select_enum_validator.py @@ -29,7 +29,7 @@ class MultiSelectEnumValidator(MultiSelectValidator[T_Enum]): def __init__( self, # EnumValidator settings - enum_cls: type[Enum], + enum_cls: type[T_Enum], *, allowed_values: Iterable[Any] | None = None, # TODO: case_insensitive is deprecated in validataclass and must be removed in a future version. diff --git a/src/validataclass_search_queries/validators/multi_select_validator.py b/src/validataclass_search_queries/validators/multi_select_validator.py index 7467e53..aba7ab9 100644 --- a/src/validataclass_search_queries/validators/multi_select_validator.py +++ b/src/validataclass_search_queries/validators/multi_select_validator.py @@ -6,6 +6,7 @@ from typing import Any, TypeVar +from typing_extensions import override from validataclass.validators import ListValidator, Validator __all__ = [ @@ -43,7 +44,7 @@ class MultiSelectValidator(ListValidator[T_ListItem]): def __init__( self, - item_validator: Validator, + item_validator: Validator[T_ListItem], *, delimiter: str = ',', max_length: int | None = None, @@ -64,7 +65,8 @@ def __init__( ) self.delimiter = delimiter - def validate(self, input_data: Any, **kwargs) -> list[T_ListItem]: + @override + def validate(self, input_data: Any, **kwargs: Any) -> list[T_ListItem]: """ Validate input data as string. Returns a validated list. """ diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index be44628..f98815c 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -6,3 +6,8 @@ from .assertions import assert_column_element from .mocks import UnitTestEnum + +__all__ = [ + 'assert_column_element', + 'UnitTestEnum', +] diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py index b7b52f8..59af22c 100644 --- a/tests/helpers/assertions.py +++ b/tests/helpers/assertions.py @@ -9,7 +9,7 @@ from sqlalchemy.sql import ColumnElement -def assert_column_element(element: Any, expected_string: str, *expected_params) -> None: +def assert_column_element(element: Any, expected_string: str, *expected_params: Any) -> None: """ Helper function to check the SQL and bound parameters of a generated ColumnElement. """ assert isinstance(element, ColumnElement) compiled_expr = element.compile() diff --git a/tests/unit/filters/search_params/conftest.py b/tests/unit/filters/search_params/conftest.py index fc12de3..96daf05 100644 --- a/tests/unit/filters/search_params/conftest.py +++ b/tests/unit/filters/search_params/conftest.py @@ -4,11 +4,13 @@ Use of this source code is governed by an MIT-style license that can be found in the LICENSE file. """ +from typing import Any + import pytest import sqlalchemy from sqlalchemy.sql import ColumnElement @pytest.fixture -def sqlalchemy_column() -> ColumnElement: +def sqlalchemy_column() -> ColumnElement[Any]: return sqlalchemy.column('unit_test_column') diff --git a/tests/unit/filters/search_params/search_param_boolean_test.py b/tests/unit/filters/search_params/search_param_boolean_test.py index 1f5f1a5..fd5b0a5 100644 --- a/tests/unit/filters/search_params/search_param_boolean_test.py +++ b/tests/unit/filters/search_params/search_param_boolean_test.py @@ -48,5 +48,13 @@ def test_search_param_is_not_none(sqlalchemy_column): def test_search_param_ternary(sqlalchemy_column): """ Test the SearchParamTernary search parameter. """ param = SearchParamTernary('yes', 'no') - assert_column_element(param.sqlalchemy_filter(sqlalchemy_column, True), 'unit_test_column = :unit_test_column_1', 'yes') - assert_column_element(param.sqlalchemy_filter(sqlalchemy_column, False), 'unit_test_column = :unit_test_column_1', 'no') + assert_column_element( + param.sqlalchemy_filter(sqlalchemy_column, True), + 'unit_test_column = :unit_test_column_1', + 'yes', + ) + assert_column_element( + param.sqlalchemy_filter(sqlalchemy_column, False), + 'unit_test_column = :unit_test_column_1', + 'no', + ) diff --git a/tests/unit/filters/search_params/search_param_substring_test.py b/tests/unit/filters/search_params/search_param_substring_test.py index e704f95..5b47e5e 100644 --- a/tests/unit/filters/search_params/search_param_substring_test.py +++ b/tests/unit/filters/search_params/search_param_substring_test.py @@ -12,6 +12,9 @@ # Test data for SearchParamContains, SearchParamStartsWidth, SearchParamEndsWidth (substring matching search filters) # (Parameters: input_value, expected_param) test_data_substring_matches = [ + # Empty string + ('', ''), + # Simple string ('banana', 'banana'), @@ -25,24 +28,30 @@ def test_search_param_contains(sqlalchemy_column, input_value, expected_param): """ Test the SearchParamContains search parameter. """ search_filter = SearchParamContains().sqlalchemy_filter(sqlalchemy_column, input_value) - assert_column_element(search_filter, "unit_test_column LIKE '%' || :unit_test_column_1 || '%' ESCAPE '/'", expected_param) + assert_column_element( + search_filter, + "unit_test_column LIKE '%' || :unit_test_column_1 || '%' ESCAPE '/'", + expected_param, + ) @pytest.mark.parametrize('input_value, expected_param', test_data_substring_matches) def test_search_param_starts_with(sqlalchemy_column, input_value, expected_param): """ Test the SearchParamStartsWith search parameter. """ search_filter = SearchParamStartsWith().sqlalchemy_filter(sqlalchemy_column, input_value) - assert_column_element(search_filter, "unit_test_column LIKE :unit_test_column_1 || '%' ESCAPE '/'", expected_param) + assert_column_element( + search_filter, + "unit_test_column LIKE :unit_test_column_1 || '%' ESCAPE '/'", + expected_param, + ) @pytest.mark.parametrize('input_value, expected_param', test_data_substring_matches) def test_search_param_ends_with(sqlalchemy_column, input_value, expected_param): """ Test the SearchParamEndsWith search parameter. """ search_filter = SearchParamEndsWith().sqlalchemy_filter(sqlalchemy_column, input_value) - assert_column_element(search_filter, "unit_test_column LIKE '%' || :unit_test_column_1 ESCAPE '/'", expected_param) - - -@pytest.mark.parametrize('search_param_cls', [SearchParamContains, SearchParamStartsWith, SearchParamEndsWith]) -def test_search_param_substring_matching_shortcircuit(sqlalchemy_column, search_param_cls): - """ Test that SearchParamContains, SearchParamStartsWith and SearchParamEndsWith short-circuit if the value is empty. """ - assert search_param_cls().sqlalchemy_filter(sqlalchemy_column, '') is sqlalchemy_column + assert_column_element( + search_filter, + "unit_test_column LIKE '%' || :unit_test_column_1 ESCAPE '/'", + expected_param, + ) diff --git a/tests/unit/pagination/paginated_result_test.py b/tests/unit/pagination/paginated_result_test.py index 7170371..8893b65 100644 --- a/tests/unit/pagination/paginated_result_test.py +++ b/tests/unit/pagination/paginated_result_test.py @@ -4,7 +4,10 @@ Use of this source code is governed by an MIT-style license that can be found in the LICENSE file. """ +from typing import Any + import pytest +from typing_extensions import override from validataclass_search_queries.pagination import PaginatedResult @@ -15,6 +18,7 @@ class MockItem: def __init__(self, name: str): self.name = name + @override def __eq__(self, other): return type(self) is type(other) and self.name == other.name @@ -22,7 +26,7 @@ def __eq__(self, other): class MockItemToDictable(MockItem): """ Variation of MockItem that has a to_dict() method. """ - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: return {'name': self.name} @@ -33,14 +37,14 @@ def __init__(self, name: str): self.name = name @staticmethod - def map_static(item) -> str: + def map_static(item: MockItem) -> str: return str(item) @classmethod - def map_class(cls, item) -> str: + def map_class(cls, item: MockItem) -> str: return f'[{cls.__name__}] {item}' - def map_instance(self, item) -> str: + def map_instance(self, item: MockItem) -> str: return f'[{self.name}] {item}' @@ -83,7 +87,9 @@ def test_paginated_result(input_list, total_count): ] ) def test_paginated_result_to_dict_basic_types(paginated_result, expected_dict): - """ Test PaginatedResult.to_dict() with basic types and objects without to_dict() method (recursive and non-recursive). """ + """ + Test PaginatedResult.to_dict() with basic types and objects without to_dict() method (recursive and non-recursive). + """ assert paginated_result.to_dict() == expected_dict assert paginated_result.to_dict(recursive=False) == expected_dict assert paginated_result.to_dict(recursive=True) == expected_dict diff --git a/tests/unit/pagination/response_helpers_test.py b/tests/unit/pagination/response_helpers_test.py index 5785443..5709d74 100644 --- a/tests/unit/pagination/response_helpers_test.py +++ b/tests/unit/pagination/response_helpers_test.py @@ -5,10 +5,16 @@ """ from dataclasses import dataclass +from typing import Any import pytest -from validataclass_search_queries.pagination import CursorPaginationMixin, OffsetPaginationMixin, PaginatedResult, paginated_api_response +from validataclass_search_queries.pagination import ( + CursorPaginationMixin, + OffsetPaginationMixin, + PaginatedResult, + paginated_api_response, +) from validataclass_search_queries.search_queries import BaseSearchQuery, search_query_dataclass @@ -17,7 +23,7 @@ class MockItem: """ Object used to test the response helper functions. """ id: int - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: return {'id': self.id} @@ -41,7 +47,7 @@ class ExampleQueryOffsetPagination(OffsetPaginationMixin, BaseSearchQuery): { 'items': [1, 3, 1, 2], 'total_count': 10, - } + }, ), ( # Empty result (implies last page) @@ -50,7 +56,7 @@ class ExampleQueryOffsetPagination(OffsetPaginationMixin, BaseSearchQuery): { 'items': [], 'total_count': 10, - } + }, ), ( # Full page with cursor pagination @@ -60,7 +66,7 @@ class ExampleQueryOffsetPagination(OffsetPaginationMixin, BaseSearchQuery): 'items': [{'id': 13}, {'id': 37}, {'id': 41}], 'total_count': 10, 'next_id': 42, - } + }, ), ( # Full page with offset pagination @@ -70,7 +76,7 @@ class ExampleQueryOffsetPagination(OffsetPaginationMixin, BaseSearchQuery): 'items': [{'id': 13}, {'id': 37}, {'id': 41}], 'total_count': 10, 'next_offset': 9, - } + }, ), ( # Non-full page (implies last page) @@ -79,9 +85,9 @@ class ExampleQueryOffsetPagination(OffsetPaginationMixin, BaseSearchQuery): { 'items': [{'id': 99}], 'total_count': 10, - } + }, ), - ] + ], ) def test_paginated_api_response(paginated_result, search_query, expected_response): """ Test paginated_api_response() with different search queries and results. """ @@ -104,7 +110,7 @@ def test_paginated_api_response(paginated_result, search_query, expected_respons 'total_count': 10, 'next_id': 42, 'next_path': '/unit/test?start=42&limit=3', - } + }, ), ( # Full page with offset pagination @@ -116,7 +122,7 @@ def test_paginated_api_response(paginated_result, search_query, expected_respons 'total_count': 10, 'next_offset': 9, 'next_path': '/unit/test?offset=9&limit=3', - } + }, ), ( # With original parameters (numbers are strings here, like in HTTP query parameters) @@ -128,9 +134,9 @@ def test_paginated_api_response(paginated_result, search_query, expected_respons 'total_count': 10, 'next_offset': 9, 'next_path': '/unit/test?foo=bar&limit=3&offset=9&something=else', - } + }, ), - ] + ], ) def test_paginated_api_response_with_next_path(paginated_result, search_query, original_params, expected_response): """ Test paginated_api_response() with request_path and original_params to generate the "next_path" field. """ diff --git a/tests/unit/search_queries/search_query_dataclass_test.py b/tests/unit/search_queries/search_query_dataclass_test.py index 128998a..e6cc16c 100644 --- a/tests/unit/search_queries/search_query_dataclass_test.py +++ b/tests/unit/search_queries/search_query_dataclass_test.py @@ -329,37 +329,37 @@ def test_search_query_dataclass_with_invalid_values(): with pytest.raises(DataclassValidatorFieldException) as exception_info: @search_query_dataclass class InvalidSearchQueryDataclass: - foo: int + foo: int # type: ignore[validataclass] - assert str(exception_info.value) == 'Dataclass field "foo" must specify a Validator.' + assert str(exception_info.value) == 'Dataclass field "foo" must specify a validator.' @pytest.mark.parametrize( 'field_tuple, expected_exception_msg', [ # Missing validator - (None, 'Dataclass field "foo" must specify a Validator.'), - ((Default(3)), 'Dataclass field "foo" must specify a Validator.'), - ((SearchParamEquals(), Default(0)), 'Dataclass field "foo" must specify a Validator.'), + (None, 'Dataclass field "foo" must specify a validator.'), + ((Default(3)), 'Dataclass field "foo" must specify a validator.'), + ((SearchParamEquals(), Default(0)), 'Dataclass field "foo" must specify a validator.'), # Too many validators ( (IntegerValidator(), StringValidator()), - 'Dataclass field "foo": Only one Validator can be specified.', + 'Dataclass field "foo": Only one validator can be specified.', ), ( (SearchParamEquals(), IntegerValidator(), IntegerValidator()), - 'Dataclass field "foo": Only one Validator can be specified.', + 'Dataclass field "foo": Only one validator can be specified.', ), # Too many defaults ( (Default(1), IntegerValidator(), Default(2)), - 'Dataclass field "foo": Only one Default can be specified.', + 'Dataclass field "foo": Only one default can be specified.', ), ( (Default(1), SearchParamEquals(), IntegerValidator(), Default(2)), - 'Dataclass field "foo": Only one Default can be specified.', + 'Dataclass field "foo": Only one default can be specified.', ), # Too many SearchParams @@ -378,6 +378,6 @@ def test_search_query_dataclass_with_invalid_field_tuples(field_tuple, expected_ with pytest.raises(DataclassValidatorFieldException) as exception_info: @search_query_dataclass class InvalidSearchQueryDataclass: - foo: int = field_tuple + foo: int = field_tuple # type: ignore[validataclass] assert str(exception_info.value) == expected_exception_msg diff --git a/tests/unit/sorting/sorting_mixin_test.py b/tests/unit/sorting/sorting_mixin_test.py index 1898063..a119fa5 100644 --- a/tests/unit/sorting/sorting_mixin_test.py +++ b/tests/unit/sorting/sorting_mixin_test.py @@ -7,6 +7,7 @@ import pytest import sqlalchemy from sqlalchemy.sql import ColumnElement +from sqlalchemy.sql.elements import ColumnClause from validataclass.dataclasses import validataclass, Default from validataclass.exceptions import DictFieldsValidationError from validataclass.validators import DataclassValidator, AnyOfValidator @@ -16,16 +17,16 @@ class MockModelCls: """ This class is used as a mock for a database model class. """ - id = sqlalchemy.column('id') - unit_test_field = sqlalchemy.column('unit_test_field') + id: ColumnClause[int] = sqlalchemy.column('id') + test_field: ColumnClause[str] = sqlalchemy.column('test_field') def test_sorting_mixin_get_sorting_column(): """ Test SortingMixin.get_sorting_column() on its own. """ # It's supposed to be used as a mixin, but it should function on its own, too. # (Also, we bypass the validators here by creating the object directly, so we can test with any sorted_by key.) - sorting_mixin = SortingMixin(sorted_by='unit_test_field', sorting_direction=SortingDirection.DESC) - assert sorting_mixin.get_sorting_column(MockModelCls) is MockModelCls.unit_test_field + sorting_mixin = SortingMixin(sorted_by='test_field', sorting_direction=SortingDirection.DESC) + assert sorting_mixin.get_sorting_column(MockModelCls) is MockModelCls.test_field @pytest.mark.parametrize( @@ -37,11 +38,11 @@ def test_sorting_mixin_get_sorting_column(): ) def test_sorting_mixin_apply_sorting_direction(sorting_direction, expected_dir_str): """ Test SortingMixin.apply_sorting_direction() on its own. """ - sorting_mixin = SortingMixin(sorted_by='unit_test_field', sorting_direction=sorting_direction) - order_column = sorting_mixin.apply_sorting_direction(sqlalchemy.column('custom_column')) + sorting_mixin = SortingMixin(sorted_by='test_field', sorting_direction=sorting_direction) + order_column = sorting_mixin.apply_sorting_direction(MockModelCls.test_field) assert isinstance(order_column, ColumnElement) - assert str(order_column) == f'custom_column {expected_dir_str}' + assert str(order_column) == f'test_field {expected_dir_str}' # TODO: Tests for SortingMixin.apply_sorting_to_query() @@ -108,23 +109,23 @@ def test_sorting_mixin_with_validation_invalid(): 'query_input, expected_column, expected_order_column_str', [ # Defaults - ({}, MockModelCls.unit_test_field, 'unit_test_field DESC'), + ({}, MockModelCls.test_field, 'test_field DESC'), ({'sorted_by': 'id'}, MockModelCls.id, 'id DESC'), - ({'sorting_direction': 'ASC'}, MockModelCls.unit_test_field, 'unit_test_field ASC'), + ({'sorting_direction': 'ASC'}, MockModelCls.test_field, 'test_field ASC'), # Explicit values ({'sorted_by': 'id', 'sorting_direction': 'ASC'}, MockModelCls.id, 'id ASC'), ({'sorted_by': 'id', 'sorting_direction': 'DESC'}, MockModelCls.id, 'id DESC'), - ({'sorted_by': 'unit_test_field', 'sorting_direction': 'ASC'}, MockModelCls.unit_test_field, 'unit_test_field ASC'), - ({'sorted_by': 'unit_test_field', 'sorting_direction': 'DESC'}, MockModelCls.unit_test_field, 'unit_test_field DESC'), - ] + ({'sorted_by': 'test_field', 'sorting_direction': 'ASC'}, MockModelCls.test_field, 'test_field ASC'), + ({'sorted_by': 'test_field', 'sorting_direction': 'DESC'}, MockModelCls.test_field, 'test_field DESC'), + ], ) def test_dataclass_with_sorting_mixin_with_validation(query_input, expected_column, expected_order_column_str): """ Test a dataclass that uses and customizes the SortingMixin with validation. """ @validataclass class UnitTestSortingQuery(SortingMixin): - sorted_by: str = AnyOfValidator(['id', 'unit_test_field']), Default('unit_test_field') + sorted_by: str = AnyOfValidator(['id', 'test_field']), Default('test_field') sorting_direction: SortingDirection = Default(SortingDirection.DESC) query_validator = DataclassValidator(UnitTestSortingQuery) diff --git a/tox.ini b/tox.ini index b8c56ea..e31e24c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,15 +1,12 @@ [tox] minversion = 4.5.1 -envlist = clean,py{310,311,312}-sqlalchemy{1.4,2.0},report,flake8 +envlist = clean,py{310,311,312,313,314}-sqlalchemy{1.4,2.0},report,flake8,mypy skip_missing_interpreters = true [flake8] -max-line-length = 140 +max-line-length = 120 exclude = _version.py ignore = -per-file-ignores = - # False positives for "unused imports" in __init__.py - __init__.py: F401 [testenv] extras = testing @@ -18,10 +15,17 @@ commands = python -m pytest --cov --cov-append {posargs} [testenv:flake8] commands = flake8 src/ tests/ +[testenv:mypy,py{310,311,312,313,314}-mypy] +commands = mypy {posargs} + +[testenv:mypy-debug] +# Use no-incremental to disable mypy caching when developing the mypy plugin +commands = mypy --show-traceback --no-incremental {posargs} + [testenv:clean] commands = coverage erase -[testenv:report,py{310,311,312}-report] +[testenv:report,py{310,311,312,313,314}-report] commands = coverage html coverage xml