diff --git a/src/leadr/common/domain/models.py b/src/leadr/common/domain/models.py index 131108f..ae34df6 100644 --- a/src/leadr/common/domain/models.py +++ b/src/leadr/common/domain/models.py @@ -7,19 +7,13 @@ from pydantic import BaseModel, ConfigDict, Field -class ImmutableEntity(BaseModel): - """Base class for immutable domain entities (append-only, no updates/deletes). +class _EntityBase(BaseModel): + """Private shared base for Entity and ImmutableEntity. - Provides common functionality for event-sourced entities including: + Provides common fields and behaviour: - Auto-generated UUID primary key (or typed prefixed ID in subclasses) - Created timestamp (UTC) - Equality and hashing based on ID - - Used for entities that are never updated or deleted after creation, - such as ScoreEvent in event-sourcing patterns. - - Subclasses can override the `id` field with a typed PrefixedID for better - type safety and API clarity. """ model_config = ConfigDict(validate_assignment=True) @@ -61,7 +55,23 @@ def __hash__(self) -> int: return hash(self.id) -class Entity(BaseModel): +class ImmutableEntity(_EntityBase): + """Base class for immutable domain entities (append-only, no updates/deletes). + + Provides common functionality for event-sourced entities including: + - Auto-generated UUID primary key (or typed prefixed ID in subclasses) + - Created timestamp (UTC) + - Equality and hashing based on ID + + Used for entities that are never updated or deleted after creation, + such as ScoreEvent in event-sourcing patterns. + + Subclasses can override the `id` field with a typed PrefixedID for better + type safety and API clarity. + """ + + +class Entity(_EntityBase): """Base class for all domain entities with ID and timestamps. Provides common functionality for domain entities including: @@ -78,17 +88,6 @@ class Entity(BaseModel): type safety and API clarity. """ - model_config = ConfigDict(validate_assignment=True) - - id: Any = Field( - frozen=True, - default_factory=uuid4, - description="Unique identifier (auto-generated UUID or typed ID)", - ) - created_at: datetime = Field( - default_factory=lambda: datetime.now(UTC), - description="Timestamp when entity was created (UTC)", - ) updated_at: datetime = Field( default_factory=lambda: datetime.now(UTC), description="Timestamp of last update (UTC)", @@ -131,29 +130,3 @@ def restore(self) -> None: >>> assert account.is_deleted is False """ self.deleted_at = None - - def __eq__(self, other: object) -> bool: - """Check equality based on ID. - - Two entities are considered equal if they have the same ID and are - of the same class. - - Args: - other: Object to compare with. - - Returns: - True if both entities have the same ID and class. - """ - if not isinstance(other, self.__class__): - return False - return self.id == other.id - - def __hash__(self) -> int: - """Return hash based on ID. - - Allows entities to be used in sets and as dictionary keys. - - Returns: - Hash of the entity's ID. - """ - return hash(self.id) diff --git a/src/leadr/common/repositories.py b/src/leadr/common/repositories.py index 5d3932c..cecf5c8 100644 --- a/src/leadr/common/repositories.py +++ b/src/leadr/common/repositories.py @@ -13,7 +13,7 @@ from leadr.common.domain.cursor import Cursor from leadr.common.domain.exceptions import EntityNotFoundError from leadr.common.domain.ids import AccountID, PrefixedID -from leadr.common.domain.models import Entity, ImmutableEntity +from leadr.common.domain.models import Entity, ImmutableEntity, _EntityBase from leadr.common.domain.pagination import ( CursorPosition, PaginationDirection, @@ -23,7 +23,11 @@ from leadr.common.domain.pagination_result import PaginatedResult from leadr.common.orm import Base, ImmutableBase -# Type variables for generic repository +# Type variables for the shared base +_EntityBaseT = TypeVar("_EntityBaseT", bound=_EntityBase) +_ORMBaseT = TypeVar("_ORMBaseT") + +# Type variables for mutable repository DomainEntityT = TypeVar("DomainEntityT", bound=Entity) ORMModelT = TypeVar("ORMModelT", bound=Base) @@ -32,13 +36,11 @@ ImmutableORMT = TypeVar("ImmutableORMT", bound=ImmutableBase) -class BaseRepository(ABC, Generic[DomainEntityT, ORMModelT]): - """Abstract base repository providing common CRUD operations. - - All repositories should extend this class and implement the abstract methods - for converting between domain entities and ORM models. +class _RepositoryBase(ABC, Generic[_EntityBaseT, _ORMBaseT]): + """Private shared base for BaseRepository and ImmutableBaseRepository. - All delete operations are soft deletes by default, setting deleted_at timestamp. + Contains common infrastructure: session management, ID extraction, + abstract conversion methods, create, and all pagination helpers. """ def __init__(self, session: AsyncSession): @@ -68,7 +70,7 @@ def _extract_uuid(id_value: UUID | PrefixedID | UUID4 | str) -> UUID: return UUID(str(id_value)) @abstractmethod - def _to_domain(self, orm: ORMModelT) -> DomainEntityT: + def _to_domain(self, orm: _ORMBaseT) -> _EntityBaseT: """Convert ORM model to domain entity. Args: @@ -79,7 +81,7 @@ def _to_domain(self, orm: ORMModelT) -> DomainEntityT: """ @abstractmethod - def _to_orm(self, entity: DomainEntityT) -> ORMModelT: + def _to_orm(self, entity: _EntityBaseT) -> _ORMBaseT: """Convert domain entity to ORM model. Args: @@ -90,14 +92,14 @@ def _to_orm(self, entity: DomainEntityT) -> ORMModelT: """ @abstractmethod - def _get_orm_class(self) -> type[ORMModelT]: + def _get_orm_class(self) -> type[_ORMBaseT]: """Get the ORM model class for this repository. Returns: ORM model class """ - async def create(self, entity: DomainEntityT) -> DomainEntityT: + async def create(self, entity: _EntityBaseT) -> _EntityBaseT: """Create a new entity in the database. Args: @@ -112,90 +114,6 @@ async def create(self, entity: DomainEntityT) -> DomainEntityT: await self.session.refresh(orm) return self._to_domain(orm) - async def get_by_id( - self, entity_id: UUID4 | PrefixedID, include_deleted: bool = False - ) -> DomainEntityT | None: - """Get an entity by its ID. - - Args: - entity_id: Entity ID to retrieve - include_deleted: If True, include soft-deleted entities. Defaults to False. - - Returns: - Domain entity if found, None otherwise - """ - orm_class = self._get_orm_class() - query = select(orm_class).where(orm_class.id == self._extract_uuid(entity_id)) - - if not include_deleted: - query = query.where(orm_class.deleted_at.is_(None)) - - result = await self.session.execute(query) - orm = result.scalar_one_or_none() - - return self._to_domain(orm) if orm else None - - async def update(self, entity: DomainEntityT) -> DomainEntityT: - """Update an existing entity in the database. - - Args: - entity: Domain entity with updated data - - Returns: - Updated domain entity with refreshed data - - Raises: - EntityNotFoundError: If entity is not found - """ - orm_class = self._get_orm_class() - entity_uuid = self._extract_uuid(entity.id) - result = await self.session.execute(select(orm_class).where(orm_class.id == entity_uuid)) - orm = result.scalar_one_or_none() - - if not orm: - # Get entity type name from ORM class - entity_type = orm_class.__name__.replace("ORM", "") - raise EntityNotFoundError(entity_type, str(entity.id)) - - # Update ORM from entity - updated_orm = self._to_orm(entity) - for key, value in updated_orm.__dict__.items(): - if not key.startswith("_"): - setattr(orm, key, value) - - await self.session.commit() - await self.session.refresh(orm) - return self._to_domain(orm) - - async def delete(self, entity_id: UUID4 | PrefixedID) -> None: - """Soft delete an entity by setting its deleted_at timestamp. - - Args: - entity_id: ID of entity to delete - - Raises: - EntityNotFoundError: If entity is not found - """ - orm_class = self._get_orm_class() - entity_uuid = self._extract_uuid(entity_id) - - # Verify entity exists - result = await self.session.execute(select(orm_class).where(orm_class.id == entity_uuid)) - orm = result.scalar_one_or_none() - - if not orm: - # Get entity type name from ORM class - entity_type = orm_class.__name__.replace("ORM", "") - raise EntityNotFoundError(entity_type, str(entity_id)) - - # Perform soft delete - await self.session.execute( - update(orm_class) - .where(orm_class.id == entity_uuid) - .values(deleted_at=datetime.now(UTC)) - ) - await self.session.commit() - @abstractmethod async def filter( self, @@ -203,19 +121,11 @@ async def filter( *, pagination: PaginationParams, **kwargs: Any, - ) -> PaginatedResult[DomainEntityT]: + ) -> PaginatedResult[_EntityBaseT]: """Filter entities based on criteria with pagination. - All filter operations return paginated results. The pagination parameter - is required to enforce consistent API behavior across the codebase. - - For multi-tenant entities, implementations should make account_id required - (no default). For top-level entities like Account, account_id can remain - optional and unused. - Args: - account_id: Optional account ID for filtering. Multi-tenant entities - should override to make this required. + account_id: Optional account ID for filtering. pagination: Required pagination parameters (cursor, limit, sort). **kwargs: Additional filter parameters specific to the entity type. @@ -223,121 +133,6 @@ async def filter( PaginatedResult containing matching entities and pagination metadata. """ - async def _list_all_unfiltered(self, include_deleted: bool = False) -> list[DomainEntityT]: - """List all entities without filtering by account. - - PRIVATE METHOD - Use filter() in application code for multi-tenant safety. - This method is for internal use and testing only. - - Args: - include_deleted: If True, include soft-deleted entities. Defaults to False. - - Returns: - List of domain entities - """ - orm_class = self._get_orm_class() - query = select(orm_class) - - if not include_deleted: - query = query.where(orm_class.deleted_at.is_(None)) - - result = await self.session.execute(query) - orms = result.scalars().all() - - return [self._to_domain(orm) for orm in orms] - - # Helper methods for common repository patterns - - async def _get_by_field(self, field_name: str, value: Any) -> DomainEntityT | None: - """Get an entity by a specific field value. - - This is a helper method that reduces boilerplate for get_by_ patterns - like get_by_slug, get_by_email, get_by_prefix, etc. - - Args: - field_name: Name of the ORM field to query - value: Value to match - - Returns: - Domain entity if found, None otherwise - - Example: - async def get_by_slug(self, slug: str) -> Account | None: - return await self._get_by_field("slug", slug) - """ - orm_class = self._get_orm_class() - field = getattr(orm_class, field_name) - query = select(orm_class).where(field == value, orm_class.deleted_at.is_(None)) - result = await self.session.execute(query) - orm = result.scalar_one_or_none() - return self._to_domain(orm) if orm else None - - async def _list_by_account( - self, - account_id: UUID4, - additional_filters: list[Any] | None = None, - ) -> list[DomainEntityT]: - """List entities for a specific account. - - This is a helper method that reduces boilerplate for list_by_account patterns. - - Args: - account_id: Account ID to filter by - additional_filters: Optional list of additional SQLAlchemy filter expressions - - Returns: - List of domain entities belonging to the account - - Example: - async def list_by_account(self, account_id: UUID, active_only: bool = False): - filters = [] - if active_only: - filters.append(UserORM.status == UserStatusEnum.ACTIVE) - return await self._list_by_account(account_id, filters) - """ - orm_class = self._get_orm_class() - account_uuid = self._extract_uuid(account_id) - query = select(orm_class).where( - orm_class.account_id == account_uuid, # type: ignore[attr-defined] - orm_class.deleted_at.is_(None), - ) - - if additional_filters: - for filter_expr in additional_filters: - query = query.where(filter_expr) - - result = await self.session.execute(query) - orms = result.scalars().all() - return [self._to_domain(orm) for orm in orms] - - async def _count_where(self, *conditions: Any) -> int: - """Count entities matching given conditions. - - This is a helper method that reduces boilerplate for count operations. - - Args: - *conditions: SQLAlchemy filter expressions to apply - - Returns: - Count of entities matching the conditions - - Example: - async def count_active_by_account(self, account_id: UUID) -> int: - return await self._count_where( - APIKeyORM.account_id == account_id, - APIKeyORM.status == APIKeyStatusEnum.ACTIVE, - APIKeyORM.deleted_at.is_(None), - ) - """ - orm_class = self._get_orm_class() - query = select(func.count()).select_from(orm_class) - - for condition in conditions: - query = query.where(condition) - - result = await self.session.execute(query) - return result.scalar_one() - # Pagination support methods def _get_orm_column(self, field_name: str) -> Any: @@ -469,7 +264,7 @@ def _apply_sort(self, query: Any, sort_fields: list[SortField]) -> Any: def _extract_cursor_position( self, - orm: ORMModelT, + orm: _ORMBaseT, sort_fields: list[SortField], ) -> CursorPosition: """Extract cursor position from an ORM model. @@ -486,7 +281,7 @@ def _extract_cursor_position( value = getattr(orm, sort_field.name) values.append(value) - entity_id = str(orm.id) + entity_id = str(orm.id) # type: ignore[union-attr] return CursorPosition(values=tuple(values), entity_id=entity_id) async def _execute_paginated_query( @@ -495,7 +290,7 @@ async def _execute_paginated_query( sort_fields: list[SortField], cursor: Cursor | None, limit: int, - ) -> PaginatedResult[DomainEntityT]: + ) -> PaginatedResult[_EntityBaseT]: """Execute a paginated query and return results with metadata. Fetches limit+1 records to determine has_next efficiently. @@ -563,307 +358,247 @@ async def _execute_paginated_query( ) -class ImmutableBaseRepository(ABC, Generic[ImmutableEntityT, ImmutableORMT]): - """Abstract base repository for immutable (append-only) entities. +class BaseRepository(_RepositoryBase[DomainEntityT, ORMModelT]): + """Abstract base repository providing common CRUD operations. - Used for event-sourced entities that are never updated or deleted. - Provides only create, get, and filter operations. + All repositories should extend this class and implement the abstract methods + for converting between domain entities and ORM models. + + All delete operations are soft deletes by default, setting deleted_at timestamp. """ - def __init__(self, session: AsyncSession): - """Initialize repository with database session. + async def get_by_id( + self, entity_id: UUID4 | PrefixedID, include_deleted: bool = False + ) -> DomainEntityT | None: + """Get an entity by its ID. Args: - session: SQLAlchemy async session + entity_id: Entity ID to retrieve + include_deleted: If True, include soft-deleted entities. Defaults to False. + + Returns: + Domain entity if found, None otherwise """ - self.session = session + orm_class = self._get_orm_class() + query = select(orm_class).where(orm_class.id == self._extract_uuid(entity_id)) - @staticmethod - def _extract_uuid(id_value: UUID | PrefixedID | UUID4 | str) -> UUID: - """Extract UUID from various ID types. + if not include_deleted: + query = query.where(orm_class.deleted_at.is_(None)) - Args: - id_value: ID value that can be UUID, PrefixedID, UUID4, or string + result = await self.session.execute(query) + orm = result.scalar_one_or_none() - Returns: - UUID instance for database querying - """ - if isinstance(id_value, PrefixedID): - return id_value.uuid - if isinstance(id_value, UUID): - return id_value - return UUID(str(id_value)) + return self._to_domain(orm) if orm else None - @abstractmethod - def _to_domain(self, orm: ImmutableORMT) -> ImmutableEntityT: - """Convert ORM model to domain entity. + async def update(self, entity: DomainEntityT) -> DomainEntityT: + """Update an existing entity in the database. Args: - orm: ORM model instance + entity: Domain entity with updated data Returns: - Domain entity instance - """ - - @abstractmethod - def _to_orm(self, entity: ImmutableEntityT) -> ImmutableORMT: - """Convert domain entity to ORM model. - - Args: - entity: Domain entity instance - - Returns: - ORM model instance - """ - - @abstractmethod - def _get_orm_class(self) -> type[ImmutableORMT]: - """Get the ORM model class for this repository. + Updated domain entity with refreshed data - Returns: - ORM model class + Raises: + EntityNotFoundError: If entity is not found """ + orm_class = self._get_orm_class() + entity_uuid = self._extract_uuid(entity.id) + result = await self.session.execute(select(orm_class).where(orm_class.id == entity_uuid)) + orm = result.scalar_one_or_none() - async def create(self, entity: ImmutableEntityT) -> ImmutableEntityT: - """Create a new immutable entity in the database. + if not orm: + # Get entity type name from ORM class + entity_type = orm_class.__name__.replace("ORM", "") + raise EntityNotFoundError(entity_type, str(entity.id)) - Args: - entity: Domain entity to create + # Update ORM from entity + updated_orm = self._to_orm(entity) + for key, value in updated_orm.__dict__.items(): + if not key.startswith("_"): + setattr(orm, key, value) - Returns: - Created domain entity with refreshed data - """ - orm = self._to_orm(entity) - self.session.add(orm) await self.session.commit() await self.session.refresh(orm) return self._to_domain(orm) - async def get_by_id(self, entity_id: UUID4 | PrefixedID) -> ImmutableEntityT | None: - """Get an immutable entity by its ID. + async def delete(self, entity_id: UUID4 | PrefixedID) -> None: + """Soft delete an entity by setting its deleted_at timestamp. Args: - entity_id: Entity ID to retrieve + entity_id: ID of entity to delete - Returns: - Domain entity if found, None otherwise + Raises: + EntityNotFoundError: If entity is not found """ orm_class = self._get_orm_class() - query = select(orm_class).where(orm_class.id == self._extract_uuid(entity_id)) - result = await self.session.execute(query) + entity_uuid = self._extract_uuid(entity_id) + + # Verify entity exists + result = await self.session.execute(select(orm_class).where(orm_class.id == entity_uuid)) orm = result.scalar_one_or_none() - return self._to_domain(orm) if orm else None - @abstractmethod - async def filter( - self, - account_id: AccountID | None = None, - *, - pagination: PaginationParams, - **kwargs: Any, - ) -> PaginatedResult[ImmutableEntityT]: - """Filter immutable entities based on criteria with pagination. + if not orm: + # Get entity type name from ORM class + entity_type = orm_class.__name__.replace("ORM", "") + raise EntityNotFoundError(entity_type, str(entity_id)) - Args: - account_id: Optional account ID for filtering. - pagination: Required pagination parameters (cursor, limit, sort). - **kwargs: Additional filter parameters specific to the entity type. + # Perform soft delete + await self.session.execute( + update(orm_class) + .where(orm_class.id == entity_uuid) + .values(deleted_at=datetime.now(UTC)) + ) + await self.session.commit() - Returns: - PaginatedResult containing matching entities and pagination metadata. - """ + async def _list_all_unfiltered(self, include_deleted: bool = False) -> list[DomainEntityT]: + """List all entities without filtering by account. - async def _list_all_unfiltered(self) -> list[ImmutableEntityT]: - """List all immutable entities without filtering. + PRIVATE METHOD - Use filter() in application code for multi-tenant safety. + This method is for internal use and testing only. - PRIVATE METHOD - Use filter() in application code. + Args: + include_deleted: If True, include soft-deleted entities. Defaults to False. Returns: List of domain entities """ orm_class = self._get_orm_class() query = select(orm_class) + + if not include_deleted: + query = query.where(orm_class.deleted_at.is_(None)) + result = await self.session.execute(query) orms = result.scalars().all() + return [self._to_domain(orm) for orm in orms] - # Pagination support methods + # Helper methods for common repository patterns - def _get_orm_column(self, field_name: str) -> Any: - """Get ORM column by field name. + async def _get_by_field(self, field_name: str, value: Any) -> DomainEntityT | None: + """Get an entity by a specific field value. + + This is a helper method that reduces boilerplate for get_by_ patterns + like get_by_slug, get_by_email, get_by_prefix, etc. Args: - field_name: Name of the field + field_name: Name of the ORM field to query + value: Value to match Returns: - SQLAlchemy column object + Domain entity if found, None otherwise - Raises: - ValueError: If field doesn't exist on ORM model + Example: + async def get_by_slug(self, slug: str) -> Account | None: + return await self._get_by_field("slug", slug) """ orm_class = self._get_orm_class() - if not hasattr(orm_class, field_name): - raise ValueError(f"Unknown sort field: {field_name}") - return getattr(orm_class, field_name) - - def _convert_cursor_value(self, value: Any, column: Any) -> Any: - """Convert cursor value to match ORM column type. - - Args: - value: Cursor value (JSON primitive) - column: SQLAlchemy column object - - Returns: - Value converted to match column's Python type - """ - column_type = column.type - - if isinstance(column_type, DateTime) and isinstance(value, str): - return datetime.fromisoformat(value) - elif isinstance(column_type, Uuid) and isinstance(value, str): - return UUID(value) - else: - return value + field = getattr(orm_class, field_name) + query = select(orm_class).where(field == value, orm_class.deleted_at.is_(None)) + result = await self.session.execute(query) + orm = result.scalar_one_or_none() + return self._to_domain(orm) if orm else None - def _build_cursor_where_clause( + async def _list_by_account( self, - cursor: Cursor, - sort_fields: list[SortField], - ) -> Any: - """Build WHERE clause for cursor-based pagination. + account_id: UUID4, + additional_filters: list[Any] | None = None, + ) -> list[DomainEntityT]: + """List entities for a specific account. + + This is a helper method that reduces boilerplate for list_by_account patterns. Args: - cursor: Cursor containing position and sort information - sort_fields: List of sort fields + account_id: Account ID to filter by + additional_filters: Optional list of additional SQLAlchemy filter expressions Returns: - SQLAlchemy WHERE clause condition - """ - position_values = cursor.position.values - is_backward = cursor.direction == PaginationDirection.BACKWARD - - or_conditions = [] - - for i, sort_field in enumerate(sort_fields): - if is_backward: - comp_op = "__gt__" if sort_field.direction == SortDirection.DESC else "__lt__" - else: - comp_op = "__lt__" if sort_field.direction == SortDirection.DESC else "__gt__" + List of domain entities belonging to the account - equality_conditions = [] - for j in range(i): - prev_field = sort_fields[j] - prev_column = self._get_orm_column(prev_field.name) - prev_value = self._convert_cursor_value(position_values[j], prev_column) - equality_conditions.append(prev_column == prev_value) + Example: + async def list_by_account(self, account_id: UUID, active_only: bool = False): + filters = [] + if active_only: + filters.append(UserORM.status == UserStatusEnum.ACTIVE) + return await self._list_by_account(account_id, filters) + """ + orm_class = self._get_orm_class() + account_uuid = self._extract_uuid(account_id) + query = select(orm_class).where( + orm_class.account_id == account_uuid, # type: ignore[attr-defined] + orm_class.deleted_at.is_(None), + ) - current_column = self._get_orm_column(sort_field.name) - current_value = self._convert_cursor_value(position_values[i], current_column) - comparison = getattr(current_column, comp_op)(current_value) + if additional_filters: + for filter_expr in additional_filters: + query = query.where(filter_expr) - if equality_conditions: - or_conditions.append(and_(*equality_conditions, comparison)) - else: - or_conditions.append(comparison) + result = await self.session.execute(query) + orms = result.scalars().all() + return [self._to_domain(orm) for orm in orms] - return or_(*or_conditions) + async def _count_where(self, *conditions: Any) -> int: + """Count entities matching given conditions. - def _apply_sort(self, query: Any, sort_fields: list[SortField]) -> Any: - """Apply sorting to a query. + This is a helper method that reduces boilerplate for count operations. Args: - query: SQLAlchemy query to sort - sort_fields: List of sort fields + *conditions: SQLAlchemy filter expressions to apply Returns: - Query with sorting applied + Count of entities matching the conditions + + Example: + async def count_active_by_account(self, account_id: UUID) -> int: + return await self._count_where( + APIKeyORM.account_id == account_id, + APIKeyORM.status == APIKeyStatusEnum.ACTIVE, + APIKeyORM.deleted_at.is_(None), + ) """ - for sort_field in sort_fields: - column = self._get_orm_column(sort_field.name) - if sort_field.direction == SortDirection.DESC: - query = query.order_by(column.desc()) - else: - query = query.order_by(column.asc()) - return query + orm_class = self._get_orm_class() + query = select(func.count()).select_from(orm_class) - def _extract_cursor_position( - self, - orm: ImmutableORMT, - sort_fields: list[SortField], - ) -> CursorPosition: - """Extract cursor position from an ORM model. + for condition in conditions: + query = query.where(condition) - Args: - orm: ORM model instance - sort_fields: List of sort fields to extract values for + result = await self.session.execute(query) + return result.scalar_one() - Returns: - CursorPosition with values for each sort field - """ - values = [] - for sort_field in sort_fields: - value = getattr(orm, sort_field.name) - values.append(value) - entity_id = str(orm.id) - return CursorPosition(values=tuple(values), entity_id=entity_id) +class ImmutableBaseRepository(_RepositoryBase[ImmutableEntityT, ImmutableORMT]): + """Abstract base repository for immutable (append-only) entities. - async def _execute_paginated_query( - self, - query: Any, - sort_fields: list[SortField], - cursor: Cursor | None, - limit: int, - ) -> PaginatedResult[ImmutableEntityT]: - """Execute a paginated query and return results with metadata. + Used for event-sourced entities that are never updated or deleted. + Provides only create, get, and filter operations. + """ + + async def get_by_id(self, entity_id: UUID4 | PrefixedID) -> ImmutableEntityT | None: + """Get an immutable entity by its ID. Args: - query: Base SQLAlchemy query (with filters applied) - sort_fields: List of sort fields - cursor: Optional cursor for pagination - limit: Number of items to return + entity_id: Entity ID to retrieve Returns: - PaginatedResult with items and pagination metadata + Domain entity if found, None otherwise """ - if cursor is not None: - cursor_where = self._build_cursor_where_clause(cursor, sort_fields) - query = query.where(cursor_where) - - query = self._apply_sort(query, sort_fields) - query = query.limit(limit + 1) - + orm_class = self._get_orm_class() + query = select(orm_class).where(orm_class.id == self._extract_uuid(entity_id)) result = await self.session.execute(query) - orms = list(result.scalars().all()) - - has_more = len(orms) > limit - - if has_more: - orms = orms[:limit] + orm = result.scalar_one_or_none() + return self._to_domain(orm) if orm else None - items = [self._to_domain(orm) for orm in orms] + async def _list_all_unfiltered(self) -> list[ImmutableEntityT]: + """List all immutable entities without filtering. - if cursor is not None and cursor.direction == PaginationDirection.BACKWARD: - has_next = True - has_prev = has_more - next_position = self._extract_cursor_position(orms[-1], sort_fields) if orms else None - prev_position = ( - self._extract_cursor_position(orms[0], sort_fields) if orms and has_prev else None - ) - else: - has_next = has_more - has_prev = cursor is not None - next_position = ( - self._extract_cursor_position(orms[-1], sort_fields) if orms and has_next else None - ) - prev_position = ( - self._extract_cursor_position(orms[0], sort_fields) if orms and has_prev else None - ) + PRIVATE METHOD - Use filter() in application code. - return PaginatedResult( - items=items, - has_next=has_next, - has_prev=has_prev, - next_position=next_position, - prev_position=prev_position, - ) + Returns: + List of domain entities + """ + orm_class = self._get_orm_class() + query = select(orm_class) + result = await self.session.execute(query) + orms = result.scalars().all() + return [self._to_domain(orm) for orm in orms] diff --git a/tests/leadr/common/domain/test_models.py b/tests/leadr/common/domain/test_models.py index 325c858..d9c2bd5 100644 --- a/tests/leadr/common/domain/test_models.py +++ b/tests/leadr/common/domain/test_models.py @@ -3,7 +3,9 @@ from datetime import UTC, datetime from uuid import uuid4 -from leadr.common.domain.models import Entity +import pytest + +from leadr.common.domain.models import Entity, ImmutableEntity class TestEntity: @@ -82,3 +84,66 @@ def test_entity_restore_method(self): entity.restore() assert entity.is_deleted is False assert entity.deleted_at is None + + +class TestImmutableEntity: + """Tests for ImmutableEntity base class.""" + + def test_auto_generates_id_and_created_at(self): + """Test that ImmutableEntity auto-generates id and created_at.""" + entity = ImmutableEntity() + assert entity.id is not None + assert entity.created_at is not None + assert isinstance(entity.created_at, datetime) + + def test_id_is_frozen(self): + """Test that id cannot be reassigned after creation.""" + entity = ImmutableEntity() + with pytest.raises(Exception): # noqa: B017 + entity.id = uuid4() + + def test_equality_by_id(self): + """Test that equality is based on id.""" + shared_id = uuid4() + entity1 = ImmutableEntity(id=shared_id) + entity2 = ImmutableEntity(id=shared_id) + assert entity1 == entity2 + + def test_inequality_different_ids(self): + """Test that entities with different ids are not equal.""" + entity1 = ImmutableEntity() + entity2 = ImmutableEntity() + assert entity1 != entity2 + + def test_inequality_different_types(self): + """Test that an entity is not equal to a non-entity.""" + entity = ImmutableEntity() + assert entity != "not an entity" + + def test_hash_by_id(self): + """Test that hash is based on id, allowing use in sets.""" + shared_id = uuid4() + entity1 = ImmutableEntity(id=shared_id) + entity2 = ImmutableEntity(id=shared_id) + assert hash(entity1) == hash(entity2) + assert len({entity1, entity2}) == 1 # type: ignore[reportUnhashable] + + def test_has_no_updated_at(self): + """Test that ImmutableEntity has no updated_at field.""" + entity = ImmutableEntity() + assert not hasattr(entity, "updated_at") + + def test_has_no_deleted_at(self): + """Test that ImmutableEntity has no deleted_at field.""" + entity = ImmutableEntity() + assert not hasattr(entity, "deleted_at") + + def test_has_no_soft_delete(self): + """Test that ImmutableEntity has no soft_delete method.""" + entity = ImmutableEntity() + assert not hasattr(entity, "soft_delete") + + def test_has_no_restore(self): + """Test that ImmutableEntity has no restore method.""" + entity = ImmutableEntity() + assert not hasattr(entity, "restore") diff --git a/tests/leadr/common/test_immutable_base_repository.py b/tests/leadr/common/test_immutable_base_repository.py new file mode 100644 index 0000000..bea3b31 --- /dev/null +++ b/tests/leadr/common/test_immutable_base_repository.py @@ -0,0 +1,136 @@ +"""Tests for ImmutableBaseRepository abstraction.""" + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +import pytest +import pytest_asyncio +from pydantic import UUID4 +from sqlalchemy import String, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped, mapped_column + +from leadr.common.api.pagination import PaginationParams +from leadr.common.domain.ids import PrefixedID +from leadr.common.domain.models import ImmutableEntity +from leadr.common.domain.pagination_result import PaginatedResult +from leadr.common.orm import ImmutableBase +from leadr.common.repositories import ImmutableBaseRepository + +# Test fixtures - Domain Entity + + +class MockImmutableEntity(ImmutableEntity): + """Test domain entity for ImmutableBaseRepository testing.""" + + name: str + + +# Test fixtures - ORM Model + + +class MockImmutableEntityORM(ImmutableBase): + """Test ORM model for ImmutableBaseRepository testing.""" + + __tablename__ = "test_immutable_entities" + + name: Mapped[str] = mapped_column(String, nullable=False) + + +# Test Repository Implementation + + +class MockImmutableRepository(ImmutableBaseRepository[MockImmutableEntity, MockImmutableEntityORM]): + """Concrete test repository for testing ImmutableBaseRepository.""" + + def _to_domain(self, orm: MockImmutableEntityORM) -> MockImmutableEntity: + return MockImmutableEntity( + id=orm.id, + name=orm.name, + created_at=orm.created_at, + ) + + def _to_orm(self, entity: MockImmutableEntity) -> MockImmutableEntityORM: + return MockImmutableEntityORM( + id=entity.id, + name=entity.name, + created_at=entity.created_at, + ) + + def _get_orm_class(self) -> type[MockImmutableEntityORM]: + return MockImmutableEntityORM + + async def filter( + self, + account_id: UUID4 | PrefixedID | None = None, + *, + pagination: PaginationParams, + **kwargs: Any, + ) -> PaginatedResult[MockImmutableEntity]: + query = select(MockImmutableEntityORM) + sort_fields = pagination.sort_spec + cursor = pagination.decode_cursor() + return await self._execute_paginated_query(query, sort_fields, cursor, pagination.limit) + + +@pytest.mark.asyncio +class TestImmutableBaseRepository: + """Test suite for ImmutableBaseRepository common functionality.""" + + @pytest_asyncio.fixture(autouse=True) + async def setup_test_table(self, test_engine): + """Create test table before each test.""" + async with test_engine.begin() as conn: + await conn.run_sync( + MockImmutableEntityORM.__table__.create, # type: ignore[attr-defined] + checkfirst=True, + ) + + async def test_create(self, db_session: AsyncSession): + """Test creating an immutable entity via repository.""" + repo = MockImmutableRepository(db_session) + entity_id = uuid4() + now = datetime.now(UTC) + + entity = MockImmutableEntity(id=entity_id, name="Test Event", created_at=now) + created = await repo.create(entity) + + assert created.id == entity_id + assert created.name == "Test Event" + + async def test_get_by_id_found(self, db_session: AsyncSession): + """Test retrieving an immutable entity by ID when it exists.""" + repo = MockImmutableRepository(db_session) + entity_id = uuid4() + now = datetime.now(UTC) + + entity = MockImmutableEntity(id=entity_id, name="Test Event", created_at=now) + await repo.create(entity) + + retrieved = await repo.get_by_id(entity_id) + assert retrieved is not None + assert retrieved.id == entity_id + assert retrieved.name == "Test Event" + + async def test_get_by_id_not_found(self, db_session: AsyncSession): + """Test retrieving a non-existent immutable entity returns None.""" + repo = MockImmutableRepository(db_session) + + result = await repo.get_by_id(uuid4()) + assert result is None + + async def test_list_all_unfiltered(self, db_session: AsyncSession): + """Test listing all immutable entities.""" + repo = MockImmutableRepository(db_session) + now = datetime.now(UTC) + + entity1 = MockImmutableEntity(name="Event 1", created_at=now) + entity2 = MockImmutableEntity(name="Event 2", created_at=now) + await repo.create(entity1) + await repo.create(entity2) + + all_entities = await repo._list_all_unfiltered() + assert len(all_entities) == 2 + names = {e.name for e in all_entities} + assert names == {"Event 1", "Event 2"}