Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Change Log

## [0.13.1] - 2026-02-20

### Added

- Comprehensive type annotations across all modules (issue #5)
- `[tool.mypy]` configuration in `pyproject.toml` with `disallow_untyped_defs`, `disallow_incomplete_defs`, `check_untyped_defs`, `no_implicit_optional`, `warn_unused_ignores`, `warn_redundant_casts`
- `@overload` signatures for `ApplicationModule.handler()` decorator to preserve decorated function types

## [0.13.0] - 2025-02-20

### Breaking Changes
Expand Down
3 changes: 1 addition & 2 deletions lato/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import typing
from logging import NullHandler

from .application import Application
Expand All @@ -14,7 +13,7 @@
from .message import Command, Event, Query
from .transaction_context import TransactionContext

__version__ = "0.13.0"
__version__ = "0.13.1"
__all__ = [
"Application",
"ApplicationModule",
Expand Down
36 changes: 21 additions & 15 deletions lato/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections.abc import Awaitable, Callable
from typing import Any, Optional, Union
from typing import Any, Optional, TypeVar, Union

from lato.application_module import ApplicationModule
from lato.dependency_provider import BasicDependencyProvider, DependencyProvider
Expand All @@ -18,6 +18,8 @@

log = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])


class Application(ApplicationModule):
"""Core Application class.
Expand All @@ -32,10 +34,10 @@ class Application(ApplicationModule):

def __init__(
self,
name=__name__,
name: str = __name__,
dependency_provider: Optional[DependencyProvider] = None,
**kwargs,
):
**kwargs: Any,
) -> None:
"""Initialize the application instance.

:param name: Name of the application
Expand All @@ -48,15 +50,17 @@ def __init__(
self.dependency_provider = (
dependency_provider or self.dependency_provider_factory(**kwargs)
)
self._transaction_context_factory: Optional[Callable] = None
self._transaction_context_factory: Optional[
Callable[..., TransactionContext]
] = None
self._on_enter_transaction_context: Optional[
OnEnterTransactionContextCallback
] = None
self._on_exit_transaction_context: Optional[
OnExitTransactionContextCallback
] = None
self._transaction_middlewares: list[MiddlewareFunction] = []
self._composers: dict[Union[Message, str], ComposerFunction] = {}
self._composers: dict[Union[type[Message], str], ComposerFunction] = {}

def get_dependency(self, identifier: DependencyIdentifier) -> Any:
"""Gets a dependency from the dependency provider. Dependencies can be resolved either by name or by type.
Expand All @@ -70,7 +74,9 @@ def get_dependency(self, identifier: DependencyIdentifier) -> Any:
def __getitem__(self, identifier: DependencyIdentifier) -> Any:
return self.get_dependency(identifier)

def call(self, func: Union[Callable[..., Any], str], *args, **kwargs) -> Any:
def call(
self, func: Union[Callable[..., Any], str], *args: Any, **kwargs: Any
) -> Any:
"""Invokes a function with `args` and `kwargs` within the :class:`TransactionContext`.
If `func` is a string, then it is an alias, and the corresponding handler for the alias is retrieved.
Any missing arguments are provided by the dependency provider of a transaction context,
Expand All @@ -96,7 +102,7 @@ def call(self, func: Union[Callable[..., Any], str], *args, **kwargs) -> Any:
return result

async def call_async(
self, func: Union[Callable[..., Awaitable[Any]], str], *args, **kwargs
self, func: Union[Callable[..., Awaitable[Any]], str], *args: Any, **kwargs: Any
) -> Any:
"""Invokes an async function with `args` and `kwargs` within the :class:`TransactionContext`.
If `func` is a string, then it is an alias, and the corresponding handler for the alias is retrieved.
Expand Down Expand Up @@ -178,7 +184,7 @@ async def publish_async(self, event: Event) -> dict[MessageHandler, Any]:
result = await ctx.publish_async(event)
return result

def on_enter_transaction_context(self, func):
def on_enter_transaction_context(self, func: F) -> F:
"""
Decorator for registering a function to be called when entering a transaction context

Expand All @@ -201,7 +207,7 @@ def on_enter_transaction_context(self, func):
self._on_enter_transaction_context = func
return func

def on_exit_transaction_context(self, func):
def on_exit_transaction_context(self, func: F) -> F:
"""
Decorator for registering a function to be called when exiting a transaction context

Expand All @@ -223,7 +229,7 @@ def on_exit_transaction_context(self, func):
self._on_exit_transaction_context = func
return func

def on_create_transaction_context(self, func):
def on_create_transaction_context(self, func: F) -> F:
"""
Decorator for overriding default transaction context creation

Expand All @@ -248,7 +254,7 @@ def on_create_transaction_context(self, func):
self._transaction_context_factory = func
return func

def transaction_middleware(self, middleware_func):
def transaction_middleware(self, middleware_func: F) -> F:
"""
Decorator for registering a middleware function to be called when executing a function in a transaction context
:param middleware_func:
Expand All @@ -269,7 +275,7 @@ def transaction_middleware(self, middleware_func):
self._transaction_middlewares.append(middleware_func)
return middleware_func

def compose(self, alias):
def compose(self, alias: Any) -> Callable[[F], F]:
"""
Decorator for composing results of handlers identified by an alias.

Expand All @@ -287,13 +293,13 @@ def compose(self, alias):
... ...
"""

def decorator(func):
def decorator(func: F) -> F:
self._composers[alias] = func
return func

return decorator

def transaction_context(self, **dependencies) -> TransactionContext:
def transaction_context(self, **dependencies: Any) -> TransactionContext:
"""Creates a transaction context for the app.

The lifecycle of a transaction context is controlled by :func:`transaction_middleware`,
Expand Down
37 changes: 24 additions & 13 deletions lato/application_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Callable, Iterator
from typing import Any, TypeVar, overload

from lato.exceptions import DuplicateHandlerError
from lato.message import Command, Message, Query
Expand All @@ -10,24 +11,26 @@

log = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])


class ApplicationModule:
def __init__(self, name: str):
def __init__(self, name: str) -> None:
"""Initialize the application module instance.

:param name: Name of the module
"""
self.name: str = name
self._handlers: defaultdict[HandlerAlias, OrderedSet[Callable]] = defaultdict(
OrderedSet
)
self._submodules: OrderedSet[ApplicationModule] = OrderedSet()
self._handlers: defaultdict[
HandlerAlias, OrderedSet[Callable[..., Any]]
] = defaultdict(OrderedSet)
self._submodules: OrderedSet["ApplicationModule"] = OrderedSet()

@property
def identifier(self):
def identifier(self) -> str:
return string_to_kwarg_name(self.name)

def include_submodule(self, a_module: "ApplicationModule"):
def include_submodule(self, a_module: "ApplicationModule") -> None:
"""Adds a child submodule to this module.

:param a_module: child module to add
Expand All @@ -37,7 +40,15 @@ def include_submodule(self, a_module: "ApplicationModule"):
), f"Can only include {ApplicationModule} instances, got {a_module}"
self._submodules.add(a_module)

def handler(self, alias: HandlerAlias) -> Callable:
@overload
def handler(self, alias: F) -> F:
...

@overload
def handler(self, alias: HandlerAlias) -> Callable[[F], F]:
...

def handler(self, alias: Any) -> Any:
"""
Decorator for registering a handler. Handler can be aliased by a name or by a message type.

Expand Down Expand Up @@ -89,7 +100,7 @@ def handler(self, alias: HandlerAlias) -> Callable:
# decorator was called with argument
# @my_module.handle("my_function")
# @my_module.handle(MyCommand)
def decorator(func):
def decorator(func: F) -> F:
"""
Decorator for registering tasks by name
"""
Expand All @@ -107,7 +118,7 @@ def decorator(func):

return decorator

def iterate_handlers_for(self, alias: str):
def iterate_handlers_for(self, alias: HandlerAlias) -> Iterator[MessageHandler]:
if alias in self._handlers:
for handler in self._handlers[alias]:
yield MessageHandler(source=self.identifier, message=alias, fn=handler)
Expand All @@ -117,8 +128,8 @@ def iterate_handlers_for(self, alias: str):
except KeyError:
pass

def get_handlers_for(self, alias: str):
def get_handlers_for(self, alias: HandlerAlias) -> list[MessageHandler]:
return list(self.iterate_handlers_for(alias))

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.name} {object.__repr__(self)}>"
6 changes: 3 additions & 3 deletions lato/compositon.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from collections.abc import Callable
from functools import partial, reduce
from operator import add, or_
from typing import Optional
from typing import Any, Optional

from mergedeep import Strategy, merge # type: ignore
from mergedeep import Strategy, merge

additive_merge = partial(merge, strategy=Strategy.TYPESAFE_ADDITIVE)


def compose(compose_operator: Optional[Callable] = None, **kwargs):
def compose(compose_operator: Optional[Callable] = None, **kwargs: Any) -> Any:
values = tuple(value for module_name, value in kwargs.items() if value is not None)

if len(values) == 0:
Expand Down
34 changes: 19 additions & 15 deletions lato/dependency_provider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import inspect
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Callable
from typing import Any

from lato.exceptions import UnknownDependencyError
from lato.types import DependencyIdentifier
from lato.utils import OrderedDict


class TypedDependency:
def __init__(self, value, a_type):
def __init__(self, value: Any, a_type: type) -> None:
self.value = value
self.a_type = a_type

Expand All @@ -18,7 +18,7 @@ def as_type(obj: Any, cls: type) -> TypedDependency:
return TypedDependency(obj, cls)


def get_function_parameters(func) -> OrderedDict:
def get_function_parameters(func: Callable[..., Any]) -> OrderedDict[str, Any]:
"""
Retrieve the function's parameters and their annotations.

Expand All @@ -27,7 +27,7 @@ def get_function_parameters(func) -> OrderedDict:
"""
handler_signature = inspect.signature(func)
kwargs_iterator = iter(handler_signature.parameters.items())
parameters = OrderedDict()
parameters: OrderedDict[str, Any] = OrderedDict()
for name, param in kwargs_iterator:
parameters[name] = param.annotation
return parameters
Expand Down Expand Up @@ -55,7 +55,9 @@ def has_dependency(self, identifier: DependencyIdentifier) -> bool:
raise NotImplementedError()

@abstractmethod
def register_dependency(self, identifier: DependencyIdentifier, dependency: Any):
def register_dependency(
self, identifier: DependencyIdentifier, dependency: Any
) -> None:
"""
Register a dependency with a given identifier (name or type).

Expand All @@ -74,15 +76,15 @@ def get_dependency(self, identifier: DependencyIdentifier) -> Any:
raise NotImplementedError()

@abstractmethod
def copy(self, *args, **kwargs) -> "DependencyProvider":
def copy(self, *args: Any, **kwargs: Any) -> "DependencyProvider":
"""Creates a copy of self with updated dependencies.

:param args: dependencies to update, identified by type.
:param kwargs: dependencies to update, identified by name and type.
:return: A copy of the dependency provider.
"""

def update(self, *args, **kwargs):
def update(self, *args: Any, **kwargs: Any) -> None:
"""
Updates the dependency provider with new dependencies.

Expand All @@ -101,7 +103,7 @@ def update(self, *args, **kwargs):
if self.allow_types:
self.register_dependency(t, v)

def _get_type_and_value(self, value):
def _get_type_and_value(self, value: Any) -> tuple[type, Any]:
if isinstance(value, TypedDependency):
return value.a_type, value.value
return type(value), value
Expand All @@ -128,7 +130,7 @@ def resolve_func_params(
func_kwargs = {}

func_parameters = get_function_parameters(func)
resolved_kwargs = OrderedDict()
resolved_kwargs: OrderedDict[str, Any] = OrderedDict()
arg_idx = 0
for param_name, param_type in func_parameters.items():
if arg_idx < len(func_args):
Expand All @@ -151,10 +153,10 @@ def resolve_func_params(

return resolved_kwargs

def __getitem__(self, key):
def __getitem__(self, key: DependencyIdentifier) -> Any:
return self.get_dependency(key)

def __setitem__(self, key, value):
def __setitem__(self, key: DependencyIdentifier, value: Any) -> None:
self.register_dependency(key, value)


Expand All @@ -164,16 +166,18 @@ class BasicDependencyProvider(DependencyProvider):
dependency injection based on type or parameter name.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the DependencyProvider.

:param args: Class instances to be registered by types
:param kwargs: Dependencies to be registered by types and with explicit names
"""
self._dependencies = {}
self._dependencies: dict[DependencyIdentifier, Any] = {}
self.update(*args, **kwargs)

def register_dependency(self, identifier: DependencyIdentifier, dependency: Any):
def register_dependency(
self, identifier: DependencyIdentifier, dependency: Any
) -> None:
"""
Register a dependency with a given identifier (name or type).

Expand Down Expand Up @@ -206,7 +210,7 @@ def get_dependency(self, identifier: DependencyIdentifier) -> Any:
except KeyError as e:
raise UnknownDependencyError(identifier)

def copy(self, *args, **kwargs) -> DependencyProvider:
def copy(self, *args: Any, **kwargs: Any) -> "BasicDependencyProvider":
"""
Create a copy of self with updated dependencies.
:param args: typed overrides
Expand Down
Loading