diff --git a/agentex/Dockerfile b/agentex/Dockerfile index 0adebd52..998e5d2b 100644 --- a/agentex/Dockerfile +++ b/agentex/Dockerfile @@ -36,6 +36,9 @@ RUN uv sync --frozen --group dev --package agentex-backend COPY ${SOURCE_DIR}/src/ ./src/ EXPOSE 5003 ENV PYTHONPATH=/app +# Dev / single-worker stage. `--reload` doesn't work with `--workers >1`. +# Production multi-worker config is set in the final stage at the bottom of +# this file. CMD ["ddtrace-run", "uvicorn", "src.api.app:app", "--host", "0.0.0.0", "--port", "5003", "--reload"] # Docs builder stage @@ -90,4 +93,20 @@ USER nonroot EXPOSE 5003 ENV PYTHONPATH=/app -CMD ["ddtrace-run", "uvicorn", "src.api.app:app", "--host", "0.0.0.0", "--port", "5003"] +# Run uvicorn with multiple workers so a single pod isn't bottlenecked on one +# in-flight request at a time. Default of 4 is a safe match for the typical +# 1 CPU / 2Gi pod limits; bump UVICORN_WORKERS up to ~cpu_count when those +# limits are increased. +# +# IMPORTANT: each worker is a separate Python process with its OWN DB pools. +# Per-pod aggregates with the current defaults: +# - Postgres: POSTGRES_POOL_SIZE (10) × workers + max_overflow (20) × workers +# = 40 base + 80 overflow per pod at workers=4 (vs 10 + 20 at workers=1). +# - MongoDB: MONGODB_MAX_POOL_SIZE per worker; 4× the aggregate at workers=4. +# When changing UVICORN_WORKERS, review POSTGRES_POOL_SIZE / +# POSTGRES_MIDDLEWARE_POOL_SIZE / MONGODB_MAX_POOL_SIZE against the upstream +# DB's max_connections so you don't exhaust the pool. A rough rule of thumb is +# to keep `workers × POSTGRES_POOL_SIZE` below ~half of the upstream cap so +# overflow has headroom. +ENV UVICORN_WORKERS=4 +CMD ["sh", "-c", "exec ddtrace-run uvicorn src.api.app:app --host 0.0.0.0 --port 5003 --workers ${UVICORN_WORKERS}"] diff --git a/agentex/pyproject.toml b/agentex/pyproject.toml index b1e2a476..40858d3e 100644 --- a/agentex/pyproject.toml +++ b/agentex/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "kubernetes-asyncio>=31.1.0,<32", "aiohttp>=3.10.9,<4", "websockets~=14.2", - "pymongo>=4.11.2,<5", + "pymongo>=4.13.0,<5", "httpx[http2]>=0.27.2", "ddtrace>=3.13.0", "json_log_formatter>=1.1.1", diff --git a/agentex/src/adapters/crud_store/adapter_mongodb.py b/agentex/src/adapters/crud_store/adapter_mongodb.py index aab286fe..7a1eb6e0 100644 --- a/agentex/src/adapters/crud_store/adapter_mongodb.py +++ b/agentex/src/adapters/crud_store/adapter_mongodb.py @@ -7,7 +7,7 @@ import pymongo from bson import ObjectId -from pymongo.collection import Collection +from pymongo.asynchronous.collection import AsyncCollection from src.adapters.crud_store.exceptions import DuplicateItemError, ItemDoesNotExist from src.adapters.crud_store.port import CRUDRepository @@ -98,7 +98,7 @@ def __init__( model_class: type[T], ): self.db = db - self.collection: Collection = db[collection_name] + self.collection: AsyncCollection = db[collection_name] self.model_class = model_class # MongoDB already enforces uniqueness on _id @@ -216,7 +216,7 @@ async def create(self, item: T) -> T: if data.get("updated_at") is None: data["updated_at"] = now - result = self.collection.insert_one(data) + result = await self.collection.insert_one(data) # Update item with generated ID (as string) # Set the .id field with the string representation of _id @@ -274,7 +274,7 @@ async def batch_create(self, items: list[T]) -> list[T]: data_list.append(data) - result = self.collection.insert_many(data_list) + result = await self.collection.insert_many(data_list) # Update items with generated IDs (as strings) for idx, inserted_id in enumerate(result.inserted_ids): @@ -320,7 +320,7 @@ async def get(self, id: str | None = None, name: str | None = None) -> T | None: else: query = {"name": name} - document = self.collection.find_one(query) + document = await self.collection.find_one(query) if document is None: msg = ( f"Item with {'id' if id else 'name'} '{id or name}' does not exist." @@ -362,7 +362,9 @@ async def get_by_field(self, field_name: str, field_value: Any) -> T | None: except Exception: pass - document = self.collection.find_one({mongo_field_name: mongo_field_value}) + document = await self.collection.find_one( + {mongo_field_name: mongo_field_value} + ) if document is None: raise ItemDoesNotExist( f"Item with {field_name} '{field_value}' does not exist." @@ -392,8 +394,7 @@ async def batch_get( elif names: query["name"] = {"$in": names} - cursor = self.collection.find(query) - results = list(cursor) + results = await self.collection.find(query).to_list(length=None) if not results: key = "ids" if ids else "names" msg = f"No items found with {key} '{ids or names}'." @@ -429,14 +430,14 @@ async def update(self, item: T) -> T: # Add updated_at timestamp update_data["updated_at"] = datetime.now(UTC) - result = self.collection.update_one( + result = await self.collection.update_one( {"_id": id_value}, {"$set": update_data} ) if result.matched_count == 0: raise ItemDoesNotExist(f"Item with id '{id_value}' does not exist.") - updated_doc = self.collection.find_one({"_id": id_value}) + updated_doc = await self.collection.find_one({"_id": id_value}) return self._deserialize(updated_doc) except ItemDoesNotExist: raise @@ -476,7 +477,7 @@ async def delete(self, id: str | None = None, name: str | None = None) -> None: else: query = {"name": name} - result = self.collection.delete_one(query) + result = await self.collection.delete_one(query) if result.deleted_count == 0: msg = ( f"Item with {'id' if id else 'name'} '{id or name}' does not exist." @@ -507,7 +508,7 @@ async def batch_delete( elif names: query["name"] = {"$in": names} - result = self.collection.delete_many(query) + result = await self.collection.delete_many(query) if result.deleted_count == 0: key = "ids" if ids else "names" msg = f"No items found with {key} '{ids or names}'." @@ -536,11 +537,6 @@ async def list( raise ClientError("Page number must be greater than 0") page_number = page_number or 1 skip = (page_number - 1) * limit - if filters: - cursor = self.collection.find(filters) - else: - cursor = self.collection.find() - cursor = cursor.skip(skip).limit(limit) sort_list = [] if order_by: @@ -553,10 +549,14 @@ async def list( # Always use _id as tiebreaker sort_list.append(("_id", pymongo.ASCENDING)) - cursor = cursor.sort(sort_list) try: - return [self._deserialize(doc) for doc in cursor] + cursor = ( + self.collection.find(filters) if filters else self.collection.find() + ) + cursor = cursor.skip(skip).limit(limit).sort(sort_list) + docs = await cursor.to_list(length=None) + return [self._deserialize(doc) for doc in docs] except Exception as e: raise ServiceError( message=f"Failed to list items from MongoDB: {e}", detail=str(e) @@ -600,25 +600,24 @@ async def find_by_field( if filters: query.update(filters) - cursor = self.collection.find(query) - # Apply sorting sort_by_items = list(sort_by.items()) if sort_by else [] # Use ID for tiebreaking sort_by_items.append(("_id", 1)) - cursor = cursor.sort(sort_by_items) # Apply limit if specified limit = limit or DEFAULT_PAGE_LIMIT - cursor = cursor.limit(limit) # Apply page number if specified if page_number is not None and page_number < 1: raise ClientError("Page number must be greater than 0") - if page_number is not None: - cursor = cursor.skip((page_number - 1) * limit) + skip = (page_number - 1) * limit if page_number is not None else 0 - return [self._deserialize(doc) for doc in cursor] + cursor = self.collection.find(query).sort(sort_by_items).limit(limit) + if skip: + cursor = cursor.skip(skip) + docs = await cursor.to_list(length=None) + return [self._deserialize(doc) for doc in docs] except Exception as e: raise ServiceError( message=f"Failed to find items by field in MongoDB: {e}", detail=str(e) @@ -677,7 +676,7 @@ async def find_by_field_with_cursor( except Exception: cursor_object_id = cursor_id - cursor_doc = self.collection.find_one({"_id": cursor_object_id}) + cursor_doc = await self.collection.find_one({"_id": cursor_object_id}) if cursor_doc and "created_at" in cursor_doc: cursor_timestamp = cursor_doc["created_at"] if before_id: @@ -705,20 +704,17 @@ async def find_by_field_with_cursor( }, ] - # Create a cursor - db_cursor = self.collection.find(query) - # Apply sorting sort_by_items = list(sort_by.items()) if sort_by else [] # Use ID for tiebreaking sort_by_items.append(("_id", 1)) - db_cursor = db_cursor.sort(sort_by_items) # Apply limit if specified limit = limit or DEFAULT_PAGE_LIMIT - db_cursor = db_cursor.limit(limit) - return [self._deserialize(doc) for doc in db_cursor] + db_cursor = self.collection.find(query).sort(sort_by_items).limit(limit) + docs = await db_cursor.to_list(length=None) + return [self._deserialize(doc) for doc in docs] except Exception as e: raise ServiceError( message=f"Failed to find items by field with cursor in MongoDB: {e}", @@ -745,7 +741,9 @@ async def delete_by_field(self, field_name: str, field_value: Any) -> int: except Exception: pass - result = self.collection.delete_many({mongo_field_name: mongo_field_value}) + result = await self.collection.delete_many( + {mongo_field_name: mongo_field_value} + ) return result.deleted_count except Exception as e: raise ServiceError( diff --git a/agentex/src/api/health_interceptor.py b/agentex/src/api/health_interceptor.py index 784fe8f5..e2eaaa53 100644 --- a/agentex/src/api/health_interceptor.py +++ b/agentex/src/api/health_interceptor.py @@ -199,9 +199,8 @@ async def _check_mongodb(self, deps: Any) -> dict[str, Any]: if client is None: return {"healthy": False, "error": "Client not initialized"} - # MongoDB client is synchronous, run in thread pool async with asyncio.timeout(DEPENDENCY_CHECK_TIMEOUT): - await asyncio.to_thread(client.admin.command, "ping") + await client.admin.command("ping") return {"healthy": True} except TimeoutError: diff --git a/agentex/src/config/dependencies.py b/agentex/src/config/dependencies.py index 65c042cc..f37cf043 100644 --- a/agentex/src/config/dependencies.py +++ b/agentex/src/config/dependencies.py @@ -3,12 +3,12 @@ from typing import Annotated import httpx -import pymongo import redis.asyncio as redis from docker import DockerClient from fastapi import Depends from kubernetes_asyncio import config as k8s_config -from pymongo.database import Database as MongoDBDatabase +from pymongo import AsyncMongoClient +from pymongo.asynchronous.database import AsyncDatabase from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -43,8 +43,8 @@ def __init__(self): self.database_async_read_write_engine: AsyncEngine | None = None self.database_async_middleware_read_write_engine: AsyncEngine | None = None self.docker_client = None - self.mongodb_client: pymongo.MongoClient | None = None - self.mongodb_database: MongoDBDatabase | None = None + self.mongodb_client: AsyncMongoClient | None = None + self.mongodb_database: AsyncDatabase | None = None self.httpx_client: httpx.AsyncClient | None = None self.redis_pool: redis.ConnectionPool | None = None self.database_async_read_only_engine: AsyncEngine | None = None @@ -122,7 +122,7 @@ async def load(self): logger.info("Connecting to MongoDB") - self.mongodb_client = pymongo.MongoClient( + self.mongodb_client = AsyncMongoClient( mongodb_uri, serverSelectionTimeoutMS=20000, connectTimeoutMS=20000, @@ -136,7 +136,7 @@ async def load(self): self.mongodb_database = self.mongodb_client[mongodb_database_name] # Ping the database to verify connection - self.mongodb_client.admin.command("ping") + await self.mongodb_client.admin.command("ping") logger.info( f"Successfully connected to MongoDB database '{mongodb_database_name}'" ) @@ -146,7 +146,7 @@ async def load(self): from src.config.mongodb_indexes import ensure_mongodb_indexes try: - ensure_mongodb_indexes(self.mongodb_database) + await ensure_mongodb_indexes(self.mongodb_database) logger.info("MongoDB indexes ensured successfully") except Exception as index_error: # Don't fail startup if index creation fails @@ -242,7 +242,7 @@ async def force_reload(self): if self.database_async_read_only_engine: await self.database_async_read_only_engine.dispose() if self.mongodb_client: - self.mongodb_client.close() + await self.mongodb_client.close() # Reset state self._loaded = False @@ -293,7 +293,7 @@ async def async_shutdown(): # Close MongoDB connection if global_dependencies.mongodb_client: - global_dependencies.mongodb_client.close() + await global_dependencies.mongodb_client.close() # Close HTTPX client if global_dependencies.httpx_client: @@ -375,7 +375,7 @@ def middleware_async_read_only_session_maker() -> async_sessionmaker[AsyncSessio DockerClient, Depends(lambda: GlobalDependencies().docker_client) ] DMongoDBDatabase = Annotated[ - MongoDBDatabase, Depends(lambda: GlobalDependencies().mongodb_database) + AsyncDatabase, Depends(lambda: GlobalDependencies().mongodb_database) ] DHttpxClient = Annotated[ httpx.AsyncClient, Depends(lambda: GlobalDependencies().httpx_client) diff --git a/agentex/src/config/mongodb_indexes.py b/agentex/src/config/mongodb_indexes.py index d87231c3..4e8a3c2c 100644 --- a/agentex/src/config/mongodb_indexes.py +++ b/agentex/src/config/mongodb_indexes.py @@ -8,7 +8,7 @@ from typing import Any -from pymongo.database import Database as MongoDBDatabase +from pymongo.asynchronous.database import AsyncDatabase from pymongo.errors import OperationFailure from src.utils.logging import make_logger @@ -16,7 +16,7 @@ logger = make_logger(__name__) -def ensure_mongodb_indexes(mongodb_database: MongoDBDatabase) -> None: +async def ensure_mongodb_indexes(mongodb_database: AsyncDatabase) -> None: """ Create all MongoDB indexes defined in repository classes. @@ -76,7 +76,7 @@ def ensure_mongodb_indexes(mongodb_database: MongoDBDatabase) -> None: index_kwargs[key] = index_spec[key] # Create the index - result = collection.create_index(keys, **index_kwargs) + result = await collection.create_index(keys, **index_kwargs) if description: logger.info(f" ✓ Created index '{name or result}': {description}") @@ -100,7 +100,7 @@ def ensure_mongodb_indexes(mongodb_database: MongoDBDatabase) -> None: logger.info("MongoDB index creation completed.") -def drop_all_indexes(mongodb_database: MongoDBDatabase) -> None: +async def drop_all_indexes(mongodb_database: AsyncDatabase) -> None: """ Drop all non-_id indexes from MongoDB collections. @@ -129,7 +129,7 @@ def drop_all_indexes(mongodb_database: MongoDBDatabase) -> None: try: # Drop all indexes except _id - collection.drop_indexes() + await collection.drop_indexes() logger.info(f" ✓ Dropped all indexes from collection '{collection_name}'") except Exception as e: logger.error(f" ✗ Failed to drop indexes from '{collection_name}': {e}") @@ -137,7 +137,7 @@ def drop_all_indexes(mongodb_database: MongoDBDatabase) -> None: logger.warning("Index dropping completed.") -def get_index_stats(mongodb_database: MongoDBDatabase) -> dict[str, Any]: +async def get_index_stats(mongodb_database: AsyncDatabase) -> dict[str, Any]: """ Get statistics about indexes for all collections. @@ -165,7 +165,7 @@ def get_index_stats(mongodb_database: MongoDBDatabase) -> dict[str, Any]: collection = mongodb_database[collection_name] try: - indexes = list(collection.list_indexes()) + indexes = [idx async for idx in collection.list_indexes()] stats[collection_name] = { "count": len(indexes), "indexes": [ diff --git a/agentex/src/domain/repositories/task_state_repository.py b/agentex/src/domain/repositories/task_state_repository.py index 16e43147..9cc0381e 100644 --- a/agentex/src/domain/repositories/task_state_repository.py +++ b/agentex/src/domain/repositories/task_state_repository.py @@ -43,7 +43,7 @@ def __init__(self, db: DMongoDBDatabase): async def get_by_task_and_agent( self, task_id: str, agent_id: str ) -> StateEntity | None: - doc = self.collection.find_one({"task_id": task_id, "agent_id": agent_id}) + doc = await self.collection.find_one({"task_id": task_id, "agent_id": agent_id}) return self._deserialize(doc) if doc else None diff --git a/agentex/tests/fixtures/containers.py b/agentex/tests/fixtures/containers.py index 35259361..5e801ec2 100644 --- a/agentex/tests/fixtures/containers.py +++ b/agentex/tests/fixtures/containers.py @@ -66,26 +66,26 @@ def mongodb_connection_string(mongodb_container): return mongodb_container.get_connection_url() -@pytest.fixture -def mongodb_database(mongodb_connection_string): +@pytest_asyncio.fixture +async def mongodb_database(mongodb_connection_string): """ Function-scoped MongoDB database instance. Creates a fresh database for each test to ensure isolation. + Returns an AsyncDatabase (pymongo native async API) so repositories + consuming it can `await` collection operations. """ import time - from pymongo import MongoClient + from pymongo import AsyncMongoClient - # Create a unique database name for this test db_name = f"test_agentex_{int(time.time() * 1000)}" - client = MongoClient(mongodb_connection_string) + client = AsyncMongoClient(mongodb_connection_string) db = client[db_name] yield db - # Cleanup: Drop the database after the test - client.drop_database(db_name) - client.close() + await client.drop_database(db_name) + await client.close() @pytest.fixture(scope="session") diff --git a/agentex/tests/fixtures/database.py b/agentex/tests/fixtures/database.py index e1fbbfb9..3abfdf50 100644 --- a/agentex/tests/fixtures/database.py +++ b/agentex/tests/fixtures/database.py @@ -3,7 +3,7 @@ import pytest import redis.asyncio as redis -from pymongo import MongoClient +from pymongo import AsyncMongoClient from sqlalchemy.ext.asyncio import create_async_engine @@ -39,22 +39,24 @@ async def base_postgres_session(postgres_session_maker): @pytest.fixture -def base_mongodb_database(mongodb_connection_string): +async def base_mongodb_database(mongodb_connection_string): """ Base MongoDB database with cleanup. Creates a unique database per test and cleans up afterward. This is the shared foundation for both unit and integration tests. + Returns an AsyncDatabase (pymongo native async API) so repositories + consuming it can `await` collection operations. """ # Create unique database name with process ID for extra uniqueness db_name = f"test_agentex_{int(time.time() * 1000)}_{os.getpid()}" - client = MongoClient(mongodb_connection_string) + client = AsyncMongoClient(mongodb_connection_string) db = client[db_name] yield db # Cleanup: Drop the database after the test - client.drop_database(db_name) - client.close() + await client.drop_database(db_name) + await client.close() @pytest.fixture @@ -88,7 +90,7 @@ async def unit_db_session(base_postgres_session): @pytest.fixture -def unit_mongodb_database(base_mongodb_database): +async def unit_mongodb_database(base_mongodb_database): """ MongoDB database for unit tests. This is an alias for base_mongodb_database for consistency. diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index 388da467..b715d223 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -109,7 +109,7 @@ async def isolated_test_schema(integration_test_db_urls): ) # Create MongoDB client for database management - mongodb_client = pymongo.MongoClient( + mongodb_client = pymongo.AsyncMongoClient( integration_test_db_urls["mongodb_url"], serverSelectionTimeoutMS=15000, # Longer timeout for container startup connectTimeoutMS=10000, @@ -183,7 +183,7 @@ async def drop_schema(): try: # Drop MongoDB database - mongodb_client.drop_database(mongodb_db_name) + await mongodb_client.drop_database(mongodb_db_name) except Exception as e: print(f"Warning: Failed to drop MongoDB database {mongodb_db_name}: {e}") @@ -195,7 +195,7 @@ async def drop_schema(): try: await asyncio.gather(*cleanup_tasks, return_exceptions=True) - mongodb_client.close() + await mongodb_client.close() await redis_client.aclose() except Exception as e: print(f"Warning: Failed to cleanup connections: {e}") diff --git a/agentex/tests/unit/adapters/__init__.py b/agentex/tests/unit/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentex/tests/unit/adapters/test_mongodb_adapter.py b/agentex/tests/unit/adapters/test_mongodb_adapter.py new file mode 100644 index 00000000..1e130dfb --- /dev/null +++ b/agentex/tests/unit/adapters/test_mongodb_adapter.py @@ -0,0 +1,179 @@ +"""Unit tests for the generic MongoDB CRUD adapter. + +These tests exercise the adapter directly against a real MongoDB +(via the testcontainers `mongodb_database` fixture, which yields an +`AsyncDatabase` from pymongo's native async API). The focus is the +code paths touched by the motor → pymongo-async migration: + + - direct collection ops: insert_one / insert_many / find_one / + update_one / delete_one / delete_many + - cursor materialization via `await cursor.to_list(length=None)` for + list(), find_by_field(), find_by_field_with_cursor() + - timestamp + id round-tripping +""" + +from datetime import datetime +from typing import Any + +import pytest +from pydantic import BaseModel +from src.adapters.crud_store.adapter_mongodb import MongoDBCRUDRepository +from src.adapters.crud_store.exceptions import ItemDoesNotExist + + +def _naive(dt: datetime) -> datetime: + """Strip tzinfo so we can compare adapter-supplied (tz-aware) vs + Mongo-roundtripped (tz-naive UTC) timestamps without TypeErrors.""" + return dt.replace(tzinfo=None) if dt.tzinfo is not None else dt + + +class _Item(BaseModel): + id: str | None = None + name: str + group: str | None = None + payload: dict[str, Any] | None = None + created_at: Any | None = None + updated_at: Any | None = None + + +def _make_repo(mongodb_database) -> MongoDBCRUDRepository[_Item]: + return MongoDBCRUDRepository( + db=mongodb_database, collection_name="items", model_class=_Item + ) + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_create_get_update_delete_roundtrip(mongodb_database): + repo = _make_repo(mongodb_database) + + created = await repo.create(_Item(name="alpha", group="g1")) + assert created.id is not None + assert created.created_at is not None + assert created.updated_at is not None + + fetched = await repo.get(id=created.id) + assert fetched is not None + assert fetched.name == "alpha" + assert fetched.group == "g1" + + fetched.name = "alpha-renamed" + updated = await repo.update(fetched) + assert updated.name == "alpha-renamed" + assert _naive(updated.updated_at) >= _naive(created.updated_at) + + await repo.delete(id=created.id) + with pytest.raises(ItemDoesNotExist): + await repo.get(id=created.id) + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_batch_create_and_batch_get(mongodb_database): + repo = _make_repo(mongodb_database) + + items = [_Item(name=f"b{i}", group="batch") for i in range(5)] + created = await repo.batch_create(items) + ids = [c.id for c in created] + assert len(ids) == 5 + assert all(ids) + + fetched = await repo.batch_get(ids=ids) + assert {f.name for f in fetched} == {"b0", "b1", "b2", "b3", "b4"} + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_list_pagination_and_ordering(mongodb_database): + repo = _make_repo(mongodb_database) + + await repo.batch_create( + [_Item(name=f"item-{i:02d}", group="list") for i in range(15)] + ) + + page1 = await repo.list( + filters={"group": "list"}, limit=5, page_number=1, order_by="name" + ) + page2 = await repo.list( + filters={"group": "list"}, limit=5, page_number=2, order_by="name" + ) + page3 = await repo.list( + filters={"group": "list"}, limit=5, page_number=3, order_by="name" + ) + + assert [i.name for i in page1] == [f"item-{n:02d}" for n in range(0, 5)] + assert [i.name for i in page2] == [f"item-{n:02d}" for n in range(5, 10)] + assert [i.name for i in page3] == [f"item-{n:02d}" for n in range(10, 15)] + + page1_desc = await repo.list( + filters={"group": "list"}, + limit=5, + page_number=1, + order_by="name", + order_direction="desc", + ) + assert [i.name for i in page1_desc] == [f"item-{n:02d}" for n in range(14, 9, -1)] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_find_by_field_filters_and_limit(mongodb_database): + repo = _make_repo(mongodb_database) + + await repo.batch_create( + [_Item(name=f"f{i}", group="A" if i % 2 == 0 else "B") for i in range(10)] + ) + + group_a = await repo.find_by_field("group", "A", limit=10) + group_b = await repo.find_by_field("group", "B", limit=10) + assert len(group_a) == 5 + assert len(group_b) == 5 + assert all(item.group == "A" for item in group_a) + assert all(item.group == "B" for item in group_b) + + page1 = await repo.find_by_field("group", "A", limit=2, page_number=1) + page2 = await repo.find_by_field("group", "A", limit=2, page_number=2) + assert len(page1) == 2 + assert len(page2) == 2 + assert {p.id for p in page1}.isdisjoint({p.id for p in page2}) + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_find_by_field_with_cursor_before_after(mongodb_database): + repo = _make_repo(mongodb_database) + + created = await repo.batch_create( + [_Item(name=f"c{i}", group="cursor") for i in range(8)] + ) + middle_id = created[4].id + + after = await repo.find_by_field_with_cursor( + "group", "cursor", limit=10, after_id=middle_id + ) + before = await repo.find_by_field_with_cursor( + "group", "cursor", limit=10, before_id=middle_id + ) + + assert middle_id not in {i.id for i in after} + assert middle_id not in {i.id for i in before} + # Every result is from the same group + assert all(i.group == "cursor" for i in after + before) + # Combined coverage minus the cursor doc equals the rest of the set + assert {i.id for i in after} | {i.id for i in before} == { + i.id for i in created if i.id != middle_id + } + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_delete_by_field(mongodb_database): + repo = _make_repo(mongodb_database) + + await repo.batch_create([_Item(name=f"d{i}", group="del") for i in range(6)]) + + deleted = await repo.delete_by_field("group", "del") + assert deleted == 6 + + with pytest.raises(ItemDoesNotExist): + await repo.batch_get(names=[f"d{i}" for i in range(6)]) diff --git a/agentex/tests/unit/api/test_health_interceptor.py b/agentex/tests/unit/api/test_health_interceptor.py index d58b9190..6d4527dd 100644 --- a/agentex/tests/unit/api/test_health_interceptor.py +++ b/agentex/tests/unit/api/test_health_interceptor.py @@ -106,7 +106,9 @@ def should_not_be_called(request): mock_redis_pool = MagicMock() mock_mongodb_client = MagicMock() - mock_mongodb_client.admin.command = MagicMock(return_value={"ok": 1}) + # AsyncMongoClient.admin.command(...) is awaitable on the native async + # driver, so the mock must return a coroutine — AsyncMock does that. + mock_mongodb_client.admin.command = AsyncMock(return_value={"ok": 1}) mock_deps = MagicMock() mock_deps.database_async_read_write_engine = mock_engine diff --git a/uv.lock b/uv.lock index ec8d4d6d..82187257 100644 --- a/uv.lock +++ b/uv.lock @@ -138,7 +138,7 @@ requires-dist = [ { name = "opentelemetry-exporter-otlp", specifier = ">=1.28.0" }, { name = "opentelemetry-sdk", specifier = ">=1.28.0" }, { name = "psycopg2-binary", specifier = ">=2.9.9,<3" }, - { name = "pymongo", specifier = ">=4.11.2,<5" }, + { name = "pymongo", specifier = ">=4.13.0,<5" }, { name = "python-dotenv", specifier = ">=1.2.2,<2" }, { name = "python-multipart", specifier = ">=0.0.27" }, { name = "pyyaml", specifier = ">=6.0,<7" },