From ac19cd557fef1721c6e32f8479e6fbc207cf80d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20G=C3=B3recki?= Date: Fri, 20 Feb 2026 23:31:25 +0100 Subject: [PATCH] Bump to 0.13.1, add comprehensive type annotations (issue #5) Add mypy config with disallow_untyped_defs, annotate all functions across message.py, utils.py, compositon.py, dependency_provider.py, transaction_context.py, application_module.py, application.py, and testing.py. Add @overload for handler() decorator. Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 8 ++++ lato/__init__.py | 3 +- lato/application.py | 36 ++++++++------ lato/application_module.py | 37 ++++++++++----- lato/compositon.py | 6 +-- lato/dependency_provider.py | 34 ++++++++------ lato/message.py | 2 +- lato/testing.py | 19 +++++--- lato/transaction_context.py | 94 +++++++++++++++++++++---------------- lato/utils.py | 13 ++--- pyproject.toml | 14 +++++- 11 files changed, 164 insertions(+), 102 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0b66d5..addf2e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Change Log +## [0.13.1] - 2026-02-20 + +### Added + +- Comprehensive type annotations across all modules (issue #5) +- `[tool.mypy]` configuration in `pyproject.toml` with `disallow_untyped_defs`, `disallow_incomplete_defs`, `check_untyped_defs`, `no_implicit_optional`, `warn_unused_ignores`, `warn_redundant_casts` +- `@overload` signatures for `ApplicationModule.handler()` decorator to preserve decorated function types + ## [0.13.0] - 2025-02-20 ### Breaking Changes diff --git a/lato/__init__.py b/lato/__init__.py index f809349..eef9b70 100644 --- a/lato/__init__.py +++ b/lato/__init__.py @@ -1,5 +1,4 @@ import logging -import typing from logging import NullHandler from .application import Application @@ -14,7 +13,7 @@ from .message import Command, Event, Query from .transaction_context import TransactionContext -__version__ = "0.13.0" +__version__ = "0.13.1" __all__ = [ "Application", "ApplicationModule", diff --git a/lato/application.py b/lato/application.py index 606c721..175a30a 100644 --- a/lato/application.py +++ b/lato/application.py @@ -1,6 +1,6 @@ import logging from collections.abc import Awaitable, Callable -from typing import Any, Optional, Union +from typing import Any, Optional, TypeVar, Union from lato.application_module import ApplicationModule from lato.dependency_provider import BasicDependencyProvider, DependencyProvider @@ -18,6 +18,8 @@ log = logging.getLogger(__name__) +F = TypeVar("F", bound=Callable[..., Any]) + class Application(ApplicationModule): """Core Application class. @@ -32,10 +34,10 @@ class Application(ApplicationModule): def __init__( self, - name=__name__, + name: str = __name__, dependency_provider: Optional[DependencyProvider] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: """Initialize the application instance. :param name: Name of the application @@ -48,7 +50,9 @@ def __init__( self.dependency_provider = ( dependency_provider or self.dependency_provider_factory(**kwargs) ) - self._transaction_context_factory: Optional[Callable] = None + self._transaction_context_factory: Optional[ + Callable[..., TransactionContext] + ] = None self._on_enter_transaction_context: Optional[ OnEnterTransactionContextCallback ] = None @@ -56,7 +60,7 @@ def __init__( OnExitTransactionContextCallback ] = None self._transaction_middlewares: list[MiddlewareFunction] = [] - self._composers: dict[Union[Message, str], ComposerFunction] = {} + self._composers: dict[Union[type[Message], str], ComposerFunction] = {} def get_dependency(self, identifier: DependencyIdentifier) -> Any: """Gets a dependency from the dependency provider. Dependencies can be resolved either by name or by type. @@ -70,7 +74,9 @@ def get_dependency(self, identifier: DependencyIdentifier) -> Any: def __getitem__(self, identifier: DependencyIdentifier) -> Any: return self.get_dependency(identifier) - def call(self, func: Union[Callable[..., Any], str], *args, **kwargs) -> Any: + def call( + self, func: Union[Callable[..., Any], str], *args: Any, **kwargs: Any + ) -> Any: """Invokes a function with `args` and `kwargs` within the :class:`TransactionContext`. If `func` is a string, then it is an alias, and the corresponding handler for the alias is retrieved. Any missing arguments are provided by the dependency provider of a transaction context, @@ -96,7 +102,7 @@ def call(self, func: Union[Callable[..., Any], str], *args, **kwargs) -> Any: return result async def call_async( - self, func: Union[Callable[..., Awaitable[Any]], str], *args, **kwargs + self, func: Union[Callable[..., Awaitable[Any]], str], *args: Any, **kwargs: Any ) -> Any: """Invokes an async function with `args` and `kwargs` within the :class:`TransactionContext`. If `func` is a string, then it is an alias, and the corresponding handler for the alias is retrieved. @@ -178,7 +184,7 @@ async def publish_async(self, event: Event) -> dict[MessageHandler, Any]: result = await ctx.publish_async(event) return result - def on_enter_transaction_context(self, func): + def on_enter_transaction_context(self, func: F) -> F: """ Decorator for registering a function to be called when entering a transaction context @@ -201,7 +207,7 @@ def on_enter_transaction_context(self, func): self._on_enter_transaction_context = func return func - def on_exit_transaction_context(self, func): + def on_exit_transaction_context(self, func: F) -> F: """ Decorator for registering a function to be called when exiting a transaction context @@ -223,7 +229,7 @@ def on_exit_transaction_context(self, func): self._on_exit_transaction_context = func return func - def on_create_transaction_context(self, func): + def on_create_transaction_context(self, func: F) -> F: """ Decorator for overriding default transaction context creation @@ -248,7 +254,7 @@ def on_create_transaction_context(self, func): self._transaction_context_factory = func return func - def transaction_middleware(self, middleware_func): + def transaction_middleware(self, middleware_func: F) -> F: """ Decorator for registering a middleware function to be called when executing a function in a transaction context :param middleware_func: @@ -269,7 +275,7 @@ def transaction_middleware(self, middleware_func): self._transaction_middlewares.append(middleware_func) return middleware_func - def compose(self, alias): + def compose(self, alias: Any) -> Callable[[F], F]: """ Decorator for composing results of handlers identified by an alias. @@ -287,13 +293,13 @@ def compose(self, alias): ... ... """ - def decorator(func): + def decorator(func: F) -> F: self._composers[alias] = func return func return decorator - def transaction_context(self, **dependencies) -> TransactionContext: + def transaction_context(self, **dependencies: Any) -> TransactionContext: """Creates a transaction context for the app. The lifecycle of a transaction context is controlled by :func:`transaction_middleware`, diff --git a/lato/application_module.py b/lato/application_module.py index c569951..2b36305 100644 --- a/lato/application_module.py +++ b/lato/application_module.py @@ -1,6 +1,7 @@ import logging from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Iterator +from typing import Any, TypeVar, overload from lato.exceptions import DuplicateHandlerError from lato.message import Command, Message, Query @@ -10,24 +11,26 @@ log = logging.getLogger(__name__) +F = TypeVar("F", bound=Callable[..., Any]) + class ApplicationModule: - def __init__(self, name: str): + def __init__(self, name: str) -> None: """Initialize the application module instance. :param name: Name of the module """ self.name: str = name - self._handlers: defaultdict[HandlerAlias, OrderedSet[Callable]] = defaultdict( - OrderedSet - ) - self._submodules: OrderedSet[ApplicationModule] = OrderedSet() + self._handlers: defaultdict[ + HandlerAlias, OrderedSet[Callable[..., Any]] + ] = defaultdict(OrderedSet) + self._submodules: OrderedSet["ApplicationModule"] = OrderedSet() @property - def identifier(self): + def identifier(self) -> str: return string_to_kwarg_name(self.name) - def include_submodule(self, a_module: "ApplicationModule"): + def include_submodule(self, a_module: "ApplicationModule") -> None: """Adds a child submodule to this module. :param a_module: child module to add @@ -37,7 +40,15 @@ def include_submodule(self, a_module: "ApplicationModule"): ), f"Can only include {ApplicationModule} instances, got {a_module}" self._submodules.add(a_module) - def handler(self, alias: HandlerAlias) -> Callable: + @overload + def handler(self, alias: F) -> F: + ... + + @overload + def handler(self, alias: HandlerAlias) -> Callable[[F], F]: + ... + + def handler(self, alias: Any) -> Any: """ Decorator for registering a handler. Handler can be aliased by a name or by a message type. @@ -89,7 +100,7 @@ def handler(self, alias: HandlerAlias) -> Callable: # decorator was called with argument # @my_module.handle("my_function") # @my_module.handle(MyCommand) - def decorator(func): + def decorator(func: F) -> F: """ Decorator for registering tasks by name """ @@ -107,7 +118,7 @@ def decorator(func): return decorator - def iterate_handlers_for(self, alias: str): + def iterate_handlers_for(self, alias: HandlerAlias) -> Iterator[MessageHandler]: if alias in self._handlers: for handler in self._handlers[alias]: yield MessageHandler(source=self.identifier, message=alias, fn=handler) @@ -117,8 +128,8 @@ def iterate_handlers_for(self, alias: str): except KeyError: pass - def get_handlers_for(self, alias: str): + def get_handlers_for(self, alias: HandlerAlias) -> list[MessageHandler]: return list(self.iterate_handlers_for(alias)) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.name} {object.__repr__(self)}>" diff --git a/lato/compositon.py b/lato/compositon.py index c3900bf..d083251 100644 --- a/lato/compositon.py +++ b/lato/compositon.py @@ -1,14 +1,14 @@ from collections.abc import Callable from functools import partial, reduce from operator import add, or_ -from typing import Optional +from typing import Any, Optional -from mergedeep import Strategy, merge # type: ignore +from mergedeep import Strategy, merge additive_merge = partial(merge, strategy=Strategy.TYPESAFE_ADDITIVE) -def compose(compose_operator: Optional[Callable] = None, **kwargs): +def compose(compose_operator: Optional[Callable] = None, **kwargs: Any) -> Any: values = tuple(value for module_name, value in kwargs.items() if value is not None) if len(values) == 0: diff --git a/lato/dependency_provider.py b/lato/dependency_provider.py index acf3590..988040e 100644 --- a/lato/dependency_provider.py +++ b/lato/dependency_provider.py @@ -1,15 +1,15 @@ import inspect from abc import ABC, abstractmethod +from collections import OrderedDict from collections.abc import Callable from typing import Any from lato.exceptions import UnknownDependencyError from lato.types import DependencyIdentifier -from lato.utils import OrderedDict class TypedDependency: - def __init__(self, value, a_type): + def __init__(self, value: Any, a_type: type) -> None: self.value = value self.a_type = a_type @@ -18,7 +18,7 @@ def as_type(obj: Any, cls: type) -> TypedDependency: return TypedDependency(obj, cls) -def get_function_parameters(func) -> OrderedDict: +def get_function_parameters(func: Callable[..., Any]) -> OrderedDict[str, Any]: """ Retrieve the function's parameters and their annotations. @@ -27,7 +27,7 @@ def get_function_parameters(func) -> OrderedDict: """ handler_signature = inspect.signature(func) kwargs_iterator = iter(handler_signature.parameters.items()) - parameters = OrderedDict() + parameters: OrderedDict[str, Any] = OrderedDict() for name, param in kwargs_iterator: parameters[name] = param.annotation return parameters @@ -55,7 +55,9 @@ def has_dependency(self, identifier: DependencyIdentifier) -> bool: raise NotImplementedError() @abstractmethod - def register_dependency(self, identifier: DependencyIdentifier, dependency: Any): + def register_dependency( + self, identifier: DependencyIdentifier, dependency: Any + ) -> None: """ Register a dependency with a given identifier (name or type). @@ -74,7 +76,7 @@ def get_dependency(self, identifier: DependencyIdentifier) -> Any: raise NotImplementedError() @abstractmethod - def copy(self, *args, **kwargs) -> "DependencyProvider": + def copy(self, *args: Any, **kwargs: Any) -> "DependencyProvider": """Creates a copy of self with updated dependencies. :param args: dependencies to update, identified by type. @@ -82,7 +84,7 @@ def copy(self, *args, **kwargs) -> "DependencyProvider": :return: A copy of the dependency provider. """ - def update(self, *args, **kwargs): + def update(self, *args: Any, **kwargs: Any) -> None: """ Updates the dependency provider with new dependencies. @@ -101,7 +103,7 @@ def update(self, *args, **kwargs): if self.allow_types: self.register_dependency(t, v) - def _get_type_and_value(self, value): + def _get_type_and_value(self, value: Any) -> tuple[type, Any]: if isinstance(value, TypedDependency): return value.a_type, value.value return type(value), value @@ -128,7 +130,7 @@ def resolve_func_params( func_kwargs = {} func_parameters = get_function_parameters(func) - resolved_kwargs = OrderedDict() + resolved_kwargs: OrderedDict[str, Any] = OrderedDict() arg_idx = 0 for param_name, param_type in func_parameters.items(): if arg_idx < len(func_args): @@ -151,10 +153,10 @@ def resolve_func_params( return resolved_kwargs - def __getitem__(self, key): + def __getitem__(self, key: DependencyIdentifier) -> Any: return self.get_dependency(key) - def __setitem__(self, key, value): + def __setitem__(self, key: DependencyIdentifier, value: Any) -> None: self.register_dependency(key, value) @@ -164,16 +166,18 @@ class BasicDependencyProvider(DependencyProvider): dependency injection based on type or parameter name. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the DependencyProvider. :param args: Class instances to be registered by types :param kwargs: Dependencies to be registered by types and with explicit names """ - self._dependencies = {} + self._dependencies: dict[DependencyIdentifier, Any] = {} self.update(*args, **kwargs) - def register_dependency(self, identifier: DependencyIdentifier, dependency: Any): + def register_dependency( + self, identifier: DependencyIdentifier, dependency: Any + ) -> None: """ Register a dependency with a given identifier (name or type). @@ -206,7 +210,7 @@ def get_dependency(self, identifier: DependencyIdentifier) -> Any: except KeyError as e: raise UnknownDependencyError(identifier) - def copy(self, *args, **kwargs) -> DependencyProvider: + def copy(self, *args: Any, **kwargs: Any) -> "BasicDependencyProvider": """ Create a copy of self with updated dependencies. :param args: typed overrides diff --git a/lato/message.py b/lato/message.py index db7a815..6e11002 100644 --- a/lato/message.py +++ b/lato/message.py @@ -6,7 +6,7 @@ class Message(BaseModel): id: UUID = Field(default_factory=uuid4, alias="id") - def get_alias(self): + def get_alias(self) -> type["Message"]: return self.__class__ diff --git a/lato/testing.py b/lato/testing.py index d091df1..488ad85 100644 --- a/lato/testing.py +++ b/lato/testing.py @@ -1,11 +1,14 @@ import contextlib from collections.abc import Iterator +from typing import Any from lato import Application @contextlib.contextmanager -def override_app(application: Application, *args, **kwargs) -> Iterator[Application]: +def override_app( + application: Application, *args: Any, **kwargs: Any +) -> Iterator[Application]: original_dependency_provider = application.dependency_provider overridden_dependency_provider = original_dependency_provider.copy(*args, **kwargs) @@ -16,22 +19,26 @@ def override_app(application: Application, *args, **kwargs) -> Iterator[Applicat @contextlib.contextmanager -def override_ctx(application: Application, *args, **kwargs) -> Iterator[Application]: +def override_ctx( + application: Application, *args: Any, **kwargs: Any +) -> Iterator[Application]: original_transaction_context = application.transaction_context - def overriden_transaction_context(**dependencies): + def overriden_transaction_context(**dependencies: Any) -> Any: ctx = original_transaction_context(**dependencies) ctx.dependency_provider = ctx.dependency_provider.copy(*args, **kwargs) return ctx - application.transaction_context = overriden_transaction_context # type: ignore + application.transaction_context = overriden_transaction_context # type: ignore[method-assign] yield application - application.transaction_context = original_transaction_context # type: ignore + application.transaction_context = original_transaction_context # type: ignore[method-assign] @contextlib.contextmanager -def override(application: Application, *args, **kwargs) -> Iterator[Application]: +def override( + application: Application, *args: Any, **kwargs: Any +) -> Iterator[Application]: with override_app(application, **kwargs) as overridden1: with override_ctx(overridden1, **kwargs) as overridden2: yield overridden2 diff --git a/lato/transaction_context.py b/lato/transaction_context.py index 417ade9..d7e6cd9 100644 --- a/lato/transaction_context.py +++ b/lato/transaction_context.py @@ -4,6 +4,7 @@ from collections.abc import Awaitable, Callable, Iterator from dataclasses import dataclass from functools import partial +from types import TracebackType from typing import Any, Optional, Union from lato.compositon import compose @@ -14,7 +15,7 @@ ) from lato.exceptions import HandlerNotFoundError from lato.message import Message -from lato.types import HandlerAlias +from lato.types import DependencyIdentifier, HandlerAlias from lato.utils import maybe_await log = logging.getLogger(__name__) @@ -24,18 +25,16 @@ class MessageHandler: source: str message: HandlerAlias - fn: Callable + fn: Callable[..., Any] - def __hash__(self): + def __hash__(self) -> int: return hash((self.source, self.fn)) -OnEnterTransactionContextCallback = Callable[["TransactionContext"], Awaitable[None]] -OnExitTransactionContextCallback = Callable[ - ["TransactionContext", Optional[Exception]], Awaitable[None] -] -MiddlewareFunction = Callable[["TransactionContext", Callable], Awaitable[Any]] -ComposerFunction = Callable[..., Callable] +OnEnterTransactionContextCallback = Callable[..., Any] +OnExitTransactionContextCallback = Callable[..., Any] +MiddlewareFunction = Callable[["TransactionContext", Callable[..., Any]], Any] +ComposerFunction = Callable[..., Any] HandlersIterator = Callable[[HandlerAlias], Iterator[MessageHandler]] @@ -59,8 +58,11 @@ class TransactionContext: dependency_provider_factory = BasicDependencyProvider def __init__( - self, dependency_provider: Optional[DependencyProvider] = None, *args, **kwargs - ): + self, + dependency_provider: Optional[DependencyProvider] = None, + *args: Any, + **kwargs: Any, + ) -> None: """Initialize the transaction context instance. :param dependency_provider: dependency provider :class:`DependencyProvider` instance. @@ -93,7 +95,7 @@ def configure( middlewares: Optional[list[MiddlewareFunction]] = None, composers: Optional[dict[HandlerAlias, ComposerFunction]] = None, handlers_iterator: Optional[HandlersIterator] = None, - ): + ) -> None: """Customize the behavior of the transaction context with callbacks, middlewares, and composers. :param on_enter_transaction_context: Optional; Function to be called when entering a transaction context. @@ -113,7 +115,7 @@ def configure( if handlers_iterator: self._handlers_iterator = handlers_iterator - def begin(self): + def begin(self) -> None: """Starts a transaction by calling `on_enter_transaction_context` callback. The callback could be used to set up the transaction-level dependencies (i.e. current time, transaction id), @@ -127,7 +129,7 @@ def begin(self): ) self._on_enter_transaction_context(self) - async def begin_async(self): + async def begin_async(self) -> None: """Asynchronously starts a transaction by calling async `on_enter_transaction_context` callback. The callback could be used to set up the transaction-level dependencies (i.e. current time, transaction id), @@ -139,7 +141,7 @@ async def begin_async(self): if asyncio.iscoroutine(result): await result - def end(self, exception: Optional[Exception] = None): + def end(self, exception: Optional[Exception] = None) -> None: """Ends the transaction context by calling `on_exit_transaction_context` callback, optionally passing an exception. @@ -159,7 +161,7 @@ def end(self, exception: Optional[Exception] = None): else: log.debug("Ended transaction") - async def end_async(self, exception: Optional[Exception] = None): + async def end_async(self, exception: Optional[Exception] = None) -> None: """Ends the transaction context by calling `on_exit_transaction_context` callback, optionally passing an exception. @@ -191,28 +193,40 @@ def is_async_context_manager(self) -> bool: + [asyncio.iscoroutinefunction(self._on_exit_transaction_context)] ) - def iterate_handlers_for(self, alias: str): + def iterate_handlers_for(self, alias: HandlerAlias) -> Iterator[MessageHandler]: yield from self._handlers_iterator(alias) - def __enter__(self): + def __enter__(self) -> "TransactionContext": self.begin() return self - def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): - self.end(exc_val) + def __exit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + self.end(exc_val) # type: ignore[arg-type] - async def __aenter__(self): + async def __aenter__(self) -> "TransactionContext": result = self.begin_async() if asyncio.iscoroutine(result): await result return self - async def __aexit__(self, exc_type=None, exc_val=None, exc_tb=None): - result = self.end_async(exc_val) + async def __aexit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + result = self.end_async(exc_val) # type: ignore[arg-type] if asyncio.iscoroutine(result): await result - def call(self, func: Callable, *func_args: Any, **func_kwargs: Any) -> Any: + def call( + self, func: Callable[..., Any], *func_args: Any, **func_kwargs: Any + ) -> Any: """Call a function with the arguments and keyword arguments. Missing arguments will be resolved with the dependency provider. @@ -294,7 +308,7 @@ async def call_async( else: return call_next() - def execute(self, message: Message) -> tuple[Any, ...]: + def execute(self, message: Message) -> Any: """Executes all handlers bound to the message. Returns a tuple of handlers' return values. :param message: The message to be executed. @@ -309,7 +323,7 @@ def execute(self, message: Message) -> tuple[Any, ...]: composed_result = self._compose_results(message, results) return composed_result - async def execute_async(self, message: Message) -> tuple[Any, ...]: + async def execute_async(self, message: Message) -> Any: """Executes all async handlers bound to the message. Returns a tuple of handlers' return values. :param message: The message to be executed. @@ -325,13 +339,13 @@ async def execute_async(self, message: Message) -> tuple[Any, ...]: return composed_result def emit( - self, message: Union[str, Message], *args, **kwargs + self, message: Union[str, Message], *args: Any, **kwargs: Any ) -> dict[MessageHandler, Any]: # TODO: mark as obsolete return self.publish(message, *args, **kwargs) def publish( - self, message: Union[str, Message], *args, **kwargs + self, message: Union[str, Message], *args: Any, **kwargs: Any ) -> dict[MessageHandler, Any]: """ Publish a message by calling all handlers for that message. @@ -346,8 +360,8 @@ def publish( if isinstance(message, Message): args = (message, *args) - all_results = OrderedDict() - for handler in self._handlers_iterator(message_type): # type: ignore + all_results: OrderedDict[MessageHandler, Any] = OrderedDict() + for handler in self._handlers_iterator(message_type): # type: ignore[arg-type] self.set_dependency("message", message) # FIXME: push and pop current action instead of setting it self.current_handler = handler @@ -356,8 +370,8 @@ def publish( return all_results async def publish_async( - self, message: Union[str, Message], *args, **kwargs - ) -> dict[MessageHandler, Awaitable[Any]]: + self, message: Union[str, Message], *args: Any, **kwargs: Any + ) -> dict[MessageHandler, Any]: """ Asynchronously publish a message by calling all handlers for that message. @@ -371,9 +385,9 @@ async def publish_async( if isinstance(message, Message): args = (message, *args) - all_results = OrderedDict() + all_results: OrderedDict[MessageHandler, Any] = OrderedDict() # TODO: use asyncio.gather() - for handler in self._handlers_iterator(message_type): # type: ignore + for handler in self._handlers_iterator(message_type): # type: ignore[arg-type] self.set_dependency("message", message) # FIXME: push and pop current action instead of setting it self.current_handler = ( @@ -383,20 +397,20 @@ async def publish_async( all_results[handler] = result return all_results - def get_dependency(self, identifier: Any) -> Any: + def get_dependency(self, identifier: DependencyIdentifier) -> Any: """Gets a dependency from the dependency provider""" return self.dependency_provider.get_dependency(identifier) - def set_dependency(self, identifier: Any, dependency: Any) -> None: + def set_dependency(self, identifier: DependencyIdentifier, dependency: Any) -> None: """Sets a dependency in the dependency provider""" self.dependency_provider.register_dependency(identifier, dependency) - def set_dependencies(self, **kwargs): + def set_dependencies(self, **kwargs: Any) -> None: # TODO: add *args """Sets multiple dependencies at once""" self.dependency_provider.update(**kwargs) - def __getitem__(self, item) -> Any: + def __getitem__(self, item: DependencyIdentifier) -> Any: return self.get_dependency(item) def _compose_results( @@ -410,6 +424,6 @@ def _compose_results( return composer(**kwargs) @property - def current_action(self) -> tuple[Message, Callable]: + def current_action(self) -> tuple[Message, Callable[..., Any]]: """Returns current message and handler being executed""" - return self.get_dependency("message"), self.current_handler # type: ignore + return self.get_dependency("message"), self.current_handler # type: ignore[return-value] diff --git a/lato/utils.py b/lato/utils.py index d9b2bf0..66bdb81 100644 --- a/lato/utils.py +++ b/lato/utils.py @@ -1,27 +1,28 @@ import asyncio import re from collections import OrderedDict -from typing import TypeVar +from collections.abc import Callable, Iterable +from typing import Any, Optional, TypeVar T = TypeVar("T") class OrderedSet(OrderedDict[T, None]): - def __init__(self, iterable=None): + def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: super().__init__() if iterable: for item in iterable: self.add(item) - def add(self, item: T): + def add(self, item: T) -> None: self[item] = None - def update(self, iterable): + def update(self, iterable: Iterable[T]) -> None: # type: ignore[override] for item in iterable: self.add(item) -def string_to_kwarg_name(string): +def string_to_kwarg_name(string: str) -> str: # Remove invalid characters and replace them with underscores valid_string = re.sub(r"[^a-zA-Z0-9_]", "_", string) @@ -32,7 +33,7 @@ def string_to_kwarg_name(string): return valid_string -async def maybe_await(func, *args, **kwargs): +async def maybe_await(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: if asyncio.iscoroutinefunction(func): return await func(*args, **kwargs) else: diff --git a/pyproject.toml b/pyproject.toml index 45ac13b..5b2e052 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lato" -version = "0.13.0" +version = "0.13.1" description = "Lato is a Python microframework designed for building modular monoliths and loosely coupled applications." authors = ["Przemysław Górecki "] readme = "README.md" @@ -38,6 +38,18 @@ sphinx-rtd-theme = "^1.3.0" dependency-injector = "^4.41.0" lagom = "^2.5.0" +[tool.mypy] +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +no_implicit_optional = true +warn_unused_ignores = true +warn_redundant_casts = true + +[[tool.mypy.overrides]] +module = "mergedeep.*" +ignore_missing_imports = true + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api"