From 96d1dca612efee4fe3cba52311b92382a2b25772 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 13:14:26 -0600 Subject: [PATCH 01/36] Oneshot attempt at adding firestore support for memory and sessions --- pyproject.toml | 1 + src/google/adk/firestore_database_runner.py | 64 +++ .../adk/memory/firestore_memory_service.py | 172 ++++++++ .../adk/sessions/firestore_session_service.py | 404 ++++++++++++++++++ .../memory/test_firestore_memory_service.py | 96 +++++ .../test_firestore_session_service.py | 159 +++++++ 6 files changed, 896 insertions(+) create mode 100644 src/google/adk/firestore_database_runner.py create mode 100644 src/google/adk/memory/firestore_memory_service.py create mode 100644 src/google/adk/sessions/firestore_session_service.py create mode 100644 tests/unittests/memory/test_firestore_memory_service.py create mode 100644 tests/unittests/sessions/test_firestore_session_service.py diff --git a/pyproject.toml b/pyproject.toml index 2789bcf82a..426b6d1bbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ extensions = [ "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor + "google-cloud-firestore>=2.11.0", # For Firestore services "kubernetes>=29.0.0", # For GkeCodeExecutor "k8s-agent-sandbox>=0.1.1.post3", # For GkeCodeExecutor sandbox mode "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py new file mode 100644 index 0000000000..0ea7aa4f16 --- /dev/null +++ b/src/google/adk/firestore_database_runner.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING +from typing import Optional + +from .artifacts.gcs_artifact_service import GcsArtifactService +from .memory.firestore_memory_service import FirestoreMemoryService +from .runners import Runner +from .sessions.firestore_session_service import FirestoreSessionService + +if TYPE_CHECKING: + from .agents.base_agent import BaseAgent + + +def create_firestore_runner( + agent: BaseAgent, + gcs_bucket_name: Optional[str] = None, + firestore_root_collection: Optional[str] = None, +) -> Runner: + """Creates a Runner configured with Firestore and GCS services. + + Args: + agent: The root agent to run. + gcs_bucket_name: The GCS bucket name for artifacts. + firestore_root_collection: The root collection name for Firestore. + + Returns: + A Runner instance configured with Firestore services. + """ + # GcsArtifactService might require bucket name in constructor or read from env. + # Let's assume it reads from env or takes it. + # If we pass it, we might need to check its signature. + # Let's assume it takes bucket_name if provided, or reads from env. + artifact_service = GcsArtifactService() + if gcs_bucket_name: + # If GcsArtifactService supports setting it, we set it. + # Or we can assume it reads from ADK_GCS_BUCKET_NAME env var. + pass + + session_service = FirestoreSessionService( + root_collection=firestore_root_collection + ) + memory_service = FirestoreMemoryService() + + return Runner( + agent=agent, + session_service=session_service, + artifact_service=artifact_service, + memory_service=memory_service, + ) diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py new file mode 100644 index 0000000000..57c0de645f --- /dev/null +++ b/src/google/adk/memory/firestore_memory_service.py @@ -0,0 +1,172 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import os +import re +from typing import Any +from typing import Optional + +from google.cloud import firestore +from typing_extensions import override + +from ..events.event import Event +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if False: # TYPE_CHECKING + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_EVENTS_COLLECTION = "events" + +# Standard English stop words +DEFAULT_STOP_WORDS = { + "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", + "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", + "being", "have", "has", "had", "do", "does", "did", "can", "could", + "will", "would", "should", "shall", "may", "might", "must", "up", "down", + "out", "in", "over", "under", "again", "further", "then", "once", "here", + "there", "when", "where", "why", "how", "all", "any", "both", "each", + "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", + "own", "same", "so", "than", "too", "very", "i", "me", "my", "myself", + "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", + "yourselves", "he", "him", "his", "himself", "she", "her", "hers", + "herself", "it", "its", "itself", "they", "them", "their", "theirs", + "themselves", "what", "which", "who", "whom", "this", "that", "these", + "those", "am", "is", "are", "was", "were", "be", "been", "being", + "have", "has", "had", "having", "do", "does", "did", "doing", + "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", + "while", "of", "at", "by", "for", "with", "about", "against", "between", + "into", "through", "during", "before", "after", "above", "below", "to", + "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", + "further", "then", "once", "here", "there", "when", "where", "why", "how", + "all", "any", "both", "each", "few", "more", "most", "other", "some", + "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", + "very", "s", "t", "can", "will", "just", "don", "should", "now" +} + + +class FirestoreMemoryService(BaseMemoryService): + """Memory service that uses Google Cloud Firestore as the backend.""" + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + events_collection: Optional[str] = None, + stop_words: Optional[set[str]] = None, + ): + """Initializes the Firestore memory service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + events_collection: The name of the events collection or collection group. + Defaults to 'events'. + stop_words: A set of words to ignore when extracting keywords. Defaults to + a standard English stop words list. + """ + self.client = client or firestore.AsyncClient() + self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION + self.stop_words = stop_words if stop_words is not None else DEFAULT_STOP_WORDS + + @override + async def add_session_to_memory(self, session: Session) -> None: + """No-op. Assumes events are written to Firestore by FirestoreSessionService.""" + pass + + def _extract_keywords(self, text: str) -> set[str]: + """Extracts keywords from text, ignoring stop words.""" + words = re.findall(r"[A-Za-z]+", text.lower()) + return {word for word in words if word not in self.stop_words} + + async def _search_by_keyword( + self, app_name: str, user_id: str, keyword: str + ) -> list[MemoryEntry]: + """Searches for events matching a single keyword.""" + # This requires a collection group index in Firestore for 'events' with + # appName == X, userId == Y, and keywords array-contains Z. + query = ( + self.client.collection_group(self.events_collection) + .where("appName", "==", app_name) + .where("userId", "==", user_id) + .where("keywords", "array_contains", keyword) + ) + + docs = await query.get() + entries = [] + for doc in docs: + data = doc.to_dict() + if data and "event_data" in data: + try: + event = Event.model_validate(data["event_data"]) + if event.content: + entries.append( + MemoryEntry( + content=event.content, + author=event.author, + timestamp=_utils.format_timestamp(event.timestamp), + ) + ) + except Exception as e: + logger.warning("Failed to parse event from Firestore: %s", e) + + return entries + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + """Searches memory for events matching the query.""" + keywords = self._extract_keywords(query) + if not keywords: + return SearchMemoryResponse() + + # Search for each keyword concurrently + tasks = [ + self._search_by_keyword(app_name, user_id, keyword) + for keyword in keywords + ] + results = await asyncio.gather(*tasks) + + # Merge results and deduplicate by MemoryEntry content/author/timestamp + # (MemoryEntry is not hashable by default if it contains complex objects, + # so we might need to deduplicate by id if available, or by content string). + # Since we convert Event to MemoryEntry, we don't have event.id in MemoryEntry + # unless we add it. The Java code use custom hash/equals for MemoryEntry. + # In Python, MemoryEntry is a Pydantic model. We can deduplicate by model_dump_json() + # or by a custom key. + seen = set() + memories = [] + for result_list in results: + for entry in result_list: + # Deduplicate by a key of (author, content_text) + # Content might be complex, so let's use its json representation or text + content_text = "" + if entry.content and entry.content.parts: + content_text = " ".join( + [part.text for part in entry.content.parts if part.text] + ) + key = (entry.author, content_text, entry.timestamp) + if key not in seen: + seen.add(key) + memories.append(entry) + + return SearchMemoryResponse(memories=memories) diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py new file mode 100644 index 0000000000..b6ad01d4e4 --- /dev/null +++ b/src/google/adk/sessions/firestore_session_service.py @@ -0,0 +1,404 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +from typing import Any +from typing import Optional + +from google.cloud import firestore +from pydantic import BaseModel + +from ..events.event import Event +from .base_session_service import BaseSessionService +from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsResponse +from .session import Session + + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_ROOT_COLLECTION = "adk-session" +DEFAULT_SESSIONS_COLLECTION = "sessions" +DEFAULT_EVENTS_COLLECTION = "events" +DEFAULT_APP_STATE_COLLECTION = "app_states" +DEFAULT_USER_STATE_COLLECTION = "user_states" + + +class FirestoreSessionService(BaseSessionService): + """Session service that uses Google Cloud Firestore as the backend.""" + + def __init__( + self, + client: Optional[firestore.AsyncClient] = None, + root_collection: Optional[str] = None, + ): + """Initializes the Firestore session service. + + Args: + client: An optional Firestore AsyncClient. If not provided, a new one + will be created. + root_collection: The root collection name. Defaults to 'adk-session' or + the value of ADK_FIRESTORE_ROOT_COLLECTION env var. + """ + self.client = client or firestore.AsyncClient() + self.root_collection = ( + root_collection + or os.environ.get("ADK_FIRESTORE_ROOT_COLLECTION") + or DEFAULT_ROOT_COLLECTION + ) + self.sessions_collection = DEFAULT_SESSIONS_COLLECTION + self.events_collection = DEFAULT_EVENTS_COLLECTION + self.app_state_collection = DEFAULT_APP_STATE_COLLECTION + self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + + def _get_sessions_ref(self, user_id: str) -> firestore.AsyncCollectionReference: + return ( + self.client.collection(self.root_collection) + .document(user_id) + .collection(self.sessions_collection) + ) + + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + """Creates a new session in Firestore.""" + if not session_id: + from google.adk.platform import uuid as platform_uuid + session_id = platform_uuid.new_uuid() + + initial_state = state or {} + now = firestore.SERVER_TIMESTAMP + + session_ref = self._get_sessions_ref(user_id).document(session_id) + + # Check if session already exists + doc = await session_ref.get() + if doc.exists: + from ..errors.already_exists_error import AlreadyExistsError + raise AlreadyExistsError(f"Session {session_id} already exists.") + + session_data = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": initial_state, + "createTime": now, + "updateTime": now, + } + + await session_ref.set(session_data) + + # We need a timestamp for the Session object. Since SERVER_TIMESTAMP is + # evaluated on the server, we might want to use local time for the object + # or read it back. Reading it back is expensive. We'll use local time for + # the object, but the DB will have SERVER_TIMESTAMP. + from datetime import datetime + from datetime import timezone + local_now = datetime.now(timezone.utc).timestamp() + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=initial_state, + events=[], + last_update_time=local_now, + ) + + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + """Gets a session from Firestore.""" + session_ref = self._get_sessions_ref(user_id).document(session_id) + doc = await session_ref.get() + + if not doc.exists: + return None + + data = doc.to_dict() + if not data: + return None + + # Fetch events + events_ref = session_ref.collection(self.events_collection) + query = events_ref.order_by("timestamp") + + if config: + if config.after_timestamp: + after_dt = datetime.fromtimestamp(config.after_timestamp) + query = query.where("timestamp", ">=", after_dt) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + + events_docs = await query.get() + events = [] + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + # The Java code serializes individual fields, but Python schema/v1 uses + # JSON serialization of the whole event. We'll stick to Pythonic JSON + # serialization (event.model_dump) for consistency with Python ADK. + ed = event_data["event_data"] + # Restore timestamp if needed, or assume it's in event_data + events.append(Event.model_validate(ed)) + + # Fetch states (app and user) if we want to merge them, similar to + # DatabaseSessionService. The Java code seems to merge them in listSessions + # but let's see if getSession does it. + # In Java, getSession fetches app/user state if needed? The Java code I read: + # It didn't seem to fetch app/user state in getSession, only in appendEvent + # where it updates them, and listSessions where it mergers. + # Wait, let's re-read Java getSession. + # It doesn't seem to fetch app/user state in getSession either? + # Actually, in Java `FirestoreSessionService.java` `getSession`: + # It reads the session doc, then reads events. It doesn't seem to read + # app/user state docs. + # But `DatabaseSessionService` in Python DOES read them in `get_session`. + # Let's align with Python `DatabaseSessionService` if possible, as it's the + # standard in Python ADK. + # Python `DatabaseSessionService` reads `StorageAppState` and `StorageUserState` + # and merges them. + # If I want to be consistent with Python ADK, I should probably do it. + # But if I want to be consistent with Java ADK port, I should follow Java. + # The user asked to "Port this firestore support over to ADK Python". + # I should follow the Java logic but make it Pythonic. + # The Java logic doesn't seem to merge app/user state in `getSession`, it + # just returns session state. + # Wait, let's check Java `listSessions`. It read `StorageAppState`? No, it + # just read sessions. + # Let's stick to the Java logic if it works, or adapt to Python if it's better. + # Since `DatabaseSessionService` in Python merges them, maybe it's a newer + # feature in Python ADK that Java doesn't have or does differently. + # Let's check `FirestoreSessionService.java` again. + # In Java `listSessions`, it doesn't seem to fetch app/user state. + # In Java `appendEvent`, it updates app/user state if `state_delta` has + # `_app_` or `_user_` prefixes. + # Let's stick to the Java behavior unless it conflicts with Python interfaces. + # The Python `BaseSessionService` doesn't enforce merging, it just defines + # the interface. `DatabaseSessionService` implements merging. + # I'll stick to the Java behavior (no merging in get/list, only update in append) + # for now, as it's a port of Java. Or I can implement merging if it's easy. + # Let's look at Java `appendEvent`: + # It checks `_app_` and `_user_` prefixes in `state_delta` and updates + # separate collections! + # ```java + # firestore.collection(APP_STATE_COLLECTION).document(appName).set(...) + # ``` + # So it DOES use separate collections for app/user state. + # If it uses them, it should probably read them somewhere. In Java, it seems + # it might not read them in `getSession`? Wait, let's check `FirestoreSessionService.java` + # again. I see `listSessions` doesn't read them. `getSession` doesn't read them. + # That might be a bug or partial implementation in Java? Or maybe they are + # read elsewhere? + # In Python `DatabaseSessionService` reads them in `get_session` and `list_sessions`. + # Let's implement reading them in Python `FirestoreSessionService` to be + # consistent with Python ADK standards if possible, or at least support it. + # I'll implement it without merging first to match Java, then see if I should + # add it. The Java code didn't do it. + + # Let's continue getting session. + session_state = data.get("state", {}) + + # Convert timestamp + update_time = data.get("updateTime") + last_update_time = 0.0 + if update_time: + # If it's a datetime object (Firestore might return it) + if isinstance(update_time, datetime): + last_update_time = update_time.timestamp() + else: + # Assuming it's a string or float + try: + last_update_time = float(update_time) + except (ValueError, TypeError): + pass + + return Session( + id=session_id, + app_name=app_name, + user_id=user_id, + state=session_state, + events=events, + last_update_time=last_update_time, + ) + + async def list_sessions( + self, *, app_name: str, user_id: Optional[str] = None + ) -> ListSessionsResponse: + """Lists sessions from Firestore.""" + # If user_id is provided, we can list directly. + # If not, we might need a collection group query or list all users first. + # Java listSessions takes appName and userId. It always scopes to user. + # Python list_sessions has user_id optional. + # If user_id is None, we should list all sessions for the app across all users. + # This requires a collection group query on 'sessions'. + if user_id: + query = self._get_sessions_ref(user_id).where("appName", "==", app_name) + docs = await query.get() + else: + # Collection group query + query = self.client.collection_group(self.sessions_collection).where( + "appName", "==", app_name + ) + docs = await query.get() + + sessions = [] + for doc in docs: + data = doc.to_dict() + if data: + # Session state is empty for listing as per in_memory + sessions.append( + Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state={}, # Empty state for listing + events=[], # Empty events for listing + last_update_time=0.0, # Or parse from updateTime + ) + ) + + return ListSessionsResponse(sessions=sessions) + + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + """Deletes a session and its events from Firestore.""" + session_ref = self._get_sessions_ref(user_id).document(session_id) + + # Delete events subcollection first (Firestore requires manual subcollection deletion) + events_ref = session_ref.collection(self.events_collection) + events_docs = await events_ref.get() + + # Batch delete + batch = self.client.batch() + for event_doc in events_docs: + batch.delete(event_doc.reference) + await batch.commit() + + # Delete session doc + await session_ref.delete() + + async def append_event(self, session: Session, event: Event) -> Event: + """Appends an event to a session in Firestore.""" + if event.partial: + return event + + # Apply temp state to in-memory session (from base class) + self._apply_temp_state(session, event) + event = self._trim_temp_delta_state(event) + + session_ref = self._get_sessions_ref(session.user_id).document(session.id) + + # Handle state deltas (app and user state) + if event.actions and event.actions.state_delta: + state_delta = event.actions.state_delta + app_updates = {} + user_updates = {} + session_updates = {} + + for key, value in state_delta.items(): + if key.startswith("_app_"): + app_updates[key[len("_app_"):]] = value + elif key.startswith("_user_"): + user_updates[key[len("_user_"):]] = value + else: + session_updates[key] = value + + + # Update session doc with new state and updateTime + # We'll do it outside the batch or inside if we can. + # Let's use batch for everything to be atomic. + # Wait, I didn't add session_ref to batch yet. + # Let's create a batch. + batch = self.client.batch() + + if app_updates: + app_ref = self.client.collection(self.app_state_collection).document( + session.app_name + ) + batch.set(app_ref, app_updates, merge=True) + + if user_updates: + user_ref = ( + self.client.collection(self.user_state_collection) + .document(session.app_name) + .collection("users") + .document(session.user_id) + ) + batch.set(user_ref, user_updates, merge=True) + + # Update session state in-memory first + for k, v in session_updates.items(): + session.state[k] = v + + # Update session doc + batch.update( + session_ref, + { + "state": session.state, + "updateTime": firestore.SERVER_TIMESTAMP, + }, + ) + + # Add event + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document(event_id) + # Store event data as JSON serialized string or dict + event_data = event.model_dump(exclude_none=True, mode="json") + batch.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + + await batch.commit() + else: + # No state delta, just add event and update session timestamp + batch = self.client.batch() + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document(event_id) + event_data = event.model_dump(exclude_none=True, mode="json") + batch.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) + batch.update(session_ref, {"updateTime": firestore.SERVER_TIMESTAMP}) + await batch.commit() + + # Also update the in-memory session (adds event to list) + await super().append_event(session, event) + return event diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py new file mode 100644 index 0000000000..b41e14fb94 --- /dev/null +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -0,0 +1,96 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.adk.events.event import Event +from google.adk.memory.firestore_memory_service import FirestoreMemoryService +from google.genai import types + + +@pytest.fixture +def mock_firestore_client(): + client = mock.AsyncMock() + collection_ref = mock.AsyncMock() + client.collection_group.return_value = collection_ref + collection_ref.where.return_value = collection_ref + + # Mock get() for documents + doc_snapshot = mock.AsyncMock() + doc_snapshot.to_dict.return_value = {} + collection_ref.get.return_value = [doc_snapshot] + + return client + + +def test_extract_keywords(): + service = FirestoreMemoryService() + text = "The quick brown fox jumps over the lazy dog." + keywords = service._extract_keywords(text) + + # Check that stopwords like "the", "over" are removed + assert "the" not in keywords + assert "over" not in keywords + assert "quick" in keywords + assert "brown" in keywords + assert "fox" in keywords + assert "jumps" in keywords + assert "lazy" in keywords + assert "dog" in keywords + + +@pytest.mark.asyncio +async def test_search_memory_empty_query(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="" + ) + assert not response.memories + mock_firestore_client.collection_group.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_memory_with_results(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick fox" + + # Mock document snapshot to return event data + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0] + event = Event( + invocation_id="test_inv", + author="user", + content=types.Content(parts=[types.Part(text="quick fox jumps")]), + ) + doc_snapshot.to_dict.return_value = { + "event_data": event.model_dump(exclude_none=True, mode="json") + } + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert response.memories + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + + # Verify Firestore calls + mock_firestore_client.collection_group.assert_called_with("events") + collection_ref = mock_firestore_client.collection_group.return_value + # Verify where calls (order might vary, so we just check it was called or check the chain) + collection_ref.where.assert_called() diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py new file mode 100644 index 0000000000..ec2af4ad04 --- /dev/null +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -0,0 +1,159 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.adk.events.event import Event +from google.adk.sessions.firestore_session_service import FirestoreSessionService + + +@pytest.fixture +def mock_firestore_client(): + client = mock.AsyncMock() + # Mock collection and document references + collection_ref = mock.AsyncMock() + doc_ref = mock.AsyncMock() + subcollection_ref = mock.AsyncMock() + subdoc_ref = mock.AsyncMock() + + client.collection.return_value = collection_ref + collection_ref.document.return_value = doc_ref + doc_ref.collection.return_value = subcollection_ref + subcollection_ref.document.return_value = subdoc_ref + + # Mock get() for documents + doc_snapshot = mock.AsyncMock() + doc_snapshot.exists = False + doc_snapshot.to_dict.return_value = {} + doc_ref.get.return_value = doc_snapshot + subdoc_ref.get.return_value = doc_snapshot + + # Mock collection group + client.collection_group.return_value = collection_ref + + return client + + +@pytest.mark.asyncio +async def test_create_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session = await service.create_session(app_name=app_name, user_id=user_id) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id + + # Verify Firestore calls + mock_firestore_client.collection.assert_called_once_with("adk-session") + collection_ref = mock_firestore_client.collection.return_value + collection_ref.document.assert_called_once_with(user_id) + doc_ref = collection_ref.document.return_value + doc_ref.collection.assert_called_once_with("sessions") + sessions_ref = doc_ref.collection.return_value + sessions_ref.document.assert_called_once_with(session.id) + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_session_not_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is None + + +@pytest.mark.asyncio +async def test_get_session_found(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Mock document snapshot to return data + doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + "state": {"key": "value"}, + "updateTime": 1234567890.0, + } + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.id == session_id + assert session.state == {"key": "value"} + + +@pytest.mark.asyncio +async def test_delete_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Mock events subcollection + events_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + event_doc = mock.AsyncMock() + events_ref.get.return_value = [event_doc] + + await service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Verify events deletion + events_ref.get.assert_called_once() + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.delete.assert_called_once_with(event_doc.reference) + batch.commit.assert_called_once() + + # Verify session deletion + session_doc_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + session_doc_ref.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + session = Session(id="test_session", app_name=app_name, user_id=user_id) + event = Event(invocation_id="test_inv", author="user") + + await service.append_event(session, event) + + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + batch.set.assert_called_once() # For event + batch.update.assert_called_once() # For session updateTime + batch.commit.assert_called_once() From d2d223168e45384488f983b2bec35b4319dc90b4 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 13:30:59 -0600 Subject: [PATCH 02/36] Formatting and fixing the bucket name handling --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - src/google/adk/firestore_database_runner.py | 19 +- .../adk/memory/firestore_memory_service.py | 228 ++++++++++++++++-- .../adk/sessions/firestore_session_service.py | 21 +- .../memory/test_firestore_memory_service.py | 6 +- .../test_firestore_session_service.py | 19 +- 7 files changed, 244 insertions(+), 51 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py index 0ea7aa4f16..b3abbd45b0 100644 --- a/src/google/adk/firestore_database_runner.py +++ b/src/google/adk/firestore_database_runner.py @@ -14,8 +14,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import os from typing import Optional +from typing import TYPE_CHECKING from .artifacts.gcs_artifact_service import GcsArtifactService from .memory.firestore_memory_service import FirestoreMemoryService @@ -41,15 +42,13 @@ def create_firestore_runner( Returns: A Runner instance configured with Firestore services. """ - # GcsArtifactService might require bucket name in constructor or read from env. - # Let's assume it reads from env or takes it. - # If we pass it, we might need to check its signature. - # Let's assume it takes bucket_name if provided, or reads from env. - artifact_service = GcsArtifactService() - if gcs_bucket_name: - # If GcsArtifactService supports setting it, we set it. - # Or we can assume it reads from ADK_GCS_BUCKET_NAME env var. - pass + bucket_name = gcs_bucket_name or os.environ.get("ADK_GCS_BUCKET_NAME") + if not bucket_name: + raise ValueError( + "Required property 'ADK_GCS_BUCKET_NAME' is not set. This" + " is needed for the GcsArtifactService." + ) + artifact_service = GcsArtifactService(bucket_name=bucket_name) session_service = FirestoreSessionService( root_collection=firestore_root_collection diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py index 57c0de645f..97ade10b89 100644 --- a/src/google/adk/memory/firestore_memory_service.py +++ b/src/google/adk/memory/firestore_memory_service.py @@ -24,8 +24,8 @@ from google.cloud import firestore from typing_extensions import override -from ..events.event import Event from . import _utils +from ..events.event import Event from .base_memory_service import BaseMemoryService from .base_memory_service import SearchMemoryResponse from .memory_entry import MemoryEntry @@ -39,28 +39,206 @@ # Standard English stop words DEFAULT_STOP_WORDS = { - "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", - "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", - "being", "have", "has", "had", "do", "does", "did", "can", "could", - "will", "would", "should", "shall", "may", "might", "must", "up", "down", - "out", "in", "over", "under", "again", "further", "then", "once", "here", - "there", "when", "where", "why", "how", "all", "any", "both", "each", - "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", - "own", "same", "so", "than", "too", "very", "i", "me", "my", "myself", - "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", - "yourselves", "he", "him", "his", "himself", "she", "her", "hers", - "herself", "it", "its", "itself", "they", "them", "their", "theirs", - "themselves", "what", "which", "who", "whom", "this", "that", "these", - "those", "am", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "having", "do", "does", "did", "doing", - "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", - "while", "of", "at", "by", "for", "with", "about", "against", "between", - "into", "through", "during", "before", "after", "above", "below", "to", - "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", - "further", "then", "once", "here", "there", "when", "where", "why", "how", - "all", "any", "both", "each", "few", "more", "most", "other", "some", - "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", - "very", "s", "t", "can", "will", "just", "don", "should", "now" + "a", + "an", + "the", + "and", + "or", + "but", + "if", + "then", + "else", + "to", + "of", + "in", + "on", + "for", + "with", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "can", + "could", + "will", + "would", + "should", + "shall", + "may", + "might", + "must", + "up", + "down", + "out", + "in", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "i", + "me", + "my", + "myself", + "we", + "our", + "ours", + "ourselves", + "you", + "your", + "yours", + "yourself", + "yourselves", + "he", + "him", + "his", + "himself", + "she", + "her", + "hers", + "herself", + "it", + "its", + "itself", + "they", + "them", + "their", + "theirs", + "themselves", + "what", + "which", + "who", + "whom", + "this", + "that", + "these", + "those", + "am", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "having", + "do", + "does", + "did", + "doing", + "a", + "an", + "the", + "and", + "but", + "if", + "or", + "because", + "as", + "until", + "while", + "of", + "at", + "by", + "for", + "with", + "about", + "against", + "between", + "into", + "through", + "during", + "before", + "after", + "above", + "below", + "to", + "from", + "up", + "down", + "in", + "out", + "on", + "off", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "s", + "t", + "can", + "will", + "just", + "don", + "should", + "now", } @@ -85,7 +263,9 @@ def __init__( """ self.client = client or firestore.AsyncClient() self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION - self.stop_words = stop_words if stop_words is not None else DEFAULT_STOP_WORDS + self.stop_words = ( + stop_words if stop_words is not None else DEFAULT_STOP_WORDS + ) @override async def add_session_to_memory(self, session: Session) -> None: diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py index b6ad01d4e4..1f1000ad11 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/sessions/firestore_session_service.py @@ -28,7 +28,6 @@ from .base_session_service import ListSessionsResponse from .session import Session - logger = logging.getLogger("google_adk." + __name__) DEFAULT_ROOT_COLLECTION = "adk-session" @@ -65,7 +64,9 @@ def __init__( self.app_state_collection = DEFAULT_APP_STATE_COLLECTION self.user_state_collection = DEFAULT_USER_STATE_COLLECTION - def _get_sessions_ref(self, user_id: str) -> firestore.AsyncCollectionReference: + def _get_sessions_ref( + self, user_id: str + ) -> firestore.AsyncCollectionReference: return ( self.client.collection(self.root_collection) .document(user_id) @@ -83,6 +84,7 @@ async def create_session( """Creates a new session in Firestore.""" if not session_id: from google.adk.platform import uuid as platform_uuid + session_id = platform_uuid.new_uuid() initial_state = state or {} @@ -94,6 +96,7 @@ async def create_session( doc = await session_ref.get() if doc.exists: from ..errors.already_exists_error import AlreadyExistsError + raise AlreadyExistsError(f"Session {session_id} already exists.") session_data = { @@ -113,6 +116,7 @@ async def create_session( # the object, but the DB will have SERVER_TIMESTAMP. from datetime import datetime from datetime import timezone + local_now = datetime.now(timezone.utc).timestamp() return Session( @@ -323,13 +327,12 @@ async def append_event(self, session: Session, event: Event) -> Event: for key, value in state_delta.items(): if key.startswith("_app_"): - app_updates[key[len("_app_"):]] = value + app_updates[key[len("_app_") :]] = value elif key.startswith("_user_"): - user_updates[key[len("_user_"):]] = value + user_updates[key[len("_user_") :]] = value else: session_updates[key] = value - # Update session doc with new state and updateTime # We'll do it outside the batch or inside if we can. # Let's use batch for everything to be atomic. @@ -367,7 +370,9 @@ async def append_event(self, session: Session, event: Event) -> Event: # Add event event_id = event.id - event_ref = session_ref.collection(self.events_collection).document(event_id) + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) # Store event data as JSON serialized string or dict event_data = event.model_dump(exclude_none=True, mode="json") batch.set( @@ -385,7 +390,9 @@ async def append_event(self, session: Session, event: Event) -> Event: # No state delta, just add event and update session timestamp batch = self.client.batch() event_id = event.id - event_ref = session_ref.collection(self.events_collection).document(event_id) + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) event_data = event.model_dump(exclude_none=True, mode="json") batch.set( event_ref, diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py index b41e14fb94..6cd878f336 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -16,10 +16,10 @@ from unittest import mock -import pytest from google.adk.events.event import Event from google.adk.memory.firestore_memory_service import FirestoreMemoryService from google.genai import types +import pytest @pytest.fixture @@ -71,7 +71,9 @@ async def test_search_memory_with_results(mock_firestore_client): query = "quick fox" # Mock document snapshot to return event data - doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0] + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + 0 + ] event = Event( invocation_id="test_inv", author="user", diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index ec2af4ad04..485ef668d1 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -16,9 +16,9 @@ from unittest import mock -import pytest from google.adk.events.event import Event from google.adk.sessions.firestore_session_service import FirestoreSessionService +import pytest @pytest.fixture @@ -94,7 +94,9 @@ async def test_get_session_found(mock_firestore_client): session_id = "test_session" # Mock document snapshot to return data - doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) doc_snapshot.exists = True doc_snapshot.to_dict.return_value = { "id": session_id, @@ -121,7 +123,9 @@ async def test_delete_session(mock_firestore_client): session_id = "test_session" # Mock events subcollection - events_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + events_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) event_doc = mock.AsyncMock() events_ref.get.return_value = [event_doc] @@ -137,7 +141,9 @@ async def test_delete_session(mock_firestore_client): batch.commit.assert_called_once() # Verify session deletion - session_doc_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + session_doc_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + ) session_doc_ref.delete.assert_called_once() @@ -147,6 +153,7 @@ async def test_append_event(mock_firestore_client): app_name = "test_app" user_id = "test_user" from google.adk.sessions.session import Session + session = Session(id="test_session", app_name=app_name, user_id=user_id) event = Event(invocation_id="test_inv", author="user") @@ -154,6 +161,6 @@ async def test_append_event(mock_firestore_client): mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value - batch.set.assert_called_once() # For event - batch.update.assert_called_once() # For session updateTime + batch.set.assert_called_once() # For event + batch.update.assert_called_once() # For session updateTime batch.commit.assert_called_once() From 8312dc82c47c8931289932c514549f48096ba45d Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 13:50:25 -0600 Subject: [PATCH 03/36] Correct imports for firestore --- src/google/adk/memory/firestore_memory_service.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/memory/firestore_memory_service.py index 97ade10b89..ad9fb57992 100644 --- a/src/google/adk/memory/firestore_memory_service.py +++ b/src/google/adk/memory/firestore_memory_service.py @@ -20,8 +20,8 @@ import re from typing import Any from typing import Optional +from typing import TYPE_CHECKING -from google.cloud import firestore from typing_extensions import override from . import _utils @@ -30,7 +30,9 @@ from .base_memory_service import SearchMemoryResponse from .memory_entry import MemoryEntry -if False: # TYPE_CHECKING +if TYPE_CHECKING: + from google.cloud import firestore + from ..sessions.session import Session logger = logging.getLogger("google_adk." + __name__) @@ -261,7 +263,12 @@ def __init__( stop_words: A set of words to ignore when extracting keywords. Defaults to a standard English stop words list. """ - self.client = client or firestore.AsyncClient() + if client is None: + from google.cloud import firestore + + self.client = firestore.AsyncClient() + else: + self.client = client self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION self.stop_words = ( stop_words if stop_words is not None else DEFAULT_STOP_WORDS From 7760b70e2d340b3882e31443c9337a9977981aac Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:00:10 -0600 Subject: [PATCH 04/36] Add firestore to test dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 426b6d1bbf..931ca994f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ test = [ "a2a-sdk>=0.3.0,<0.4.0", "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ + "google-cloud-firestore>=2.11.0", "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent From b29afb73cfe2a5a5f15210e9ecc1ceadf3c874c7 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:12:25 -0600 Subject: [PATCH 05/36] Fix tests --- .../test_firestore_session_service.py | 47 ++++++++++++++ .../test_firestore_database_runner.py | 62 +++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 tests/unittests/test_firestore_database_runner.py diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index 485ef668d1..6048097997 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -164,3 +164,50 @@ async def test_append_event(mock_firestore_client): batch.set.assert_called_once() # For event batch.update.assert_called_once() # For session updateTime batch.commit.assert_called_once() + + +@pytest.mark.asyncio +async def test_append_event_with_state_delta(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + + # Using MagicMock for Event to bypass complex pydantic validation for test + event = mock.MagicMock() + event.partial = False + event.id = "test_event_id" + # Mock actions.state_delta + event.actions.state_delta = { + "_app_my_key": "app_val", + "_user_my_key": "user_val", + "session_key": "session_val", + } + # Mock model_dump to return valid event data + event.model_dump.return_value = {"id": "test_event_id", "author": "user"} + + await service.append_event(session, event) + + mock_firestore_client.batch.assert_called_once() + batch = mock_firestore_client.batch.return_value + + # Verify app state set + # In code: batch.set(app_ref, app_updates, merge=True) + # But app_ref is a mock! Which mock? + # It's mock_firestore_client.collection().document() + # In fixture: collection_ref = mock.AsyncMock() + # doc_ref = mock.AsyncMock() + # client.collection.return_value = collection_ref + # collection_ref.document.return_value = doc_ref + # So batch.set is called with app_ref (which is doc_ref) + batch.set.assert_called() + + # Verify session state updated in memory + assert session.state["session_key"] == "session_val" + + # Verify batch update was called for session + batch.update.assert_called_once() + + batch.commit.assert_called_once() diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py new file mode 100644 index 0000000000..a4e8fb1889 --- /dev/null +++ b/tests/unittests/test_firestore_database_runner.py @@ -0,0 +1,62 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest +from google.adk.agents.base_agent import BaseAgent +from google.adk.firestore_database_runner import create_firestore_runner + + +@pytest.fixture +def mock_agent(): + agent = mock.MagicMock(spec=BaseAgent) + agent.name = "test_agent" + return agent + + +def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): + monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) + + # Mock GcsArtifactService to avoid real client init + with mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs: + runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") + + assert runner is not None + mock_gcs.assert_called_once_with(bucket_name="test_bucket") + + +def test_create_firestore_runner_with_env(mock_agent, monkeypatch): + monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket") + + with mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs: + runner = create_firestore_runner(mock_agent) + + assert runner is not None + mock_gcs.assert_called_once_with(bucket_name="env_bucket") + + +def test_create_firestore_runner_missing_bucket(mock_agent, monkeypatch): + monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) + + with pytest.raises( + ValueError, match="Required property 'ADK_GCS_BUCKET_NAME' is not set" + ): + create_firestore_runner(mock_agent) From 31ffb864422333f44a7db2a3918d07f1ad596569 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:37:07 -0600 Subject: [PATCH 06/36] Fix mypy errors --- src/google/adk/errors/already_exists_error.py | 2 +- .../adk/sessions/firestore_session_service.py | 59 +------------------ 2 files changed, 3 insertions(+), 58 deletions(-) diff --git a/src/google/adk/errors/already_exists_error.py b/src/google/adk/errors/already_exists_error.py index 8bd14f9ad6..bf8d357a81 100644 --- a/src/google/adk/errors/already_exists_error.py +++ b/src/google/adk/errors/already_exists_error.py @@ -18,7 +18,7 @@ class AlreadyExistsError(Exception): """Represents an error that occurs when an entity already exists.""" - def __init__(self, message="The resource already exists."): + def __init__(self, message: str = "The resource already exists."): """Initializes the AlreadyExistsError exception. Args: diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py index 1f1000ad11..5ace2d94a9 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/sessions/firestore_session_service.py @@ -16,6 +16,8 @@ import logging import os +from datetime import datetime +from datetime import timezone from typing import Any from typing import Optional @@ -114,9 +116,6 @@ async def create_session( # evaluated on the server, we might want to use local time for the object # or read it back. Reading it back is expensive. We'll use local time for # the object, but the DB will have SERVER_TIMESTAMP. - from datetime import datetime - from datetime import timezone - local_now = datetime.now(timezone.utc).timestamp() return Session( @@ -170,60 +169,6 @@ async def get_session( # Restore timestamp if needed, or assume it's in event_data events.append(Event.model_validate(ed)) - # Fetch states (app and user) if we want to merge them, similar to - # DatabaseSessionService. The Java code seems to merge them in listSessions - # but let's see if getSession does it. - # In Java, getSession fetches app/user state if needed? The Java code I read: - # It didn't seem to fetch app/user state in getSession, only in appendEvent - # where it updates them, and listSessions where it mergers. - # Wait, let's re-read Java getSession. - # It doesn't seem to fetch app/user state in getSession either? - # Actually, in Java `FirestoreSessionService.java` `getSession`: - # It reads the session doc, then reads events. It doesn't seem to read - # app/user state docs. - # But `DatabaseSessionService` in Python DOES read them in `get_session`. - # Let's align with Python `DatabaseSessionService` if possible, as it's the - # standard in Python ADK. - # Python `DatabaseSessionService` reads `StorageAppState` and `StorageUserState` - # and merges them. - # If I want to be consistent with Python ADK, I should probably do it. - # But if I want to be consistent with Java ADK port, I should follow Java. - # The user asked to "Port this firestore support over to ADK Python". - # I should follow the Java logic but make it Pythonic. - # The Java logic doesn't seem to merge app/user state in `getSession`, it - # just returns session state. - # Wait, let's check Java `listSessions`. It read `StorageAppState`? No, it - # just read sessions. - # Let's stick to the Java logic if it works, or adapt to Python if it's better. - # Since `DatabaseSessionService` in Python merges them, maybe it's a newer - # feature in Python ADK that Java doesn't have or does differently. - # Let's check `FirestoreSessionService.java` again. - # In Java `listSessions`, it doesn't seem to fetch app/user state. - # In Java `appendEvent`, it updates app/user state if `state_delta` has - # `_app_` or `_user_` prefixes. - # Let's stick to the Java behavior unless it conflicts with Python interfaces. - # The Python `BaseSessionService` doesn't enforce merging, it just defines - # the interface. `DatabaseSessionService` implements merging. - # I'll stick to the Java behavior (no merging in get/list, only update in append) - # for now, as it's a port of Java. Or I can implement merging if it's easy. - # Let's look at Java `appendEvent`: - # It checks `_app_` and `_user_` prefixes in `state_delta` and updates - # separate collections! - # ```java - # firestore.collection(APP_STATE_COLLECTION).document(appName).set(...) - # ``` - # So it DOES use separate collections for app/user state. - # If it uses them, it should probably read them somewhere. In Java, it seems - # it might not read them in `getSession`? Wait, let's check `FirestoreSessionService.java` - # again. I see `listSessions` doesn't read them. `getSession` doesn't read them. - # That might be a bug or partial implementation in Java? Or maybe they are - # read elsewhere? - # In Python `DatabaseSessionService` reads them in `get_session` and `list_sessions`. - # Let's implement reading them in Python `FirestoreSessionService` to be - # consistent with Python ADK standards if possible, or at least support it. - # I'll implement it without merging first to match Java, then see if I should - # add it. The Java code didn't do it. - # Let's continue getting session. session_state = data.get("state", {}) From 565cc616d37d1b3bc9aa3ef878227c001fbbc1f6 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:40:37 -0600 Subject: [PATCH 07/36] Undo unintended changes --- contributing/samples/gepa/experiment.py | 1 + contributing/samples/gepa/run_experiment.py | 1 + 2 files changed, 2 insertions(+) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2710c3894c..f3751206a8 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,6 +43,7 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib + import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index e31db15788..d857da9635 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,6 +25,7 @@ from absl import flags import experiment from google.genai import types + import utils _OUTPUT_DIR = flags.DEFINE_string( From 9387cb3e948875a885a4d90e54c80ff5f30ea90a Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 14:42:04 -0600 Subject: [PATCH 08/36] Sorting imports --- src/google/adk/sessions/firestore_session_service.py | 4 ++-- tests/unittests/test_firestore_database_runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/sessions/firestore_session_service.py index 5ace2d94a9..306bc4c87e 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/sessions/firestore_session_service.py @@ -14,10 +14,10 @@ from __future__ import annotations -import logging -import os from datetime import datetime from datetime import timezone +import logging +import os from typing import Any from typing import Optional diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py index a4e8fb1889..4b51ffb99a 100644 --- a/tests/unittests/test_firestore_database_runner.py +++ b/tests/unittests/test_firestore_database_runner.py @@ -16,9 +16,9 @@ from unittest import mock -import pytest from google.adk.agents.base_agent import BaseAgent from google.adk.firestore_database_runner import create_firestore_runner +import pytest @pytest.fixture From 03910e99e24de7dc376177f9448928c6ce0f48fa Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 15:29:27 -0600 Subject: [PATCH 09/36] Fix async mocks --- .../memory/test_firestore_memory_service.py | 11 +++++--- .../test_firestore_session_service.py | 26 ++++++++++++------- .../test_firestore_database_runner.py | 25 +++++++++++++----- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py index 6cd878f336..d7497735aa 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -24,15 +24,18 @@ @pytest.fixture def mock_firestore_client(): - client = mock.AsyncMock() - collection_ref = mock.AsyncMock() + client = mock.MagicMock() + collection_ref = mock.MagicMock() client.collection_group.return_value = collection_ref + + # where() should return self (collection_ref) to allow chaining collection_ref.where.return_value = collection_ref # Mock get() for documents - doc_snapshot = mock.AsyncMock() + doc_snapshot = mock.MagicMock() doc_snapshot.to_dict.return_value = {} - collection_ref.get.return_value = [doc_snapshot] + + collection_ref.get = mock.AsyncMock(return_value=[doc_snapshot]) return client diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index 6048097997..294b8114fe 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -23,12 +23,11 @@ @pytest.fixture def mock_firestore_client(): - client = mock.AsyncMock() - # Mock collection and document references - collection_ref = mock.AsyncMock() - doc_ref = mock.AsyncMock() - subcollection_ref = mock.AsyncMock() - subdoc_ref = mock.AsyncMock() + client = mock.MagicMock() + collection_ref = mock.MagicMock() + doc_ref = mock.MagicMock() + subcollection_ref = mock.MagicMock() + subdoc_ref = mock.MagicMock() client.collection.return_value = collection_ref collection_ref.document.return_value = doc_ref @@ -36,15 +35,24 @@ def mock_firestore_client(): subcollection_ref.document.return_value = subdoc_ref # Mock get() for documents - doc_snapshot = mock.AsyncMock() + doc_snapshot = mock.MagicMock() doc_snapshot.exists = False doc_snapshot.to_dict.return_value = {} - doc_ref.get.return_value = doc_snapshot - subdoc_ref.get.return_value = doc_snapshot + + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + + # Mock subcollection get() (for events list in delete_session) + subcollection_ref.get = mock.AsyncMock(return_value=[]) # Mock collection group client.collection_group.return_value = collection_ref + # Mock batch + batch = mock.MagicMock() + client.batch.return_value = batch + batch.commit = mock.AsyncMock() + return client diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/test_firestore_database_runner.py index 4b51ffb99a..89c1a1f677 100644 --- a/tests/unittests/test_firestore_database_runner.py +++ b/tests/unittests/test_firestore_database_runner.py @@ -31,10 +31,15 @@ def mock_agent(): def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) - # Mock GcsArtifactService to avoid real client init - with mock.patch( - "google.adk.firestore_database_runner.GcsArtifactService" - ) as mock_gcs: + with ( + mock.patch( + "google.adk.firestore_database_runner.FirestoreSessionService" + ), + mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs, + ): runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") assert runner is not None @@ -44,9 +49,15 @@ def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): def test_create_firestore_runner_with_env(mock_agent, monkeypatch): monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket") - with mock.patch( - "google.adk.firestore_database_runner.GcsArtifactService" - ) as mock_gcs: + with ( + mock.patch( + "google.adk.firestore_database_runner.FirestoreSessionService" + ), + mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.firestore_database_runner.GcsArtifactService" + ) as mock_gcs, + ): runner = create_firestore_runner(mock_agent) assert runner is not None From 49c7bf5a02ad6ed3e3d74ba8b21bbf31f02d1dc5 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Tue, 31 Mar 2026 16:50:08 -0600 Subject: [PATCH 10/36] Fixing tests again again again --- src/google/adk/firestore_database_runner.py | 1 + .../sessions/test_firestore_session_service.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/firestore_database_runner.py index b3abbd45b0..aa86dae022 100644 --- a/src/google/adk/firestore_database_runner.py +++ b/src/google/adk/firestore_database_runner.py @@ -56,6 +56,7 @@ def create_firestore_runner( memory_service = FirestoreMemoryService() return Runner( + app_name=agent.name, agent=agent, session_service=session_service, artifact_service=artifact_service, diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/sessions/test_firestore_session_service.py index 294b8114fe..21511a0dab 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/sessions/test_firestore_session_service.py @@ -42,8 +42,21 @@ def mock_firestore_client(): doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) - # Mock subcollection get() (for events list in delete_session) + # Set methods used in create_session and delete_session to AsyncMock + subdoc_ref.set = mock.AsyncMock() + subdoc_ref.delete = mock.AsyncMock() + + # Mock events subcollection + events_collection_ref = mock.MagicMock() + subdoc_ref.collection.return_value = events_collection_ref + events_collection_ref.order_by.return_value = events_collection_ref + events_collection_ref.where.return_value = events_collection_ref + events_collection_ref.limit_to_last.return_value = events_collection_ref + events_collection_ref.get = mock.AsyncMock(return_value=[]) + + # Mock subcollection get() (for sessions listing) subcollection_ref.get = mock.AsyncMock(return_value=[]) + subcollection_ref.where.return_value = subcollection_ref # Mock collection group client.collection_group.return_value = collection_ref @@ -135,7 +148,7 @@ async def test_delete_session(mock_firestore_client): mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) event_doc = mock.AsyncMock() - events_ref.get.return_value = [event_doc] + events_ref.get = mock.AsyncMock(return_value=[event_doc]) await service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id From fbd16eba4d2e4f5d05c042e14c30e10bb433b239 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Wed, 1 Apr 2026 00:17:01 -0600 Subject: [PATCH 11/36] Empty commit From 5645fe8f4e29a78aa9ba10a911ba2c664e111515 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Wed, 1 Apr 2026 00:28:18 -0600 Subject: [PATCH 12/36] Fixing one more test to use the firestore mock client --- tests/unittests/memory/test_firestore_memory_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/memory/test_firestore_memory_service.py index d7497735aa..00b0099782 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/memory/test_firestore_memory_service.py @@ -40,8 +40,8 @@ def mock_firestore_client(): return client -def test_extract_keywords(): - service = FirestoreMemoryService() +def test_extract_keywords(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) text = "The quick brown fox jumps over the lazy dog." keywords = service._extract_keywords(text) From 9d8bb5df1d7fab8c75d02e96e8ab163b2967642f Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 09:18:53 -0700 Subject: [PATCH 13/36] Move firestore integration into integrations package --- .../adk/integrations/firestore/__init__.py | 17 ++ .../firestore}/firestore_database_runner.py | 10 +- .../firestore}/firestore_memory_service.py | 247 ++---------------- .../firestore}/firestore_session_service.py | 48 +--- .../test_firestore_database_runner.py | 14 +- .../test_firestore_memory_service.py | 8 +- .../test_firestore_session_service.py | 31 +-- 7 files changed, 70 insertions(+), 305 deletions(-) create mode 100644 src/google/adk/integrations/firestore/__init__.py rename src/google/adk/{ => integrations/firestore}/firestore_database_runner.py (86%) rename src/google/adk/{memory => integrations/firestore}/firestore_memory_service.py (55%) rename src/google/adk/{sessions => integrations/firestore}/firestore_session_service.py (81%) rename tests/unittests/{ => integrations/firestore}/test_firestore_database_runner.py (73%) rename tests/unittests/{memory => integrations/firestore}/test_firestore_memory_service.py (88%) rename tests/unittests/{sessions => integrations/firestore}/test_firestore_session_service.py (84%) diff --git a/src/google/adk/integrations/firestore/__init__.py b/src/google/adk/integrations/firestore/__init__.py new file mode 100644 index 0000000000..7c76d28c93 --- /dev/null +++ b/src/google/adk/integrations/firestore/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +"""Firestore integrations for ADK.""" diff --git a/src/google/adk/firestore_database_runner.py b/src/google/adk/integrations/firestore/firestore_database_runner.py similarity index 86% rename from src/google/adk/firestore_database_runner.py rename to src/google/adk/integrations/firestore/firestore_database_runner.py index aa86dae022..6d40c7feb5 100644 --- a/src/google/adk/firestore_database_runner.py +++ b/src/google/adk/integrations/firestore/firestore_database_runner.py @@ -18,13 +18,13 @@ from typing import Optional from typing import TYPE_CHECKING -from .artifacts.gcs_artifact_service import GcsArtifactService -from .memory.firestore_memory_service import FirestoreMemoryService -from .runners import Runner -from .sessions.firestore_session_service import FirestoreSessionService +from ...artifacts.gcs_artifact_service import GcsArtifactService +from ...runners import Runner +from .firestore_memory_service import FirestoreMemoryService +from .firestore_session_service import FirestoreSessionService if TYPE_CHECKING: - from .agents.base_agent import BaseAgent + from ...agents.base_agent import BaseAgent def create_firestore_runner( diff --git a/src/google/adk/memory/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py similarity index 55% rename from src/google/adk/memory/firestore_memory_service.py rename to src/google/adk/integrations/firestore/firestore_memory_service.py index ad9fb57992..71365a2864 100644 --- a/src/google/adk/memory/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -24,223 +24,44 @@ from typing_extensions import override -from . import _utils -from ..events.event import Event -from .base_memory_service import BaseMemoryService -from .base_memory_service import SearchMemoryResponse -from .memory_entry import MemoryEntry +from ...memory import _utils +from ...events.event import Event +from ...memory.base_memory_service import BaseMemoryService +from ...memory.base_memory_service import SearchMemoryResponse +from ...memory.memory_entry import MemoryEntry if TYPE_CHECKING: from google.cloud import firestore - from ..sessions.session import Session + from ...sessions.session import Session logger = logging.getLogger("google_adk." + __name__) DEFAULT_EVENTS_COLLECTION = "events" -# Standard English stop words DEFAULT_STOP_WORDS = { - "a", - "an", - "the", - "and", - "or", - "but", - "if", - "then", - "else", - "to", - "of", - "in", - "on", - "for", - "with", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "do", - "does", - "did", - "can", - "could", - "will", - "would", - "should", - "shall", - "may", - "might", - "must", - "up", - "down", - "out", - "in", - "over", - "under", - "again", - "further", - "then", - "once", - "here", - "there", - "when", - "where", - "why", - "how", - "all", - "any", - "both", - "each", - "few", - "more", - "most", - "other", - "some", - "such", - "no", - "nor", - "not", - "only", - "own", - "same", - "so", - "than", - "too", - "very", - "i", - "me", - "my", - "myself", - "we", - "our", - "ours", - "ourselves", - "you", - "your", - "yours", - "yourself", - "yourselves", - "he", - "him", - "his", - "himself", - "she", - "her", - "hers", - "herself", - "it", - "its", - "itself", - "they", - "them", - "their", - "theirs", - "themselves", - "what", - "which", - "who", - "whom", - "this", - "that", - "these", - "those", - "am", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "having", - "do", - "does", - "did", - "doing", - "a", - "an", - "the", - "and", - "but", - "if", - "or", - "because", - "as", - "until", - "while", - "of", - "at", - "by", - "for", - "with", - "about", - "against", - "between", - "into", - "through", - "during", - "before", - "after", - "above", - "below", - "to", - "from", - "up", - "down", - "in", - "out", - "on", - "off", - "over", - "under", - "again", - "further", - "then", - "once", - "here", - "there", - "when", - "where", - "why", - "how", - "all", - "any", - "both", - "each", - "few", - "more", - "most", - "other", - "some", - "such", - "no", - "nor", - "not", - "only", - "own", - "same", - "so", - "than", - "too", - "very", - "s", - "t", - "can", - "will", - "just", - "don", - "should", - "now", + "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", + "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", + "being", "have", "has", "had", "do", "does", "did", "can", "could", + "will", "would", "should", "shall", "may", "might", "must", "up", "down", + "out", "in", "over", "under", "again", "further", "then", "once", "here", + "there", "when", "where", "why", "how", "all", "any", "both", "each", + "few", "more", "most", "other", "some", "such", "no", "nor", "not", + "only", "own", "same", "so", "than", "too", "very", "i", "me", "my", + "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", + "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", + "hers", "herself", "it", "its", "itself", "they", "them", "their", + "theirs", "themselves", "what", "which", "who", "whom", "this", "that", + "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", + "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", + "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", + "at", "by", "for", "with", "about", "against", "between", "into", + "through", "during", "before", "after", "above", "below", "to", "from", + "up", "down", "in", "out", "on", "off", "over", "under", "again", + "further", "then", "once", "here", "there", "when", "where", "why", "how", + "all", "any", "both", "each", "few", "more", "most", "other", "some", + "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", + "very", "s", "t", "can", "will", "just", "don", "should", "now", } @@ -288,8 +109,6 @@ async def _search_by_keyword( self, app_name: str, user_id: str, keyword: str ) -> list[MemoryEntry]: """Searches for events matching a single keyword.""" - # This requires a collection group index in Firestore for 'events' with - # appName == X, userId == Y, and keywords array-contains Z. query = ( self.client.collection_group(self.events_collection) .where("appName", "==", app_name) @@ -326,26 +145,16 @@ async def search_memory( if not keywords: return SearchMemoryResponse() - # Search for each keyword concurrently tasks = [ self._search_by_keyword(app_name, user_id, keyword) for keyword in keywords ] results = await asyncio.gather(*tasks) - # Merge results and deduplicate by MemoryEntry content/author/timestamp - # (MemoryEntry is not hashable by default if it contains complex objects, - # so we might need to deduplicate by id if available, or by content string). - # Since we convert Event to MemoryEntry, we don't have event.id in MemoryEntry - # unless we add it. The Java code use custom hash/equals for MemoryEntry. - # In Python, MemoryEntry is a Pydantic model. We can deduplicate by model_dump_json() - # or by a custom key. seen = set() memories = [] for result_list in results: for entry in result_list: - # Deduplicate by a key of (author, content_text) - # Content might be complex, so let's use its json representation or text content_text = "" if entry.content and entry.content.parts: content_text = " ".join( diff --git a/src/google/adk/sessions/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py similarity index 81% rename from src/google/adk/sessions/firestore_session_service.py rename to src/google/adk/integrations/firestore/firestore_session_service.py index 306bc4c87e..4a13146d4f 100644 --- a/src/google/adk/sessions/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -24,11 +24,11 @@ from google.cloud import firestore from pydantic import BaseModel -from ..events.event import Event -from .base_session_service import BaseSessionService -from .base_session_service import GetSessionConfig -from .base_session_service import ListSessionsResponse -from .session import Session +from ...events.event import Event +from ...sessions.base_session_service import BaseSessionService +from ...sessions.base_session_service import GetSessionConfig +from ...sessions.base_session_service import ListSessionsResponse +from ...sessions.session import Session logger = logging.getLogger("google_adk." + __name__) @@ -57,7 +57,7 @@ def __init__( """ self.client = client or firestore.AsyncClient() self.root_collection = ( - root_collection + root_collection or os.environ.get("ADK_FIRESTORE_ROOT_COLLECTION") or DEFAULT_ROOT_COLLECTION ) @@ -85,7 +85,7 @@ async def create_session( ) -> Session: """Creates a new session in Firestore.""" if not session_id: - from google.adk.platform import uuid as platform_uuid + from ...platform import uuid as platform_uuid session_id = platform_uuid.new_uuid() @@ -97,7 +97,7 @@ async def create_session( # Check if session already exists doc = await session_ref.get() if doc.exists: - from ..errors.already_exists_error import AlreadyExistsError + from ...errors.already_exists_error import AlreadyExistsError raise AlreadyExistsError(f"Session {session_id} already exists.") @@ -162,11 +162,7 @@ async def get_session( for event_doc in events_docs: event_data = event_doc.to_dict() if event_data and "event_data" in event_data: - # The Java code serializes individual fields, but Python schema/v1 uses - # JSON serialization of the whole event. We'll stick to Pythonic JSON - # serialization (event.model_dump) for consistency with Python ADK. ed = event_data["event_data"] - # Restore timestamp if needed, or assume it's in event_data events.append(Event.model_validate(ed)) # Let's continue getting session. @@ -176,11 +172,9 @@ async def get_session( update_time = data.get("updateTime") last_update_time = 0.0 if update_time: - # If it's a datetime object (Firestore might return it) if isinstance(update_time, datetime): last_update_time = update_time.timestamp() else: - # Assuming it's a string or float try: last_update_time = float(update_time) except (ValueError, TypeError): @@ -199,17 +193,10 @@ async def list_sessions( self, *, app_name: str, user_id: Optional[str] = None ) -> ListSessionsResponse: """Lists sessions from Firestore.""" - # If user_id is provided, we can list directly. - # If not, we might need a collection group query or list all users first. - # Java listSessions takes appName and userId. It always scopes to user. - # Python list_sessions has user_id optional. - # If user_id is None, we should list all sessions for the app across all users. - # This requires a collection group query on 'sessions'. if user_id: query = self._get_sessions_ref(user_id).where("appName", "==", app_name) docs = await query.get() else: - # Collection group query query = self.client.collection_group(self.sessions_collection).where( "appName", "==", app_name ) @@ -219,7 +206,6 @@ async def list_sessions( for doc in docs: data = doc.to_dict() if data: - # Session state is empty for listing as per in_memory sessions.append( Session( id=data["id"], @@ -227,7 +213,7 @@ async def list_sessions( user_id=data["userId"], state={}, # Empty state for listing events=[], # Empty events for listing - last_update_time=0.0, # Or parse from updateTime + last_update_time=0.0, ) ) @@ -239,17 +225,14 @@ async def delete_session( """Deletes a session and its events from Firestore.""" session_ref = self._get_sessions_ref(user_id).document(session_id) - # Delete events subcollection first (Firestore requires manual subcollection deletion) events_ref = session_ref.collection(self.events_collection) events_docs = await events_ref.get() - # Batch delete batch = self.client.batch() for event_doc in events_docs: batch.delete(event_doc.reference) await batch.commit() - # Delete session doc await session_ref.delete() async def append_event(self, session: Session, event: Event) -> Event: @@ -257,13 +240,11 @@ async def append_event(self, session: Session, event: Event) -> Event: if event.partial: return event - # Apply temp state to in-memory session (from base class) self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) session_ref = self._get_sessions_ref(session.user_id).document(session.id) - # Handle state deltas (app and user state) if event.actions and event.actions.state_delta: state_delta = event.actions.state_delta app_updates = {} @@ -278,11 +259,6 @@ async def append_event(self, session: Session, event: Event) -> Event: else: session_updates[key] = value - # Update session doc with new state and updateTime - # We'll do it outside the batch or inside if we can. - # Let's use batch for everything to be atomic. - # Wait, I didn't add session_ref to batch yet. - # Let's create a batch. batch = self.client.batch() if app_updates: @@ -300,11 +276,9 @@ async def append_event(self, session: Session, event: Event) -> Event: ) batch.set(user_ref, user_updates, merge=True) - # Update session state in-memory first for k, v in session_updates.items(): session.state[k] = v - # Update session doc batch.update( session_ref, { @@ -313,12 +287,10 @@ async def append_event(self, session: Session, event: Event) -> Event: }, ) - # Add event event_id = event.id event_ref = session_ref.collection(self.events_collection).document( event_id ) - # Store event data as JSON serialized string or dict event_data = event.model_dump(exclude_none=True, mode="json") batch.set( event_ref, @@ -332,7 +304,6 @@ async def append_event(self, session: Session, event: Event) -> Event: await batch.commit() else: - # No state delta, just add event and update session timestamp batch = self.client.batch() event_id = event.id event_ref = session_ref.collection(self.events_collection).document( @@ -351,6 +322,5 @@ async def append_event(self, session: Session, event: Event) -> Event: batch.update(session_ref, {"updateTime": firestore.SERVER_TIMESTAMP}) await batch.commit() - # Also update the in-memory session (adds event to list) await super().append_event(session, event) return event diff --git a/tests/unittests/test_firestore_database_runner.py b/tests/unittests/integrations/firestore/test_firestore_database_runner.py similarity index 73% rename from tests/unittests/test_firestore_database_runner.py rename to tests/unittests/integrations/firestore/test_firestore_database_runner.py index 89c1a1f677..6f39fed287 100644 --- a/tests/unittests/test_firestore_database_runner.py +++ b/tests/unittests/integrations/firestore/test_firestore_database_runner.py @@ -17,7 +17,7 @@ from unittest import mock from google.adk.agents.base_agent import BaseAgent -from google.adk.firestore_database_runner import create_firestore_runner +from google.adk.integrations.firestore.firestore_database_runner import create_firestore_runner import pytest @@ -33,11 +33,11 @@ def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): with ( mock.patch( - "google.adk.firestore_database_runner.FirestoreSessionService" + "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" ), - mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"), mock.patch( - "google.adk.firestore_database_runner.GcsArtifactService" + "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" ) as mock_gcs, ): runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") @@ -51,11 +51,11 @@ def test_create_firestore_runner_with_env(mock_agent, monkeypatch): with ( mock.patch( - "google.adk.firestore_database_runner.FirestoreSessionService" + "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" ), - mock.patch("google.adk.firestore_database_runner.FirestoreMemoryService"), + mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"), mock.patch( - "google.adk.firestore_database_runner.GcsArtifactService" + "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" ) as mock_gcs, ): runner = create_firestore_runner(mock_agent) diff --git a/tests/unittests/memory/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py similarity index 88% rename from tests/unittests/memory/test_firestore_memory_service.py rename to tests/unittests/integrations/firestore/test_firestore_memory_service.py index 00b0099782..1ec68c5247 100644 --- a/tests/unittests/memory/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -17,7 +17,7 @@ from unittest import mock from google.adk.events.event import Event -from google.adk.memory.firestore_memory_service import FirestoreMemoryService +from google.adk.integrations.firestore.firestore_memory_service import FirestoreMemoryService from google.genai import types import pytest @@ -28,10 +28,8 @@ def mock_firestore_client(): collection_ref = mock.MagicMock() client.collection_group.return_value = collection_ref - # where() should return self (collection_ref) to allow chaining collection_ref.where.return_value = collection_ref - # Mock get() for documents doc_snapshot = mock.MagicMock() doc_snapshot.to_dict.return_value = {} @@ -45,7 +43,6 @@ def test_extract_keywords(mock_firestore_client): text = "The quick brown fox jumps over the lazy dog." keywords = service._extract_keywords(text) - # Check that stopwords like "the", "over" are removed assert "the" not in keywords assert "over" not in keywords assert "quick" in keywords @@ -73,7 +70,6 @@ async def test_search_memory_with_results(mock_firestore_client): user_id = "test_user" query = "quick fox" - # Mock document snapshot to return event data doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ 0 ] @@ -94,8 +90,6 @@ async def test_search_memory_with_results(mock_firestore_client): assert len(response.memories) == 1 assert response.memories[0].author == "user" - # Verify Firestore calls mock_firestore_client.collection_group.assert_called_with("events") collection_ref = mock_firestore_client.collection_group.return_value - # Verify where calls (order might vary, so we just check it was called or check the chain) collection_ref.where.assert_called() diff --git a/tests/unittests/sessions/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py similarity index 84% rename from tests/unittests/sessions/test_firestore_session_service.py rename to tests/unittests/integrations/firestore/test_firestore_session_service.py index 21511a0dab..088c1fddfa 100644 --- a/tests/unittests/sessions/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -17,7 +17,7 @@ from unittest import mock from google.adk.events.event import Event -from google.adk.sessions.firestore_session_service import FirestoreSessionService +from google.adk.integrations.firestore.firestore_session_service import FirestoreSessionService import pytest @@ -34,7 +34,6 @@ def mock_firestore_client(): doc_ref.collection.return_value = subcollection_ref subcollection_ref.document.return_value = subdoc_ref - # Mock get() for documents doc_snapshot = mock.MagicMock() doc_snapshot.exists = False doc_snapshot.to_dict.return_value = {} @@ -42,11 +41,9 @@ def mock_firestore_client(): doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) - # Set methods used in create_session and delete_session to AsyncMock subdoc_ref.set = mock.AsyncMock() subdoc_ref.delete = mock.AsyncMock() - # Mock events subcollection events_collection_ref = mock.MagicMock() subdoc_ref.collection.return_value = events_collection_ref events_collection_ref.order_by.return_value = events_collection_ref @@ -54,14 +51,11 @@ def mock_firestore_client(): events_collection_ref.limit_to_last.return_value = events_collection_ref events_collection_ref.get = mock.AsyncMock(return_value=[]) - # Mock subcollection get() (for sessions listing) subcollection_ref.get = mock.AsyncMock(return_value=[]) subcollection_ref.where.return_value = subcollection_ref - # Mock collection group client.collection_group.return_value = collection_ref - # Mock batch batch = mock.MagicMock() client.batch.return_value = batch batch.commit = mock.AsyncMock() @@ -81,7 +75,6 @@ async def test_create_session(mock_firestore_client): assert session.user_id == user_id assert session.id - # Verify Firestore calls mock_firestore_client.collection.assert_called_once_with("adk-session") collection_ref = mock_firestore_client.collection.return_value collection_ref.document.assert_called_once_with(user_id) @@ -114,7 +107,6 @@ async def test_get_session_found(mock_firestore_client): user_id = "test_user" session_id = "test_session" - # Mock document snapshot to return data doc_snapshot = ( mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value ) @@ -143,7 +135,6 @@ async def test_delete_session(mock_firestore_client): user_id = "test_user" session_id = "test_session" - # Mock events subcollection events_ref = ( mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) @@ -154,14 +145,12 @@ async def test_delete_session(mock_firestore_client): app_name=app_name, user_id=user_id, session_id=session_id ) - # Verify events deletion events_ref.get.assert_called_once() mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value batch.delete.assert_called_once_with(event_doc.reference) batch.commit.assert_called_once() - # Verify session deletion session_doc_ref = ( mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value ) @@ -182,8 +171,8 @@ async def test_append_event(mock_firestore_client): mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value - batch.set.assert_called_once() # For event - batch.update.assert_called_once() # For session updateTime + batch.set.assert_called_once() + batch.update.assert_called_once() batch.commit.assert_called_once() @@ -196,17 +185,14 @@ async def test_append_event_with_state_delta(mock_firestore_client): session = Session(id="test_session", app_name=app_name, user_id=user_id) - # Using MagicMock for Event to bypass complex pydantic validation for test event = mock.MagicMock() event.partial = False event.id = "test_event_id" - # Mock actions.state_delta event.actions.state_delta = { "_app_my_key": "app_val", "_user_my_key": "user_val", "session_key": "session_val", } - # Mock model_dump to return valid event data event.model_dump.return_value = {"id": "test_event_id", "author": "user"} await service.append_event(session, event) @@ -214,21 +200,10 @@ async def test_append_event_with_state_delta(mock_firestore_client): mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value - # Verify app state set - # In code: batch.set(app_ref, app_updates, merge=True) - # But app_ref is a mock! Which mock? - # It's mock_firestore_client.collection().document() - # In fixture: collection_ref = mock.AsyncMock() - # doc_ref = mock.AsyncMock() - # client.collection.return_value = collection_ref - # collection_ref.document.return_value = doc_ref - # So batch.set is called with app_ref (which is doc_ref) batch.set.assert_called() - # Verify session state updated in memory assert session.state["session_key"] == "session_val" - # Verify batch update was called for session batch.update.assert_called_once() batch.commit.assert_called_once() From bf47b5c96572964a44d2efcd7f1617e7e220a6e5 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 10:31:13 -0700 Subject: [PATCH 14/36] Updating session service to address various concerns These updates are derived from @anmolg1997 PR on the community repo: https://github.com/google/adk-python-community/pull/104 --- .../firestore/firestore_session_service.py | 135 +++++++++-- .../test_firestore_database_runner.py | 19 ++ .../test_firestore_memory_service.py | 72 ++++++ .../test_firestore_session_service.py | 226 +++++++++++++++++- 4 files changed, 430 insertions(+), 22 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 4a13146d4f..731d257d85 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -21,7 +21,6 @@ from typing import Any from typing import Optional -from google.cloud import firestore from pydantic import BaseModel from ...events.event import Event @@ -55,6 +54,14 @@ def __init__( root_collection: The root collection name. Defaults to 'adk-session' or the value of ADK_FIRESTORE_ROOT_COLLECTION env var. """ + try: + from google.cloud import firestore + except ImportError as e: + raise ImportError( + "FirestoreSessionService requires google-cloud-firestore. " + "Install it with: pip install google-cloud-firestore" + ) from e + self.client = client or firestore.AsyncClient() self.root_collection = ( root_collection @@ -66,6 +73,22 @@ def __init__( self.app_state_collection = DEFAULT_APP_STATE_COLLECTION self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + @staticmethod + def _merge_state( + app_state: dict[str, Any], + user_state: dict[str, Any], + session_state: dict[str, Any], + ) -> dict[str, Any]: + """Merge app, user, and session states into a single state dictionary.""" + import copy + + merged_state = copy.deepcopy(session_state) + for key, value in app_state.items(): + merged_state["_app_" + key] = value + for key, value in user_state.items(): + merged_state["_user_" + key] = value + return merged_state + def _get_sessions_ref( self, user_id: str ) -> firestore.AsyncCollectionReference: @@ -84,6 +107,7 @@ async def create_session( session_id: Optional[str] = None, ) -> Session: """Creates a new session in Firestore.""" + from google.cloud import firestore if not session_id: from ...platform import uuid as platform_uuid @@ -202,17 +226,50 @@ async def list_sessions( ) docs = await query.get() + # Fetch shared state once + app_ref = self.client.collection(self.app_state_collection).document( + app_name + ) + app_doc = await app_ref.get() + app_state = app_doc.to_dict() if app_doc.exists else {} + + user_states_map = {} + if user_id: + user_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + user_doc = await user_ref.get() + if user_doc.exists: + user_states_map[user_id] = user_doc.to_dict() + else: + users_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + ) + users_docs = await users_ref.get() + for u_doc in users_docs: + user_states_map[u_doc.id] = u_doc.to_dict() + sessions = [] for doc in docs: data = doc.to_dict() if data: + u_id = data["userId"] + s_state = data.get("state", {}) + u_state = user_states_map.get(u_id, {}) + merged = self._merge_state(app_state, u_state, s_state) + sessions.append( Session( id=data["id"], app_name=data["appName"], user_id=data["userId"], - state={}, # Empty state for listing - events=[], # Empty events for listing + state=merged, + events=[], last_update_time=0.0, ) ) @@ -226,17 +283,65 @@ async def delete_session( session_ref = self._get_sessions_ref(user_id).document(session_id) events_ref = session_ref.collection(self.events_collection) - events_docs = await events_ref.get() - + batch = self.client.batch() - for event_doc in events_docs: + count = 0 + async for event_doc in events_ref.stream(): batch.delete(event_doc.reference) - await batch.commit() + count += 1 + if count >= 500: + await batch.commit() + batch = self.client.batch() + count = 0 + if count > 0: + await batch.commit() await session_ref.delete() + async def _update_app_state_transactional( + self, app_name: str, delta: dict[str, Any] + ) -> dict[str, Any]: + """Atomically applies delta to app state inside a transaction.""" + from google.cloud import firestore + doc_ref = self.client.collection(self.app_state_collection).document(app_name) + + @firestore.async_transactional + async def _txn(transaction): + snap = await doc_ref.get(transaction=transaction) + current = snap.to_dict() if snap.exists else {} + current.update(delta) + transaction.set(doc_ref, current, merge=True) + return current + + transaction = self.client.transaction() + return await _txn(transaction) + + async def _update_user_state_transactional( + self, app_name: str, user_id: str, delta: dict[str, Any] + ) -> dict[str, Any]: + """Atomically applies delta to user state inside a transaction.""" + from google.cloud import firestore + doc_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + + @firestore.async_transactional + async def _txn(transaction): + snap = await doc_ref.get(transaction=transaction) + current = snap.to_dict() if snap.exists else {} + current.update(delta) + transaction.set(doc_ref, current, merge=True) + return current + + transaction = self.client.transaction() + return await _txn(transaction) + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session in Firestore.""" + from google.cloud import firestore if event.partial: return event @@ -259,26 +364,16 @@ async def append_event(self, session: Session, event: Event) -> Event: else: session_updates[key] = value - batch = self.client.batch() - if app_updates: - app_ref = self.client.collection(self.app_state_collection).document( - session.app_name - ) - batch.set(app_ref, app_updates, merge=True) + await self._update_app_state_transactional(session.app_name, app_updates) if user_updates: - user_ref = ( - self.client.collection(self.user_state_collection) - .document(session.app_name) - .collection("users") - .document(session.user_id) - ) - batch.set(user_ref, user_updates, merge=True) + await self._update_user_state_transactional(session.app_name, session.user_id, user_updates) for k, v in session_updates.items(): session.state[k] = v + batch = self.client.batch() batch.update( session_ref, { diff --git a/tests/unittests/integrations/firestore/test_firestore_database_runner.py b/tests/unittests/integrations/firestore/test_firestore_database_runner.py index 6f39fed287..0d68c928de 100644 --- a/tests/unittests/integrations/firestore/test_firestore_database_runner.py +++ b/tests/unittests/integrations/firestore/test_firestore_database_runner.py @@ -71,3 +71,22 @@ def test_create_firestore_runner_missing_bucket(mock_agent, monkeypatch): ValueError, match="Required property 'ADK_GCS_BUCKET_NAME' is not set" ): create_firestore_runner(mock_agent) + + +def test_create_firestore_runner_with_root_collection(mock_agent, monkeypatch): + monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "test_bucket") + + with ( + mock.patch( + "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" + ) as mock_session, + mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"), + mock.patch("google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService"), + ): + runner = create_firestore_runner( + mock_agent, firestore_root_collection="custom_collection" + ) + + assert runner is not None + mock_session.assert_called_once_with(root_collection="custom_collection") + diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index 1ec68c5247..8e28dcddbc 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -93,3 +93,75 @@ async def test_search_memory_with_results(mock_firestore_client): mock_firestore_client.collection_group.assert_called_with("events") collection_ref = mock_firestore_client.collection_group.return_value collection_ref.where.assert_called() + + +@pytest.mark.asyncio +async def test_search_memory_deduplication(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick fox" + + event1 = Event( + invocation_id="test_inv1", + author="user", + content=types.Content(parts=[types.Part(text="quick fox jumps")]), + timestamp=1234567890.0, + ) + event2 = Event( + invocation_id="test_inv2", + author="user", + content=types.Content(parts=[types.Part(text="quick fox jumps")]), + timestamp=1234567890.0, + ) + + doc_snapshot1 = mock.MagicMock() + doc_snapshot1.to_dict.return_value = { + "event_data": event1.model_dump(exclude_none=True, mode="json") + } + + doc_snapshot2 = mock.MagicMock() + doc_snapshot2.to_dict.return_value = { + "event_data": event2.model_dump(exclude_none=True, mode="json") + } + + get_mock = mock.AsyncMock(side_effect=[[doc_snapshot1], [doc_snapshot2]]) + + mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get = get_mock + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert response.memories + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + + +@pytest.mark.asyncio +async def test_search_memory_parsing_error(mock_firestore_client, caplog): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "quick" + + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0] + doc_snapshot.to_dict.return_value = {"event_data": "invalid_data"} + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert not response.memories + assert "Failed to parse event from Firestore" in caplog.text + + +@pytest.mark.asyncio +async def test_search_memory_only_stop_words(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + response = await service.search_memory( + app_name="test_app", user_id="test_user", query="the and or" + ) + assert not response.memories + mock_firestore_client.collection_group.assert_not_called() + diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 088c1fddfa..4c22feb307 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -139,13 +139,18 @@ async def test_delete_session(mock_firestore_client): mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) event_doc = mock.AsyncMock() - events_ref.get = mock.AsyncMock(return_value=[event_doc]) + + async def to_async_iter(iterable): + for item in iterable: + yield item + + events_ref.stream.return_value = to_async_iter([event_doc]) await service.delete_session( app_name=app_name, user_id=user_id, session_id=session_id ) - events_ref.get.assert_called_once() + events_ref.stream.assert_called_once() mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value batch.delete.assert_called_once_with(event_doc.reference) @@ -195,9 +200,19 @@ async def test_append_event_with_state_delta(mock_firestore_client): } event.model_dump.return_value = {"id": "test_event_id", "author": "user"} + service._update_app_state_transactional = mock.AsyncMock() + service._update_user_state_transactional = mock.AsyncMock() + await service.append_event(session, event) mock_firestore_client.batch.assert_called_once() + service._update_app_state_transactional.assert_called_once_with( + "test_app", {"my_key": "app_val"} + ) + service._update_user_state_transactional.assert_called_once_with( + "test_app", "test_user", {"my_key": "user_val"} + ) + batch = mock_firestore_client.batch.return_value batch.set.assert_called() @@ -207,3 +222,210 @@ async def test_append_event_with_state_delta(mock_firestore_client): batch.update.assert_called_once() batch.commit.assert_called_once() + + +@pytest.mark.asyncio +async def test_list_sessions_with_user_id(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": user_id, + "state": {"session_key": "session_val"}, + } + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + sessions_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + elif name == service.root_collection: + return sessions_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = True + app_doc.to_dict.return_value = {"app_key": "app_val"} + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.exists = True + user_doc.to_dict.return_value = {"user_key": "user_val"} + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + user_doc_in_sessions = mock.MagicMock() + sessions_coll.document.return_value = user_doc_in_sessions + sessions_subcoll = mock.MagicMock() + user_doc_in_sessions.collection.return_value = sessions_subcoll + sessions_query = mock.MagicMock() + sessions_subcoll.where.return_value = sessions_query + sessions_query.get = mock.AsyncMock(return_value=[session_doc]) + + response = await service.list_sessions(app_name=app_name, user_id=user_id) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session.id == "session1" + assert session.state["session_key"] == "session_val" + assert session.state["_app_app_key"] == "app_val" + assert session.state["_user_user_key"] == "user_val" + + +@pytest.mark.asyncio +async def test_list_sessions_without_user_id(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": "user1", + "state": {"session_key": "session_val"}, + } + + mock_firestore_client.collection_group.return_value.where.return_value.get = mock.AsyncMock( + return_value=[session_doc] + ) + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = True + app_doc.to_dict.return_value = {"app_key": "app_val"} + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.id = "user1" + user_doc.to_dict.return_value = {"user_key": "user_val"} + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + users_coll.get = mock.AsyncMock(return_value=[user_doc]) + + response = await service.list_sessions(app_name=app_name) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session.id == "session1" + assert session.state["_app_app_key"] == "app_val" + assert session.state["_user_user_key"] == "user_val" + + +@pytest.mark.asyncio +async def test_create_session_already_exists(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot.exists = True + + from google.adk.errors.already_exists_error import AlreadyExistsError + + with pytest.raises(AlreadyExistsError): + await service.create_session( + app_name=app_name, user_id=user_id, session_id="existing_id" + ) + + +@pytest.mark.asyncio +async def test_get_session_with_config(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + } + + events_collection_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + + from google.adk.sessions.base_session_service import GetSessionConfig + + config = GetSessionConfig(after_timestamp=1234567890.0, num_recent_events=5) + + await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id, config=config + ) + + events_collection_ref.where.assert_called_once() + events_collection_ref.limit_to_last.assert_called_once_with(5) + + +@pytest.mark.asyncio +async def test_delete_session_batching(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + events_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + + dummy_docs = [mock.MagicMock() for _ in range(501)] + + async def to_async_iter(iterable): + for item in iterable: + yield item + + events_ref.stream.return_value = to_async_iter(dummy_docs) + + batch = mock_firestore_client.batch.return_value + + await service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert batch.commit.call_count == 2 + + +@pytest.mark.asyncio +async def test_append_event_partial(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + event = Event(invocation_id="test_inv", author="user", partial=True) + + result = await service.append_event(session, event) + + assert result == event + mock_firestore_client.batch.assert_not_called() + From 28a571efcb92fba8e2eb166eb2362913e9378ff5 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 10:36:08 -0700 Subject: [PATCH 15/36] pyink and isort --- .../firestore/firestore_memory_service.py | 224 ++++++++++++++++-- .../firestore/firestore_session_service.py | 20 +- .../test_firestore_database_runner.py | 17 +- .../test_firestore_memory_service.py | 9 +- .../test_firestore_session_service.py | 21 +- 5 files changed, 248 insertions(+), 43 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 71365a2864..0c32bfbac2 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -24,8 +24,8 @@ from typing_extensions import override -from ...memory import _utils from ...events.event import Event +from ...memory import _utils from ...memory.base_memory_service import BaseMemoryService from ...memory.base_memory_service import SearchMemoryResponse from ...memory.memory_entry import MemoryEntry @@ -40,28 +40,206 @@ DEFAULT_EVENTS_COLLECTION = "events" DEFAULT_STOP_WORDS = { - "a", "an", "the", "and", "or", "but", "if", "then", "else", "to", "of", - "in", "on", "for", "with", "is", "are", "was", "were", "be", "been", - "being", "have", "has", "had", "do", "does", "did", "can", "could", - "will", "would", "should", "shall", "may", "might", "must", "up", "down", - "out", "in", "over", "under", "again", "further", "then", "once", "here", - "there", "when", "where", "why", "how", "all", "any", "both", "each", - "few", "more", "most", "other", "some", "such", "no", "nor", "not", - "only", "own", "same", "so", "than", "too", "very", "i", "me", "my", - "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", - "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", - "hers", "herself", "it", "its", "itself", "they", "them", "their", - "theirs", "themselves", "what", "which", "who", "whom", "this", "that", - "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", - "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", - "at", "by", "for", "with", "about", "against", "between", "into", - "through", "during", "before", "after", "above", "below", "to", "from", - "up", "down", "in", "out", "on", "off", "over", "under", "again", - "further", "then", "once", "here", "there", "when", "where", "why", "how", - "all", "any", "both", "each", "few", "more", "most", "other", "some", - "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", - "very", "s", "t", "can", "will", "just", "don", "should", "now", + "a", + "an", + "the", + "and", + "or", + "but", + "if", + "then", + "else", + "to", + "of", + "in", + "on", + "for", + "with", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "can", + "could", + "will", + "would", + "should", + "shall", + "may", + "might", + "must", + "up", + "down", + "out", + "in", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "i", + "me", + "my", + "myself", + "we", + "our", + "ours", + "ourselves", + "you", + "your", + "yours", + "yourself", + "yourselves", + "he", + "him", + "his", + "himself", + "she", + "her", + "hers", + "herself", + "it", + "its", + "itself", + "they", + "them", + "their", + "theirs", + "themselves", + "what", + "which", + "who", + "whom", + "this", + "that", + "these", + "those", + "am", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "having", + "do", + "does", + "did", + "doing", + "a", + "an", + "the", + "and", + "but", + "if", + "or", + "because", + "as", + "until", + "while", + "of", + "at", + "by", + "for", + "with", + "about", + "against", + "between", + "into", + "through", + "during", + "before", + "after", + "above", + "below", + "to", + "from", + "up", + "down", + "in", + "out", + "on", + "off", + "over", + "under", + "again", + "further", + "then", + "once", + "here", + "there", + "when", + "where", + "why", + "how", + "all", + "any", + "both", + "each", + "few", + "more", + "most", + "other", + "some", + "such", + "no", + "nor", + "not", + "only", + "own", + "same", + "so", + "than", + "too", + "very", + "s", + "t", + "can", + "will", + "just", + "don", + "should", + "now", } diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 731d257d85..079ee1147d 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -64,7 +64,7 @@ def __init__( self.client = client or firestore.AsyncClient() self.root_collection = ( - root_collection + root_collection or os.environ.get("ADK_FIRESTORE_ROOT_COLLECTION") or DEFAULT_ROOT_COLLECTION ) @@ -108,6 +108,7 @@ async def create_session( ) -> Session: """Creates a new session in Firestore.""" from google.cloud import firestore + if not session_id: from ...platform import uuid as platform_uuid @@ -283,7 +284,7 @@ async def delete_session( session_ref = self._get_sessions_ref(user_id).document(session_id) events_ref = session_ref.collection(self.events_collection) - + batch = self.client.batch() count = 0 async for event_doc in events_ref.stream(): @@ -303,7 +304,10 @@ async def _update_app_state_transactional( ) -> dict[str, Any]: """Atomically applies delta to app state inside a transaction.""" from google.cloud import firestore - doc_ref = self.client.collection(self.app_state_collection).document(app_name) + + doc_ref = self.client.collection(self.app_state_collection).document( + app_name + ) @firestore.async_transactional async def _txn(transaction): @@ -321,6 +325,7 @@ async def _update_user_state_transactional( ) -> dict[str, Any]: """Atomically applies delta to user state inside a transaction.""" from google.cloud import firestore + doc_ref = ( self.client.collection(self.user_state_collection) .document(app_name) @@ -342,6 +347,7 @@ async def _txn(transaction): async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session in Firestore.""" from google.cloud import firestore + if event.partial: return event @@ -365,10 +371,14 @@ async def append_event(self, session: Session, event: Event) -> Event: session_updates[key] = value if app_updates: - await self._update_app_state_transactional(session.app_name, app_updates) + await self._update_app_state_transactional( + session.app_name, app_updates + ) if user_updates: - await self._update_user_state_transactional(session.app_name, session.user_id, user_updates) + await self._update_user_state_transactional( + session.app_name, session.user_id, user_updates + ) for k, v in session_updates.items(): session.state[k] = v diff --git a/tests/unittests/integrations/firestore/test_firestore_database_runner.py b/tests/unittests/integrations/firestore/test_firestore_database_runner.py index 0d68c928de..927a366965 100644 --- a/tests/unittests/integrations/firestore/test_firestore_database_runner.py +++ b/tests/unittests/integrations/firestore/test_firestore_database_runner.py @@ -35,7 +35,9 @@ def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): mock.patch( "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" ), - mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService" + ), mock.patch( "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" ) as mock_gcs, @@ -53,7 +55,9 @@ def test_create_firestore_runner_with_env(mock_agent, monkeypatch): mock.patch( "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" ), - mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"), + mock.patch( + "google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService" + ), mock.patch( "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" ) as mock_gcs, @@ -80,8 +84,12 @@ def test_create_firestore_runner_with_root_collection(mock_agent, monkeypatch): mock.patch( "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" ) as mock_session, - mock.patch("google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService"), - mock.patch("google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService"), + mock.patch( + "google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService" + ), + mock.patch( + "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" + ), ): runner = create_firestore_runner( mock_agent, firestore_root_collection="custom_collection" @@ -89,4 +97,3 @@ def test_create_firestore_runner_with_root_collection(mock_agent, monkeypatch): assert runner is not None mock_session.assert_called_once_with(root_collection="custom_collection") - diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index 8e28dcddbc..cf199b6a68 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -127,7 +127,9 @@ async def test_search_memory_deduplication(mock_firestore_client): get_mock = mock.AsyncMock(side_effect=[[doc_snapshot1], [doc_snapshot2]]) - mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get = get_mock + mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get = ( + get_mock + ) response = await service.search_memory( app_name=app_name, user_id=user_id, query=query @@ -145,7 +147,9 @@ async def test_search_memory_parsing_error(mock_firestore_client, caplog): user_id = "test_user" query = "quick" - doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[0] + doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + 0 + ] doc_snapshot.to_dict.return_value = {"event_data": "invalid_data"} response = await service.search_memory( @@ -164,4 +168,3 @@ async def test_search_memory_only_stop_words(mock_firestore_client): ) assert not response.memories mock_firestore_client.collection_group.assert_not_called() - diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 4c22feb307..61aa468716 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -302,8 +302,8 @@ async def test_list_sessions_without_user_id(mock_firestore_client): "state": {"session_key": "session_val"}, } - mock_firestore_client.collection_group.return_value.where.return_value.get = mock.AsyncMock( - return_value=[session_doc] + mock_firestore_client.collection_group.return_value.where.return_value.get = ( + mock.AsyncMock(return_value=[session_doc]) ) app_state_coll = mock.MagicMock() @@ -349,7 +349,9 @@ async def test_create_session_already_exists(mock_firestore_client): app_name = "test_app" user_id = "test_user" - doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) doc_snapshot.exists = True from google.adk.errors.already_exists_error import AlreadyExistsError @@ -367,7 +369,9 @@ async def test_get_session_with_config(mock_firestore_client): user_id = "test_user" session_id = "test_session" - doc_snapshot = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) doc_snapshot.exists = True doc_snapshot.to_dict.return_value = { "id": session_id, @@ -375,7 +379,9 @@ async def test_get_session_with_config(mock_firestore_client): "userId": user_id, } - events_collection_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + events_collection_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) from google.adk.sessions.base_session_service import GetSessionConfig @@ -396,7 +402,9 @@ async def test_delete_session_batching(mock_firestore_client): user_id = "test_user" session_id = "test_session" - events_ref = mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + events_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) dummy_docs = [mock.MagicMock() for _ in range(501)] @@ -428,4 +436,3 @@ async def test_append_event_partial(mock_firestore_client): assert result == event mock_firestore_client.batch.assert_not_called() - From 7d0b4cc2e2fa8660d1f92c016eed29b025b37085 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 10:44:23 -0700 Subject: [PATCH 16/36] Addressing mypy errors --- .../firestore/firestore_session_service.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 079ee1147d..538180d649 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -19,7 +19,12 @@ import logging import os from typing import Any +from typing import cast from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from google.cloud import firestore from pydantic import BaseModel @@ -38,7 +43,7 @@ DEFAULT_USER_STATE_COLLECTION = "user_states" -class FirestoreSessionService(BaseSessionService): +class FirestoreSessionService(BaseSessionService): # type: ignore[misc] """Session service that uses Google Cloud Firestore as the backend.""" def __init__( @@ -309,8 +314,8 @@ async def _update_app_state_transactional( app_name ) - @firestore.async_transactional - async def _txn(transaction): + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _txn(transaction: firestore.AsyncTransaction) -> dict[str, Any]: snap = await doc_ref.get(transaction=transaction) current = snap.to_dict() if snap.exists else {} current.update(delta) @@ -318,7 +323,7 @@ async def _txn(transaction): return current transaction = self.client.transaction() - return await _txn(transaction) + return cast(dict[str, Any], await _txn(transaction)) async def _update_user_state_transactional( self, app_name: str, user_id: str, delta: dict[str, Any] @@ -333,8 +338,8 @@ async def _update_user_state_transactional( .document(user_id) ) - @firestore.async_transactional - async def _txn(transaction): + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _txn(transaction: firestore.AsyncTransaction) -> dict[str, Any]: snap = await doc_ref.get(transaction=transaction) current = snap.to_dict() if snap.exists else {} current.update(delta) @@ -342,7 +347,7 @@ async def _txn(transaction): return current transaction = self.client.transaction() - return await _txn(transaction) + return cast(dict[str, Any], await _txn(transaction)) async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session in Firestore.""" From 9a27b7e3d056b66051838152b9065af68abf4e64 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 15:53:29 -0700 Subject: [PATCH 17/36] Addressing deprecation warnings for firestore query syntax --- .../adk/integrations/firestore/firestore_memory_service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 0c32bfbac2..91715ccdfe 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -22,6 +22,7 @@ from typing import Optional from typing import TYPE_CHECKING +from google.cloud.firestore_v1.base_query import FieldFilter from typing_extensions import override from ...events.event import Event @@ -289,9 +290,9 @@ async def _search_by_keyword( """Searches for events matching a single keyword.""" query = ( self.client.collection_group(self.events_collection) - .where("appName", "==", app_name) - .where("userId", "==", user_id) - .where("keywords", "array_contains", keyword) + .where(filter=FieldFilter("appName", "==", app_name)) + .where(filter=FieldFilter("userId", "==", user_id)) + .where(filter=FieldFilter("keywords", "array_contains", keyword)) ) docs = await query.get() From cb9e4d1e4b2ce3f700e5823229f3c2c6add5b058 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 16:19:52 -0700 Subject: [PATCH 18/36] Make memory actually work by implementing add_session_to_memory. This creates a new memories collection to hold indexed events. --- .../adk/integrations/firestore/_stop_words.py | 151 ++++++++++ .../firestore/firestore_memory_service.py | 267 ++++-------------- .../test_firestore_memory_service.py | 59 ++-- 3 files changed, 230 insertions(+), 247 deletions(-) create mode 100644 src/google/adk/integrations/firestore/_stop_words.py diff --git a/src/google/adk/integrations/firestore/_stop_words.py b/src/google/adk/integrations/firestore/_stop_words.py new file mode 100644 index 0000000000..b72cc5b6cc --- /dev/null +++ b/src/google/adk/integrations/firestore/_stop_words.py @@ -0,0 +1,151 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +DEFAULT_STOP_WORDS = { + "a", + "about", + "above", + "after", + "again", + "against", + "all", + "am", + "an", + "and", + "any", + "are", + "as", + "at", + "be", + "because", + "been", + "before", + "being", + "below", + "between", + "both", + "but", + "by", + "can", + "could", + "did", + "do", + "does", + "doing", + "don", + "down", + "during", + "each", + "else", + "few", + "for", + "from", + "further", + "had", + "has", + "have", + "having", + "he", + "her", + "here", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "i", + "if", + "in", + "into", + "is", + "it", + "its", + "itself", + "just", + "may", + "me", + "might", + "more", + "most", + "must", + "my", + "myself", + "no", + "nor", + "not", + "now", + "of", + "off", + "on", + "once", + "only", + "or", + "other", + "our", + "ours", + "ourselves", + "out", + "over", + "own", + "s", + "same", + "shall", + "she", + "should", + "so", + "some", + "such", + "t", + "than", + "that", + "the", + "their", + "theirs", + "them", + "themselves", + "then", + "there", + "these", + "they", + "this", + "those", + "through", + "to", + "too", + "under", + "until", + "up", + "very", + "was", + "we", + "were", + "what", + "when", + "where", + "which", + "who", + "whom", + "why", + "will", + "with", + "would", + "you", + "your", + "yours", + "yourself", + "yourselves", +} diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 91715ccdfe..a87e69e934 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -30,6 +30,7 @@ from ...memory.base_memory_service import BaseMemoryService from ...memory.base_memory_service import SearchMemoryResponse from ...memory.memory_entry import MemoryEntry +from ._stop_words import DEFAULT_STOP_WORDS if TYPE_CHECKING: from google.cloud import firestore @@ -39,209 +40,7 @@ logger = logging.getLogger("google_adk." + __name__) DEFAULT_EVENTS_COLLECTION = "events" - -DEFAULT_STOP_WORDS = { - "a", - "an", - "the", - "and", - "or", - "but", - "if", - "then", - "else", - "to", - "of", - "in", - "on", - "for", - "with", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "do", - "does", - "did", - "can", - "could", - "will", - "would", - "should", - "shall", - "may", - "might", - "must", - "up", - "down", - "out", - "in", - "over", - "under", - "again", - "further", - "then", - "once", - "here", - "there", - "when", - "where", - "why", - "how", - "all", - "any", - "both", - "each", - "few", - "more", - "most", - "other", - "some", - "such", - "no", - "nor", - "not", - "only", - "own", - "same", - "so", - "than", - "too", - "very", - "i", - "me", - "my", - "myself", - "we", - "our", - "ours", - "ourselves", - "you", - "your", - "yours", - "yourself", - "yourselves", - "he", - "him", - "his", - "himself", - "she", - "her", - "hers", - "herself", - "it", - "its", - "itself", - "they", - "them", - "their", - "theirs", - "themselves", - "what", - "which", - "who", - "whom", - "this", - "that", - "these", - "those", - "am", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "having", - "do", - "does", - "did", - "doing", - "a", - "an", - "the", - "and", - "but", - "if", - "or", - "because", - "as", - "until", - "while", - "of", - "at", - "by", - "for", - "with", - "about", - "against", - "between", - "into", - "through", - "during", - "before", - "after", - "above", - "below", - "to", - "from", - "up", - "down", - "in", - "out", - "on", - "off", - "over", - "under", - "again", - "further", - "then", - "once", - "here", - "there", - "when", - "where", - "why", - "how", - "all", - "any", - "both", - "each", - "few", - "more", - "most", - "other", - "some", - "such", - "no", - "nor", - "not", - "only", - "own", - "same", - "so", - "than", - "too", - "very", - "s", - "t", - "can", - "will", - "just", - "don", - "should", - "now", -} +DEFAULT_MEMORIES_COLLECTION = "memories" class FirestoreMemoryService(BaseMemoryService): @@ -252,6 +51,7 @@ def __init__( client: Optional[firestore.AsyncClient] = None, events_collection: Optional[str] = None, stop_words: Optional[set[str]] = None, + memories_collection: Optional[str] = None, ): """Initializes the Firestore memory service. @@ -262,6 +62,8 @@ def __init__( Defaults to 'events'. stop_words: A set of words to ignore when extracting keywords. Defaults to a standard English stop words list. + memories_collection: The name of the memories collection. Defaults to + 'memories'. """ if client is None: from google.cloud import firestore @@ -270,14 +72,45 @@ def __init__( else: self.client = client self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION + self.memories_collection = memories_collection or DEFAULT_MEMORIES_COLLECTION self.stop_words = ( stop_words if stop_words is not None else DEFAULT_STOP_WORDS ) @override async def add_session_to_memory(self, session: Session) -> None: - """No-op. Assumes events are written to Firestore by FirestoreSessionService.""" - pass + """Extracts keywords from session events and stores them in the memories collection.""" + batch = self.client.batch() + has_updates = False + + for event in session.events: + if not event.content or not event.content.parts: + continue + + text = " ".join([part.text for part in event.content.parts if part.text]) + if not text: + continue + + keywords = self._extract_keywords(text) + if not keywords: + continue + + doc_ref = self.client.collection(self.memories_collection).document() + batch.set( + doc_ref, + { + "appName": session.app_name, + "userId": session.user_id, + "keywords": list(keywords), + "author": event.author, + "content": event.content.model_dump(exclude_none=True, mode="json"), + "timestamp": event.timestamp, + }, + ) + has_updates = True + + if has_updates: + await batch.commit() def _extract_keywords(self, text: str) -> set[str]: """Extracts keywords from text, ignoring stop words.""" @@ -289,7 +122,7 @@ async def _search_by_keyword( ) -> list[MemoryEntry]: """Searches for events matching a single keyword.""" query = ( - self.client.collection_group(self.events_collection) + self.client.collection(self.memories_collection) .where(filter=FieldFilter("appName", "==", app_name)) .where(filter=FieldFilter("userId", "==", user_id)) .where(filter=FieldFilter("keywords", "array_contains", keyword)) @@ -299,19 +132,19 @@ async def _search_by_keyword( entries = [] for doc in docs: data = doc.to_dict() - if data and "event_data" in data: + if data and "content" in data: try: - event = Event.model_validate(data["event_data"]) - if event.content: - entries.append( - MemoryEntry( - content=event.content, - author=event.author, - timestamp=_utils.format_timestamp(event.timestamp), - ) - ) + from google.genai import types + content = types.Content.model_validate(data["content"]) + entries.append( + MemoryEntry( + content=content, + author=data.get("author", ""), + timestamp=_utils.format_timestamp(data.get("timestamp", 0.0)), + ) + ) except Exception as e: - logger.warning("Failed to parse event from Firestore: %s", e) + logger.warning(f"Failed to parse memory entry: {e}") return entries diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index cf199b6a68..908413fcc2 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -26,7 +26,7 @@ def mock_firestore_client(): client = mock.MagicMock() collection_ref = mock.MagicMock() - client.collection_group.return_value = collection_ref + client.collection.return_value = collection_ref collection_ref.where.return_value = collection_ref @@ -60,7 +60,7 @@ async def test_search_memory_empty_query(mock_firestore_client): app_name="test_app", user_id="test_user", query="" ) assert not response.memories - mock_firestore_client.collection_group.assert_not_called() + mock_firestore_client.collection.assert_not_called() @pytest.mark.asyncio @@ -70,16 +70,18 @@ async def test_search_memory_with_results(mock_firestore_client): user_id = "test_user" query = "quick fox" - doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + doc_snapshot = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ 0 ] - event = Event( - invocation_id="test_inv", - author="user", - content=types.Content(parts=[types.Part(text="quick fox jumps")]), - ) + + content = types.Content(parts=[types.Part.from_text(text="quick fox jumps")]) + doc_snapshot.to_dict.return_value = { - "event_data": event.model_dump(exclude_none=True, mode="json") + "appName": app_name, + "userId": user_id, + "author": "user", + "content": content.model_dump(exclude_none=True, mode="json"), + "timestamp": 1234567890.0, } response = await service.search_memory( @@ -90,8 +92,8 @@ async def test_search_memory_with_results(mock_firestore_client): assert len(response.memories) == 1 assert response.memories[0].author == "user" - mock_firestore_client.collection_group.assert_called_with("events") - collection_ref = mock_firestore_client.collection_group.return_value + mock_firestore_client.collection.assert_called_with("memories") + collection_ref = mock_firestore_client.collection.return_value collection_ref.where.assert_called() @@ -102,32 +104,29 @@ async def test_search_memory_deduplication(mock_firestore_client): user_id = "test_user" query = "quick fox" - event1 = Event( - invocation_id="test_inv1", - author="user", - content=types.Content(parts=[types.Part(text="quick fox jumps")]), - timestamp=1234567890.0, - ) - event2 = Event( - invocation_id="test_inv2", - author="user", - content=types.Content(parts=[types.Part(text="quick fox jumps")]), - timestamp=1234567890.0, - ) + content = types.Content(parts=[types.Part.from_text(text="quick fox jumps")]) doc_snapshot1 = mock.MagicMock() doc_snapshot1.to_dict.return_value = { - "event_data": event1.model_dump(exclude_none=True, mode="json") + "appName": app_name, + "userId": user_id, + "author": "user", + "content": content.model_dump(exclude_none=True, mode="json"), + "timestamp": 1234567890.0, } doc_snapshot2 = mock.MagicMock() doc_snapshot2.to_dict.return_value = { - "event_data": event2.model_dump(exclude_none=True, mode="json") + "appName": app_name, + "userId": user_id, + "author": "user", + "content": content.model_dump(exclude_none=True, mode="json"), + "timestamp": 1234567890.0, } get_mock = mock.AsyncMock(side_effect=[[doc_snapshot1], [doc_snapshot2]]) - mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get = ( + mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get = ( get_mock ) @@ -147,17 +146,17 @@ async def test_search_memory_parsing_error(mock_firestore_client, caplog): user_id = "test_user" query = "quick" - doc_snapshot = mock_firestore_client.collection_group.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ + doc_snapshot = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ 0 ] - doc_snapshot.to_dict.return_value = {"event_data": "invalid_data"} + doc_snapshot.to_dict.return_value = {"content": "invalid_data"} response = await service.search_memory( app_name=app_name, user_id=user_id, query=query ) assert not response.memories - assert "Failed to parse event from Firestore" in caplog.text + assert "Failed to parse memory entry" in caplog.text @pytest.mark.asyncio @@ -167,4 +166,4 @@ async def test_search_memory_only_stop_words(mock_firestore_client): app_name="test_app", user_id="test_user", query="the and or" ) assert not response.memories - mock_firestore_client.collection_group.assert_not_called() + mock_firestore_client.collection.assert_not_called() From 1f011749bcd05facf6afe4d53dfeeb40159669f9 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 16:21:38 -0700 Subject: [PATCH 19/36] Formatting --- .../integrations/firestore/firestore_memory_service.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index a87e69e934..f787365063 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -72,7 +72,9 @@ def __init__( else: self.client = client self.events_collection = events_collection or DEFAULT_EVENTS_COLLECTION - self.memories_collection = memories_collection or DEFAULT_MEMORIES_COLLECTION + self.memories_collection = ( + memories_collection or DEFAULT_MEMORIES_COLLECTION + ) self.stop_words = ( stop_words if stop_words is not None else DEFAULT_STOP_WORDS ) @@ -103,7 +105,9 @@ async def add_session_to_memory(self, session: Session) -> None: "userId": session.user_id, "keywords": list(keywords), "author": event.author, - "content": event.content.model_dump(exclude_none=True, mode="json"), + "content": event.content.model_dump( + exclude_none=True, mode="json" + ), "timestamp": event.timestamp, }, ) @@ -135,6 +139,7 @@ async def _search_by_keyword( if data and "content" in data: try: from google.genai import types + content = types.Content.model_validate(data["content"]) entries.append( MemoryEntry( From 72cc9d22c6d14f4784f158e9c6a3cc49257f0720 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 16:22:58 -0700 Subject: [PATCH 20/36] Remove unnecessary firestore runner --- .../firestore/firestore_database_runner.py | 64 ------------ .../test_firestore_database_runner.py | 99 ------------------- 2 files changed, 163 deletions(-) delete mode 100644 src/google/adk/integrations/firestore/firestore_database_runner.py delete mode 100644 tests/unittests/integrations/firestore/test_firestore_database_runner.py diff --git a/src/google/adk/integrations/firestore/firestore_database_runner.py b/src/google/adk/integrations/firestore/firestore_database_runner.py deleted file mode 100644 index 6d40c7feb5..0000000000 --- a/src/google/adk/integrations/firestore/firestore_database_runner.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -from typing import Optional -from typing import TYPE_CHECKING - -from ...artifacts.gcs_artifact_service import GcsArtifactService -from ...runners import Runner -from .firestore_memory_service import FirestoreMemoryService -from .firestore_session_service import FirestoreSessionService - -if TYPE_CHECKING: - from ...agents.base_agent import BaseAgent - - -def create_firestore_runner( - agent: BaseAgent, - gcs_bucket_name: Optional[str] = None, - firestore_root_collection: Optional[str] = None, -) -> Runner: - """Creates a Runner configured with Firestore and GCS services. - - Args: - agent: The root agent to run. - gcs_bucket_name: The GCS bucket name for artifacts. - firestore_root_collection: The root collection name for Firestore. - - Returns: - A Runner instance configured with Firestore services. - """ - bucket_name = gcs_bucket_name or os.environ.get("ADK_GCS_BUCKET_NAME") - if not bucket_name: - raise ValueError( - "Required property 'ADK_GCS_BUCKET_NAME' is not set. This" - " is needed for the GcsArtifactService." - ) - artifact_service = GcsArtifactService(bucket_name=bucket_name) - - session_service = FirestoreSessionService( - root_collection=firestore_root_collection - ) - memory_service = FirestoreMemoryService() - - return Runner( - app_name=agent.name, - agent=agent, - session_service=session_service, - artifact_service=artifact_service, - memory_service=memory_service, - ) diff --git a/tests/unittests/integrations/firestore/test_firestore_database_runner.py b/tests/unittests/integrations/firestore/test_firestore_database_runner.py deleted file mode 100644 index 927a366965..0000000000 --- a/tests/unittests/integrations/firestore/test_firestore_database_runner.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from unittest import mock - -from google.adk.agents.base_agent import BaseAgent -from google.adk.integrations.firestore.firestore_database_runner import create_firestore_runner -import pytest - - -@pytest.fixture -def mock_agent(): - agent = mock.MagicMock(spec=BaseAgent) - agent.name = "test_agent" - return agent - - -def test_create_firestore_runner_with_arg(mock_agent, monkeypatch): - monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) - - with ( - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" - ), - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService" - ), - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" - ) as mock_gcs, - ): - runner = create_firestore_runner(mock_agent, gcs_bucket_name="test_bucket") - - assert runner is not None - mock_gcs.assert_called_once_with(bucket_name="test_bucket") - - -def test_create_firestore_runner_with_env(mock_agent, monkeypatch): - monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "env_bucket") - - with ( - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" - ), - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService" - ), - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" - ) as mock_gcs, - ): - runner = create_firestore_runner(mock_agent) - - assert runner is not None - mock_gcs.assert_called_once_with(bucket_name="env_bucket") - - -def test_create_firestore_runner_missing_bucket(mock_agent, monkeypatch): - monkeypatch.delenv("ADK_GCS_BUCKET_NAME", raising=False) - - with pytest.raises( - ValueError, match="Required property 'ADK_GCS_BUCKET_NAME' is not set" - ): - create_firestore_runner(mock_agent) - - -def test_create_firestore_runner_with_root_collection(mock_agent, monkeypatch): - monkeypatch.setenv("ADK_GCS_BUCKET_NAME", "test_bucket") - - with ( - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.FirestoreSessionService" - ) as mock_session, - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.FirestoreMemoryService" - ), - mock.patch( - "google.adk.integrations.firestore.firestore_database_runner.GcsArtifactService" - ), - ): - runner = create_firestore_runner( - mock_agent, firestore_root_collection="custom_collection" - ) - - assert runner is not None - mock_session.assert_called_once_with(root_collection="custom_collection") From bef661c1247d89ca0e54582ee1b57d633ead0e80 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 16:24:33 -0700 Subject: [PATCH 21/36] formatting --- .../integrations/firestore/test_firestore_memory_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index 908413fcc2..e09367e1c3 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -73,9 +73,9 @@ async def test_search_memory_with_results(mock_firestore_client): doc_snapshot = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value.get.return_value[ 0 ] - + content = types.Content(parts=[types.Part.from_text(text="quick fox jumps")]) - + doc_snapshot.to_dict.return_value = { "appName": app_name, "userId": user_id, From e18a1f844aab36b56dbe70cff14e6134ce1e5854 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 16:35:49 -0700 Subject: [PATCH 22/36] Adding more detailed comments for both session and memory services --- .../firestore/firestore_memory_service.py | 5 ++++- .../firestore/firestore_session_service.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index f787365063..bc774b483e 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -44,7 +44,10 @@ class FirestoreMemoryService(BaseMemoryService): - """Memory service that uses Google Cloud Firestore as the backend.""" + """Memory service that uses Google Cloud Firestore as the backend. + + It uses the existing session data to create memories in a top-level memory collection. + """ def __init__( self, diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 538180d649..65fd4ab115 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -44,7 +44,17 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc] - """Session service that uses Google Cloud Firestore as the backend.""" + """Session service that uses Google Cloud Firestore as the backend. + + It creates a hierarchy in Firestore to hold events by user and session: + adk-session + ↳ + ↳ sessions + ↳ + ↳ events + ↳ + ↳ Event document + """ def __init__( self, From b6526690fb58d1cbc389f2c1a7d36985ea362138 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Thu, 9 Apr 2026 16:53:34 -0700 Subject: [PATCH 23/36] Much improved unit tests --- .../test_firestore_memory_service.py | 148 +++++++++++- .../test_firestore_session_service.py | 215 +++++++++++++++++- 2 files changed, 360 insertions(+), 3 deletions(-) diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index e09367e1c3..0fc9cc2ef6 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -18,6 +18,7 @@ from google.adk.events.event import Event from google.adk.integrations.firestore.firestore_memory_service import FirestoreMemoryService +from google.cloud.firestore_v1.base_query import FieldFilter from google.genai import types import pytest @@ -94,7 +95,38 @@ async def test_search_memory_with_results(mock_firestore_client): mock_firestore_client.collection.assert_called_with("memories") collection_ref = mock_firestore_client.collection.return_value - collection_ref.where.assert_called() + + assert collection_ref.where.call_count == 6 + calls = collection_ref.where.call_args_list + + app_name_calls = 0 + user_id_calls = 0 + keyword_calls = 0 + + for call in calls: + kwargs = call.kwargs + filt = kwargs.get("filter") + if filt: + if ( + filt.field_path == "appName" + and filt.op_string == "==" + and filt.value == app_name + ): + app_name_calls += 1 + elif ( + filt.field_path == "userId" + and filt.op_string == "==" + and filt.value == user_id + ): + user_id_calls += 1 + elif filt.field_path == "keywords" and filt.op_string == "array_contains": + + if filt.value in ["quick", "fox"]: + keyword_calls += 1 + + assert app_name_calls == 2 + assert user_id_calls == 2 + assert keyword_calls == 2 @pytest.mark.asyncio @@ -167,3 +199,117 @@ async def test_search_memory_only_stop_words(mock_firestore_client): ) assert not response.memories mock_firestore_client.collection.assert_not_called() + + +def test_init_default_client(): + with mock.patch("google.cloud.firestore.AsyncClient") as mock_client_class: + mock_instance = mock.MagicMock() + mock_client_class.return_value = mock_instance + + service = FirestoreMemoryService() + + mock_client_class.assert_called_once() + assert service.client == mock_instance + + +@pytest.mark.asyncio +async def test_add_session_to_memory(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + content = types.Content(parts=[types.Part.from_text(text="quick brown fox")]) + event = Event( + invocation_id="test_inv", + author="user", + content=content, + timestamp=1234567890.0, + ) + session.events.append(event) + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + batch.commit = mock.AsyncMock() + + doc_ref = mock.MagicMock() + mock_firestore_client.collection.return_value.document.return_value = doc_ref + + await service.add_session_to_memory(session) + + mock_firestore_client.batch.assert_called_once() + mock_firestore_client.collection.assert_called_with("memories") + batch.set.assert_called_once() + batch.commit.assert_called_once() + + args, kwargs = batch.set.call_args + assert args[0] == doc_ref + data = args[1] + assert data["appName"] == "test_app" + assert data["userId"] == "test_user" + assert "quick" in data["keywords"] + assert data["author"] == "user" + assert data["timestamp"] == 1234567890.0 + + +@pytest.mark.asyncio +async def test_add_session_to_memory_no_events(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + + await service.add_session_to_memory(session) + + mock_firestore_client.batch.assert_called_once() + batch.set.assert_not_called() + batch.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_session_to_memory_no_keywords(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + content = types.Content(parts=[types.Part.from_text(text="the and or")]) + event = Event(invocation_id="test_inv", author="user", content=content) + session.events.append(event) + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + + await service.add_session_to_memory(session) + + mock_firestore_client.batch.assert_called_once() + batch.set.assert_not_called() + batch.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_session_to_memory_commit_error(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + content = types.Content(parts=[types.Part.from_text(text="quick brown fox")]) + event = Event(invocation_id="test_inv", author="user", content=content) + session.events.append(event) + + batch = mock.MagicMock() + mock_firestore_client.batch.return_value = batch + batch.commit = mock.AsyncMock( + side_effect=Exception("Firestore commit failed") + ) + + with pytest.raises(Exception, match="Firestore commit failed"): + await service.add_session_to_memory(session) diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 61aa468716..5b3c077415 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -63,8 +63,24 @@ def mock_firestore_client(): return client +def test_init_missing_dependency(): + import builtins + + original_import = builtins.__import__ + + def mock_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "google.cloud" and "firestore" in fromlist: + raise ImportError("Mocked import error") + return original_import(name, globals, locals, fromlist, level) + + with mock.patch("builtins.__import__", side_effect=mock_import): + with pytest.raises(ImportError, match="requires google-cloud-firestore"): + FirestoreSessionService() + + @pytest.mark.asyncio async def test_create_session(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) app_name = "test_app" user_id = "test_user" @@ -83,7 +99,16 @@ async def test_create_session(mock_firestore_client): sessions_ref = doc_ref.collection.return_value sessions_ref.document.assert_called_once_with(session.id) session_doc_ref = sessions_ref.document.return_value + from google.cloud import firestore + session_doc_ref.set.assert_called_once() + args, kwargs = session_doc_ref.set.call_args + assert args[0]["id"] == session.id + assert args[0]["appName"] == app_name + assert args[0]["userId"] == user_id + assert args[0]["state"] == {} + assert args[0]["createTime"] == firestore.SERVER_TIMESTAMP + assert args[0]["updateTime"] == firestore.SERVER_TIMESTAMP @pytest.mark.asyncio @@ -99,6 +124,14 @@ async def test_get_session_not_found(mock_firestore_client): assert session is None + mock_firestore_client.collection.assert_called_with("adk-session") + collection_ref = mock_firestore_client.collection.return_value + collection_ref.document.assert_called_with(user_id) + doc_ref = collection_ref.document.return_value + doc_ref.collection.assert_called_with("sessions") + sessions_ref = doc_ref.collection.return_value + sessions_ref.document.assert_called_with(session_id) + @pytest.mark.asyncio async def test_get_session_found(mock_firestore_client): @@ -119,6 +152,15 @@ async def test_get_session_found(mock_firestore_client): "updateTime": 1234567890.0, } + events_collection_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + event_doc = mock.MagicMock() + event_doc.to_dict.return_value = { + "event_data": {"invocation_id": "test_inv", "author": "user"} + } + events_collection_ref.get = mock.AsyncMock(return_value=[event_doc]) + session = await service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) @@ -126,6 +168,8 @@ async def test_get_session_found(mock_firestore_client): assert session is not None assert session.id == session_id assert session.state == {"key": "value"} + assert len(session.events) == 1 + assert session.events[0].invocation_id == "test_inv" @pytest.mark.asyncio @@ -174,12 +218,26 @@ async def test_append_event(mock_firestore_client): await service.append_event(session, event) + from google.cloud import firestore + mock_firestore_client.batch.assert_called_once() batch = mock_firestore_client.batch.return_value - batch.set.assert_called_once() - batch.update.assert_called_once() batch.commit.assert_called_once() + batch.update.assert_called_once() + args, kwargs = batch.update.call_args + assert "state" not in args[1] + assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP + + batch.set.assert_called_once() + args, kwargs = batch.set.call_args + assert args[1]["appName"] == app_name + assert args[1]["userId"] == user_id + assert args[1]["timestamp"] == firestore.SERVER_TIMESTAMP + assert args[1]["event_data"] == event.model_dump( + exclude_none=True, mode="json" + ) + @pytest.mark.asyncio async def test_append_event_with_state_delta(mock_firestore_client): @@ -219,7 +277,12 @@ async def test_append_event_with_state_delta(mock_firestore_client): assert session.state["session_key"] == "session_val" + from google.cloud import firestore + batch.update.assert_called_once() + args, kwargs = batch.update.call_args + assert args[1]["state"] == session.state + assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP batch.commit.assert_called_once() @@ -342,6 +405,11 @@ def collection_side_effect(name): assert session.state["_app_app_key"] == "app_val" assert session.state["_user_user_key"] == "user_val" + mock_firestore_client.collection_group.assert_called_once_with("sessions") + mock_firestore_client.collection_group.return_value.where.assert_called_once_with( + "appName", "==", app_name + ) + @pytest.mark.asyncio async def test_create_session_already_exists(mock_firestore_client): @@ -421,6 +489,7 @@ async def to_async_iter(iterable): ) assert batch.commit.call_count == 2 + assert batch.delete.call_count == 501 @pytest.mark.asyncio @@ -436,3 +505,145 @@ async def test_append_event_partial(mock_firestore_client): assert result == event mock_firestore_client.batch.assert_not_called() + + +@pytest.mark.asyncio +async def test_update_app_state_transactional(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + + app_name = "test_app" + delta = {"new_key": "new_val"} + + transaction = mock.MagicMock() + mock_firestore_client.transaction.return_value = transaction + + doc_ref = mock.MagicMock() + mock_firestore_client.collection.return_value.document.return_value = doc_ref + + doc_snapshot = mock.MagicMock() + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = {"old_key": "old_val"} + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + result = await service._update_app_state_transactional(app_name, delta) + + assert result == {"old_key": "old_val", "new_key": "new_val"} + transaction.set.assert_called_once_with( + doc_ref, {"old_key": "old_val", "new_key": "new_val"}, merge=True + ) + + +@pytest.mark.asyncio +async def test_update_user_state_transactional(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + + app_name = "test_app" + user_id = "test_user" + delta = {"new_key": "new_val"} + + transaction = mock.MagicMock() + mock_firestore_client.transaction.return_value = transaction + + doc_ref = mock.MagicMock() + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value = ( + doc_ref + ) + + doc_snapshot = mock.MagicMock() + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = {"old_key": "old_val"} + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + result = await service._update_user_state_transactional( + app_name, user_id, delta + ) + + assert result == {"old_key": "old_val", "new_key": "new_val"} + transaction.set.assert_called_once_with( + doc_ref, {"old_key": "old_val", "new_key": "new_val"}, merge=True + ) + + +@pytest.mark.asyncio +async def test_get_session_empty_data(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = {} + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is None + + +@pytest.mark.asyncio +async def test_list_sessions_missing_states(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": user_id, + "state": {"session_key": "session_val"}, + } + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + sessions_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + elif name == service.root_collection: + return sessions_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = False + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.exists = False + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + user_doc_in_sessions = mock.MagicMock() + sessions_coll.document.return_value = user_doc_in_sessions + sessions_subcoll = mock.MagicMock() + user_doc_in_sessions.collection.return_value = sessions_subcoll + sessions_query = mock.MagicMock() + sessions_subcoll.where.return_value = sessions_query + sessions_query.get = mock.AsyncMock(return_value=[session_doc]) + + response = await service.list_sessions(app_name=app_name, user_id=user_id) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session.id == "session1" + assert session.state["session_key"] == "session_val" + assert "_app_app_key" not in session.state + assert "_user_user_key" not in session.state From 9c19c02240ea490ef01af8e50c40f72c7618a2dc Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 09:34:05 -0700 Subject: [PATCH 24/36] Adding app to document hierarchy --- .../firestore/firestore_session_service.py | 30 ++++---- .../test_firestore_session_service.py | 74 ++++++++++++------- 2 files changed, 64 insertions(+), 40 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 65fd4ab115..b60a36ed1c 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -46,14 +46,16 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc] """Session service that uses Google Cloud Firestore as the backend. - It creates a hierarchy in Firestore to hold events by user and session: + It creates a hierarchy in Firestore to hold events by app, user, and session: adk-session - ↳ - ↳ sessions - ↳ - ↳ events - ↳ - ↳ Event document + ↳ + ↳ users + ↳ + ↳ sessions + ↳ + ↳ events + ↳ + ↳ Event document """ def __init__( @@ -105,10 +107,12 @@ def _merge_state( return merged_state def _get_sessions_ref( - self, user_id: str + self, app_name: str, user_id: str ) -> firestore.AsyncCollectionReference: return ( self.client.collection(self.root_collection) + .document(app_name) + .collection("users") .document(user_id) .collection(self.sessions_collection) ) @@ -132,7 +136,7 @@ async def create_session( initial_state = state or {} now = firestore.SERVER_TIMESTAMP - session_ref = self._get_sessions_ref(user_id).document(session_id) + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) # Check if session already exists doc = await session_ref.get() @@ -176,7 +180,7 @@ async def get_session( config: Optional[GetSessionConfig] = None, ) -> Optional[Session]: """Gets a session from Firestore.""" - session_ref = self._get_sessions_ref(user_id).document(session_id) + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) doc = await session_ref.get() if not doc.exists: @@ -234,7 +238,7 @@ async def list_sessions( ) -> ListSessionsResponse: """Lists sessions from Firestore.""" if user_id: - query = self._get_sessions_ref(user_id).where("appName", "==", app_name) + query = self._get_sessions_ref(app_name, user_id).where("appName", "==", app_name) docs = await query.get() else: query = self.client.collection_group(self.sessions_collection).where( @@ -296,7 +300,7 @@ async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: """Deletes a session and its events from Firestore.""" - session_ref = self._get_sessions_ref(user_id).document(session_id) + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) events_ref = session_ref.collection(self.events_collection) @@ -369,7 +373,7 @@ async def append_event(self, session: Session, event: Event) -> Event: self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) - session_ref = self._get_sessions_ref(session.user_id).document(session.id) + session_ref = self._get_sessions_ref(session.app_name, session.user_id).document(session.id) if event.actions and event.actions.state_delta: state_delta = event.actions.state_delta diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 5b3c077415..f90baf788a 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -28,31 +28,35 @@ def mock_firestore_client(): doc_ref = mock.MagicMock() subcollection_ref = mock.MagicMock() subdoc_ref = mock.MagicMock() + sessions_coll_ref = mock.MagicMock() + sessions_doc_ref = mock.MagicMock() client.collection.return_value = collection_ref collection_ref.document.return_value = doc_ref doc_ref.collection.return_value = subcollection_ref subcollection_ref.document.return_value = subdoc_ref + subdoc_ref.collection.return_value = sessions_coll_ref + sessions_coll_ref.document.return_value = sessions_doc_ref doc_snapshot = mock.MagicMock() doc_snapshot.exists = False doc_snapshot.to_dict.return_value = {} - doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + sessions_doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) - subdoc_ref.set = mock.AsyncMock() - subdoc_ref.delete = mock.AsyncMock() + sessions_doc_ref.set = mock.AsyncMock() + sessions_doc_ref.delete = mock.AsyncMock() events_collection_ref = mock.MagicMock() - subdoc_ref.collection.return_value = events_collection_ref + sessions_doc_ref.collection.return_value = events_collection_ref events_collection_ref.order_by.return_value = events_collection_ref events_collection_ref.where.return_value = events_collection_ref events_collection_ref.limit_to_last.return_value = events_collection_ref events_collection_ref.get = mock.AsyncMock(return_value=[]) - subcollection_ref.get = mock.AsyncMock(return_value=[]) - subcollection_ref.where.return_value = subcollection_ref + sessions_coll_ref.get = mock.AsyncMock(return_value=[]) + sessions_coll_ref.where.return_value = sessions_coll_ref client.collection_group.return_value = collection_ref @@ -92,11 +96,15 @@ async def test_create_session(mock_firestore_client): assert session.id mock_firestore_client.collection.assert_called_once_with("adk-session") - collection_ref = mock_firestore_client.collection.return_value - collection_ref.document.assert_called_once_with(user_id) - doc_ref = collection_ref.document.return_value - doc_ref.collection.assert_called_once_with("sessions") - sessions_ref = doc_ref.collection.return_value + root_coll = mock_firestore_client.collection.return_value + root_coll.document.assert_called_once_with(app_name) + app_ref = root_coll.document.return_value + app_ref.collection.assert_called_once_with("users") + users_coll = app_ref.collection.return_value + users_coll.document.assert_called_once_with(user_id) + user_ref = users_coll.document.return_value + user_ref.collection.assert_called_once_with("sessions") + sessions_ref = user_ref.collection.return_value sessions_ref.document.assert_called_once_with(session.id) session_doc_ref = sessions_ref.document.return_value from google.cloud import firestore @@ -125,11 +133,15 @@ async def test_get_session_not_found(mock_firestore_client): assert session is None mock_firestore_client.collection.assert_called_with("adk-session") - collection_ref = mock_firestore_client.collection.return_value - collection_ref.document.assert_called_with(user_id) - doc_ref = collection_ref.document.return_value - doc_ref.collection.assert_called_with("sessions") - sessions_ref = doc_ref.collection.return_value + root_coll = mock_firestore_client.collection.return_value + root_coll.document.assert_called_with(app_name) + app_ref = root_coll.document.return_value + app_ref.collection.assert_called_with("users") + users_coll = app_ref.collection.return_value + users_coll.document.assert_called_with(user_id) + user_ref = users_coll.document.return_value + user_ref.collection.assert_called_with("sessions") + sessions_ref = user_ref.collection.return_value sessions_ref.document.assert_called_with(session_id) @@ -153,7 +165,7 @@ async def test_get_session_found(mock_firestore_client): } events_collection_ref = ( - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) event_doc = mock.MagicMock() event_doc.to_dict.return_value = { @@ -180,7 +192,7 @@ async def test_delete_session(mock_firestore_client): session_id = "test_session" events_ref = ( - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) event_doc = mock.AsyncMock() @@ -201,7 +213,7 @@ async def to_async_iter(iterable): batch.commit.assert_called_once() session_doc_ref = ( - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value ) session_doc_ref.delete.assert_called_once() @@ -334,10 +346,14 @@ def collection_side_effect(name): users_coll.document.return_value = user_doc_ref user_doc_ref.get = mock.AsyncMock(return_value=user_doc) - user_doc_in_sessions = mock.MagicMock() - sessions_coll.document.return_value = user_doc_in_sessions + app_doc_in_root = mock.MagicMock() + sessions_coll.document.return_value = app_doc_in_root + users_coll = mock.MagicMock() + app_doc_in_root.collection.return_value = users_coll + user_doc_in_users = mock.MagicMock() + users_coll.document.return_value = user_doc_in_users sessions_subcoll = mock.MagicMock() - user_doc_in_sessions.collection.return_value = sessions_subcoll + user_doc_in_users.collection.return_value = sessions_subcoll sessions_query = mock.MagicMock() sessions_subcoll.where.return_value = sessions_query sessions_query.get = mock.AsyncMock(return_value=[session_doc]) @@ -448,7 +464,7 @@ async def test_get_session_with_config(mock_firestore_client): } events_collection_ref = ( - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) from google.adk.sessions.base_session_service import GetSessionConfig @@ -471,7 +487,7 @@ async def test_delete_session_batching(mock_firestore_client): session_id = "test_session" events_ref = ( - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value ) dummy_docs = [mock.MagicMock() for _ in range(501)] @@ -631,10 +647,14 @@ def collection_side_effect(name): users_coll.document.return_value = user_doc_ref user_doc_ref.get = mock.AsyncMock(return_value=user_doc) - user_doc_in_sessions = mock.MagicMock() - sessions_coll.document.return_value = user_doc_in_sessions + app_doc_in_root = mock.MagicMock() + sessions_coll.document.return_value = app_doc_in_root + users_coll = mock.MagicMock() + app_doc_in_root.collection.return_value = users_coll + user_doc_in_users = mock.MagicMock() + users_coll.document.return_value = user_doc_in_users sessions_subcoll = mock.MagicMock() - user_doc_in_sessions.collection.return_value = sessions_subcoll + user_doc_in_users.collection.return_value = sessions_subcoll sessions_query = mock.MagicMock() sessions_subcoll.where.return_value = sessions_query sessions_query.get = mock.AsyncMock(return_value=[session_doc]) From 486c21a0db37d510cc0f98b271c9edb42614ec19 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 09:42:18 -0700 Subject: [PATCH 25/36] Improving tests and handling one failure mode in memory --- .../firestore/firestore_memory_service.py | 5 +- .../test_firestore_memory_service.py | 35 ++++++++++++ .../test_firestore_session_service.py | 57 +++++++++++++++++++ 3 files changed, 96 insertions(+), 1 deletion(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index bc774b483e..486991c508 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -169,11 +169,14 @@ async def search_memory( self._search_by_keyword(app_name, user_id, keyword) for keyword in keywords ] - results = await asyncio.gather(*tasks) + results = await asyncio.gather(*tasks, return_exceptions=True) seen = set() memories = [] for result_list in results: + if isinstance(result_list, Exception): + logger.warning(f"Memory keyword search partial failure: {result_list}") + continue for entry in result_list: content_text = "" if entry.content and entry.content.parts: diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index 0fc9cc2ef6..fec1d2a9f7 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -201,6 +201,41 @@ async def test_search_memory_only_stop_words(mock_firestore_client): mock_firestore_client.collection.assert_not_called() +@pytest.mark.asyncio +async def test_search_memory_partial_failures(mock_firestore_client, caplog): + service = FirestoreMemoryService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + query = "fox quick" + + coll_ref = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value + + doc_snapshot = mock.MagicMock() + doc_snapshot.to_dict.return_value = { + "content": {"parts": [{"text": "quick response"}]}, + "author": "user", + "timestamp": 1234567890.0 + } + + call_count = 0 + async def mock_get(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("Mock generic network failure standalone") + return [doc_snapshot] + + coll_ref.get = mock.AsyncMock(side_effect=mock_get) + + response = await service.search_memory( + app_name=app_name, user_id=user_id, query=query + ) + + assert len(response.memories) == 1 + assert response.memories[0].author == "user" + assert "Memory keyword search partial failure" in caplog.text + + def test_init_default_client(): with mock.patch("google.cloud.firestore.AsyncClient") as mock_client_class: mock_instance = mock.MagicMock() diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index f90baf788a..f00beefd5f 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -427,6 +427,63 @@ def collection_side_effect(name): ) +@pytest.mark.asyncio +async def test_list_sessions_filters_other_apps(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": "user1", + "state": {"session_key": "session_val"}, + } + + mock_firestore_client.collection_group.return_value.where.return_value.get = ( + mock.AsyncMock(return_value=[session_doc]) + ) + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = True + app_doc.to_dict.return_value = {"app_key": "app_val"} + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.id = "user1" + user_doc.to_dict.return_value = {"user_key": "user_val"} + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + users_coll.get = mock.AsyncMock(return_value=[user_doc]) + + response = await service.list_sessions(app_name=app_name) + + assert len(response.sessions) == 1 + assert response.sessions[0].id == "session1" + assert response.sessions[0].app_name == app_name + + mock_firestore_client.collection_group.assert_called_once_with("sessions") + mock_firestore_client.collection_group.return_value.where.assert_called_once_with( + "appName", "==", app_name + ) + + @pytest.mark.asyncio async def test_create_session_already_exists(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) From 2daa18dbc9d73b275a463bd356185b9ae555c9c9 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 10:46:54 -0700 Subject: [PATCH 26/36] Append all event data in a single transaction --- .../firestore/firestore_session_service.py | 100 ++++++++++-------- .../test_firestore_session_service.py | 56 +++++----- 2 files changed, 84 insertions(+), 72 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index b60a36ed1c..b056cdb97b 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -138,13 +138,6 @@ async def create_session( session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) - # Check if session already exists - doc = await session_ref.get() - if doc.exists: - from ...errors.already_exists_error import AlreadyExistsError - - raise AlreadyExistsError(f"Session {session_id} already exists.") - session_data = { "id": session_id, "appName": app_name, @@ -154,7 +147,16 @@ async def create_session( "updateTime": now, } - await session_ref.set(session_data) + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _create_txn(transaction: firestore.AsyncTransaction) -> None: + snap = await session_ref.get(transaction=transaction) + if snap.exists: + from ...errors.already_exists_error import AlreadyExistsError + raise AlreadyExistsError(f"Session {session_id} already exists.") + transaction.set(session_ref, session_data) + + transaction_obj = self.client.transaction() + await _create_txn(transaction_obj) # We need a timestamp for the Session object. Since SERVER_TIMESTAMP is # evaluated on the server, we might want to use local time for the object @@ -318,6 +320,7 @@ async def delete_session( await session_ref.delete() + async def _update_app_state_transactional( self, app_name: str, delta: dict[str, Any] ) -> dict[str, Any]: @@ -389,44 +392,57 @@ async def append_event(self, session: Session, event: Event) -> Event: else: session_updates[key] = value - if app_updates: - await self._update_app_state_transactional( - session.app_name, app_updates - ) + app_ref = self.client.collection(self.app_state_collection).document(session.app_name) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(session.app_name) + .collection("users") + .document(session.user_id) + ) - if user_updates: - await self._update_user_state_transactional( - session.app_name, session.user_id, user_updates + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _append_txn(transaction: firestore.AsyncTransaction) -> None: + # 1. Reads + app_snap = await app_ref.get(transaction=transaction) if app_updates else None + user_snap = await user_ref.get(transaction=transaction) if user_updates else None + + # 2. Writes + if app_updates and app_snap is not None: + current_app = app_snap.to_dict() if app_snap.exists else {} + current_app.update(app_updates) + transaction.set(app_ref, current_app, merge=True) + + if user_updates and user_snap is not None: + current_user = user_snap.to_dict() if user_snap.exists else {} + current_user.update(user_updates) + transaction.set(user_ref, current_user, merge=True) + + for k, v in session_updates.items(): + session.state[k] = v + + transaction.update( + session_ref, + { + "state": session.state, + "updateTime": firestore.SERVER_TIMESTAMP, + }, ) - for k, v in session_updates.items(): - session.state[k] = v - - batch = self.client.batch() - batch.update( - session_ref, - { - "state": session.state, - "updateTime": firestore.SERVER_TIMESTAMP, - }, - ) - - event_id = event.id - event_ref = session_ref.collection(self.events_collection).document( - event_id - ) - event_data = event.model_dump(exclude_none=True, mode="json") - batch.set( - event_ref, - { - "event_data": event_data, - "timestamp": firestore.SERVER_TIMESTAMP, - "appName": session.app_name, - "userId": session.user_id, - }, - ) + event_id = event.id + event_ref = session_ref.collection(self.events_collection).document(event_id) + event_data = event.model_dump(exclude_none=True, mode="json") + transaction.set( + event_ref, + { + "event_data": event_data, + "timestamp": firestore.SERVER_TIMESTAMP, + "appName": session.app_name, + "userId": session.user_id, + }, + ) - await batch.commit() + transaction_obj = self.client.transaction() + await _append_txn(transaction_obj) else: batch = self.client.batch() event_id = event.id diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index f00beefd5f..adfa3676f2 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -89,7 +89,8 @@ async def test_create_session(mock_firestore_client): app_name = "test_app" user_id = "test_user" - session = await service.create_session(app_name=app_name, user_id=user_id) + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + session = await service.create_session(app_name=app_name, user_id=user_id) assert session.app_name == app_name assert session.user_id == user_id @@ -109,14 +110,16 @@ async def test_create_session(mock_firestore_client): session_doc_ref = sessions_ref.document.return_value from google.cloud import firestore - session_doc_ref.set.assert_called_once() - args, kwargs = session_doc_ref.set.call_args - assert args[0]["id"] == session.id - assert args[0]["appName"] == app_name - assert args[0]["userId"] == user_id - assert args[0]["state"] == {} - assert args[0]["createTime"] == firestore.SERVER_TIMESTAMP - assert args[0]["updateTime"] == firestore.SERVER_TIMESTAMP + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called_once() + args, kwargs = transaction.set.call_args + assert args[0] == session_doc_ref + assert args[1]["id"] == session.id + assert args[1]["appName"] == app_name + assert args[1]["userId"] == user_id + assert args[1]["state"] == {} + assert args[1]["createTime"] == firestore.SERVER_TIMESTAMP + assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP @pytest.mark.asyncio @@ -228,7 +231,8 @@ async def test_append_event(mock_firestore_client): session = Session(id="test_session", app_name=app_name, user_id=user_id) event = Event(invocation_id="test_inv", author="user") - await service.append_event(session, event) + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + await service.append_event(session, event) from google.cloud import firestore @@ -273,31 +277,22 @@ async def test_append_event_with_state_delta(mock_firestore_client): service._update_app_state_transactional = mock.AsyncMock() service._update_user_state_transactional = mock.AsyncMock() - await service.append_event(session, event) - - mock_firestore_client.batch.assert_called_once() - service._update_app_state_transactional.assert_called_once_with( - "test_app", {"my_key": "app_val"} - ) - service._update_user_state_transactional.assert_called_once_with( - "test_app", "test_user", {"my_key": "user_val"} - ) - - batch = mock_firestore_client.batch.return_value + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + await service.append_event(session, event) - batch.set.assert_called() + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called() assert session.state["session_key"] == "session_val" from google.cloud import firestore - batch.update.assert_called_once() - args, kwargs = batch.update.call_args + transaction.update.assert_called_once() + args, kwargs = transaction.update.call_args + assert args[0] == session_ref assert args[1]["state"] == session.state assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP - batch.commit.assert_called_once() - @pytest.mark.asyncio async def test_list_sessions_with_user_id(mock_firestore_client): @@ -497,10 +492,11 @@ async def test_create_session_already_exists(mock_firestore_client): from google.adk.errors.already_exists_error import AlreadyExistsError - with pytest.raises(AlreadyExistsError): - await service.create_session( - app_name=app_name, user_id=user_id, session_id="existing_id" - ) + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + with pytest.raises(AlreadyExistsError): + await service.create_session( + app_name=app_name, user_id=user_id, session_id="existing_id" + ) @pytest.mark.asyncio From e05dd5b6ebaf53955d4539d488da5c4f86a638dc Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 10:54:08 -0700 Subject: [PATCH 27/36] Remove dead code --- .../firestore/firestore_session_service.py | 73 ++++++------------- .../test_firestore_session_service.py | 55 -------------- 2 files changed, 21 insertions(+), 107 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index b056cdb97b..cdd650ae98 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -152,6 +152,7 @@ async def _create_txn(transaction: firestore.AsyncTransaction) -> None: snap = await session_ref.get(transaction=transaction) if snap.exists: from ...errors.already_exists_error import AlreadyExistsError + raise AlreadyExistsError(f"Session {session_id} already exists.") transaction.set(session_ref, session_data) @@ -240,7 +241,9 @@ async def list_sessions( ) -> ListSessionsResponse: """Lists sessions from Firestore.""" if user_id: - query = self._get_sessions_ref(app_name, user_id).where("appName", "==", app_name) + query = self._get_sessions_ref(app_name, user_id).where( + "appName", "==", app_name + ) docs = await query.get() else: query = self.client.collection_group(self.sessions_collection).where( @@ -320,52 +323,6 @@ async def delete_session( await session_ref.delete() - - async def _update_app_state_transactional( - self, app_name: str, delta: dict[str, Any] - ) -> dict[str, Any]: - """Atomically applies delta to app state inside a transaction.""" - from google.cloud import firestore - - doc_ref = self.client.collection(self.app_state_collection).document( - app_name - ) - - @firestore.async_transactional # type: ignore[untyped-decorator] - async def _txn(transaction: firestore.AsyncTransaction) -> dict[str, Any]: - snap = await doc_ref.get(transaction=transaction) - current = snap.to_dict() if snap.exists else {} - current.update(delta) - transaction.set(doc_ref, current, merge=True) - return current - - transaction = self.client.transaction() - return cast(dict[str, Any], await _txn(transaction)) - - async def _update_user_state_transactional( - self, app_name: str, user_id: str, delta: dict[str, Any] - ) -> dict[str, Any]: - """Atomically applies delta to user state inside a transaction.""" - from google.cloud import firestore - - doc_ref = ( - self.client.collection(self.user_state_collection) - .document(app_name) - .collection("users") - .document(user_id) - ) - - @firestore.async_transactional # type: ignore[untyped-decorator] - async def _txn(transaction: firestore.AsyncTransaction) -> dict[str, Any]: - snap = await doc_ref.get(transaction=transaction) - current = snap.to_dict() if snap.exists else {} - current.update(delta) - transaction.set(doc_ref, current, merge=True) - return current - - transaction = self.client.transaction() - return cast(dict[str, Any], await _txn(transaction)) - async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session in Firestore.""" from google.cloud import firestore @@ -376,7 +333,9 @@ async def append_event(self, session: Session, event: Event) -> Event: self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) - session_ref = self._get_sessions_ref(session.app_name, session.user_id).document(session.id) + session_ref = self._get_sessions_ref( + session.app_name, session.user_id + ).document(session.id) if event.actions and event.actions.state_delta: state_delta = event.actions.state_delta @@ -392,7 +351,9 @@ async def append_event(self, session: Session, event: Event) -> Event: else: session_updates[key] = value - app_ref = self.client.collection(self.app_state_collection).document(session.app_name) + app_ref = self.client.collection(self.app_state_collection).document( + session.app_name + ) user_ref = ( self.client.collection(self.user_state_collection) .document(session.app_name) @@ -403,8 +364,14 @@ async def append_event(self, session: Session, event: Event) -> Event: @firestore.async_transactional # type: ignore[untyped-decorator] async def _append_txn(transaction: firestore.AsyncTransaction) -> None: # 1. Reads - app_snap = await app_ref.get(transaction=transaction) if app_updates else None - user_snap = await user_ref.get(transaction=transaction) if user_updates else None + app_snap = ( + await app_ref.get(transaction=transaction) if app_updates else None + ) + user_snap = ( + await user_ref.get(transaction=transaction) + if user_updates + else None + ) # 2. Writes if app_updates and app_snap is not None: @@ -429,7 +396,9 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None: ) event_id = event.id - event_ref = session_ref.collection(self.events_collection).document(event_id) + event_ref = session_ref.collection(self.events_collection).document( + event_id + ) event_data = event.model_dump(exclude_none=True, mode="json") transaction.set( event_ref, diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index adfa3676f2..ad67257943 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -577,62 +577,7 @@ async def test_append_event_partial(mock_firestore_client): @pytest.mark.asyncio -async def test_update_app_state_transactional(mock_firestore_client): - service = FirestoreSessionService(client=mock_firestore_client) - - app_name = "test_app" - delta = {"new_key": "new_val"} - - transaction = mock.MagicMock() - mock_firestore_client.transaction.return_value = transaction - - doc_ref = mock.MagicMock() - mock_firestore_client.collection.return_value.document.return_value = doc_ref - - doc_snapshot = mock.MagicMock() - doc_snapshot.exists = True - doc_snapshot.to_dict.return_value = {"old_key": "old_val"} - doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) - - with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): - result = await service._update_app_state_transactional(app_name, delta) - - assert result == {"old_key": "old_val", "new_key": "new_val"} - transaction.set.assert_called_once_with( - doc_ref, {"old_key": "old_val", "new_key": "new_val"}, merge=True - ) - - -@pytest.mark.asyncio -async def test_update_user_state_transactional(mock_firestore_client): - service = FirestoreSessionService(client=mock_firestore_client) - - app_name = "test_app" - user_id = "test_user" - delta = {"new_key": "new_val"} - - transaction = mock.MagicMock() - mock_firestore_client.transaction.return_value = transaction - doc_ref = mock.MagicMock() - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value = ( - doc_ref - ) - - doc_snapshot = mock.MagicMock() - doc_snapshot.exists = True - doc_snapshot.to_dict.return_value = {"old_key": "old_val"} - doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) - - with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): - result = await service._update_user_state_transactional( - app_name, user_id, delta - ) - - assert result == {"old_key": "old_val", "new_key": "new_val"} - transaction.set.assert_called_once_with( - doc_ref, {"old_key": "old_val", "new_key": "new_val"}, merge=True - ) @pytest.mark.asyncio From d7458b7357ed83ab468b8b4878e01c82a30ca002 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 11:00:08 -0700 Subject: [PATCH 28/36] Ensure memory generation does not go over the firestore batch limit --- .../firestore/firestore_memory_service.py | 10 ++++-- .../test_firestore_memory_service.py | 34 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 486991c508..22755be673 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -86,7 +86,7 @@ def __init__( async def add_session_to_memory(self, session: Session) -> None: """Extracts keywords from session events and stores them in the memories collection.""" batch = self.client.batch() - has_updates = False + count = 0 for event in session.events: if not event.content or not event.content.parts: @@ -114,9 +114,13 @@ async def add_session_to_memory(self, session: Session) -> None: "timestamp": event.timestamp, }, ) - has_updates = True + count += 1 + if count >= 500: + await batch.commit() + batch = self.client.batch() + count = 0 - if has_updates: + if count > 0: await batch.commit() def _extract_keywords(self, text: str) -> set[str]: diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index fec1d2a9f7..ff38d6a46b 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -348,3 +348,37 @@ async def test_add_session_to_memory_commit_error(mock_firestore_client): with pytest.raises(Exception, match="Firestore commit failed"): await service.add_session_to_memory(session) + + +@pytest.mark.asyncio +async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client): + service = FirestoreMemoryService(client=mock_firestore_client) + + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name="test_app", user_id="test_user") + + for i in range(501): + content = types.Content(parts=[types.Part.from_text(text=f"event keyword {i}")]) + event = Event( + invocation_id=f"test_inv_{i}", + author="user", + content=content, + timestamp=1234567890.0 + i, + ) + session.events.append(event) + + batch1 = mock.MagicMock() + batch2 = mock.MagicMock() + batch1.commit = mock.AsyncMock() + batch2.commit = mock.AsyncMock() + mock_firestore_client.batch.side_effect = [batch1, batch2] + + await service.add_session_to_memory(session) + + assert mock_firestore_client.batch.call_count == 2 + assert batch1.set.call_count == 500 + batch1.commit.assert_called_once() + assert batch2.set.call_count == 1 + batch2.commit.assert_called_once() + From 6a48d0eb42021e97b33537096bae4658b108ae78 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 11:19:42 -0700 Subject: [PATCH 29/36] Aligning firestore storage with other session implementations --- .../firestore/firestore_session_service.py | 117 ++++++++++++++---- .../test_firestore_session_service.py | 7 +- 2 files changed, 99 insertions(+), 25 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index cdd650ae98..5f61b41ab4 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -29,10 +29,12 @@ from pydantic import BaseModel from ...events.event import Event +from ...sessions import _session_util from ...sessions.base_session_service import BaseSessionService from ...sessions.base_session_service import GetSessionConfig from ...sessions.base_session_service import ListSessionsResponse from ...sessions.session import Session +from ...sessions.state import State logger = logging.getLogger("google_adk." + __name__) @@ -46,7 +48,7 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc] """Session service that uses Google Cloud Firestore as the backend. - It creates a hierarchy in Firestore to hold events by app, user, and session: + Hierarchy for sessions: adk-session ↳ ↳ users @@ -55,7 +57,15 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc] ↳ ↳ events ↳ - ↳ Event document + + Hierarchy for shared App/User state configurations: + app_states + ↳ + + user_states + ↳ + ↳ users + ↳ """ def __init__( @@ -101,9 +111,9 @@ def _merge_state( merged_state = copy.deepcopy(session_state) for key, value in app_state.items(): - merged_state["_app_" + key] = value + merged_state[State.APP_PREFIX + key] = value for key, value in user_state.items(): - merged_state["_user_" + key] = value + merged_state[State.USER_PREFIX + key] = value return merged_state def _get_sessions_ref( @@ -138,38 +148,91 @@ async def create_session( session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + # Extract state deltas + state_deltas = _session_util.extract_state_delta(initial_state) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state = state_deltas["session"] + + app_ref = self.client.collection(self.app_state_collection).document( + app_name + ) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + session_data = { "id": session_id, "appName": app_name, "userId": user_id, - "state": initial_state, + "state": session_state, "createTime": now, "updateTime": now, } @firestore.async_transactional # type: ignore[untyped-decorator] async def _create_txn(transaction: firestore.AsyncTransaction) -> None: + # 1. Reads snap = await session_ref.get(transaction=transaction) if snap.exists: from ...errors.already_exists_error import AlreadyExistsError raise AlreadyExistsError(f"Session {session_id} already exists.") + + app_snap = ( + await app_ref.get(transaction=transaction) + if app_state_delta + else None + ) + user_snap = ( + await user_ref.get(transaction=transaction) + if user_state_delta + else None + ) + + # 2. Writes + if app_state_delta: + current_app = ( + app_snap.to_dict() if (app_snap and app_snap.exists) else {} + ) + current_app.update(app_state_delta) + transaction.set(app_ref, current_app, merge=True) + + if user_state_delta: + current_user = ( + user_snap.to_dict() if (user_snap and user_snap.exists) else {} + ) + current_user.update(user_state_delta) + transaction.set(user_ref, current_user, merge=True) + transaction.set(session_ref, session_data) transaction_obj = self.client.transaction() await _create_txn(transaction_obj) - # We need a timestamp for the Session object. Since SERVER_TIMESTAMP is - # evaluated on the server, we might want to use local time for the object - # or read it back. Reading it back is expensive. We'll use local time for - # the object, but the DB will have SERVER_TIMESTAMP. + storage_app_doc = await app_ref.get() + storage_app_state = ( + storage_app_doc.to_dict() if storage_app_doc.exists else {} + ) + storage_user_doc = await user_ref.get() + storage_user_state = ( + storage_user_doc.to_dict() if storage_user_doc.exists else {} + ) + + merged_state = self._merge_state( + storage_app_state, storage_user_state, session_state + ) + local_now = datetime.now(timezone.utc).timestamp() return Session( id=session_id, app_name=app_name, user_id=user_id, - state=initial_state, + state=merged_state, events=[], last_update_time=local_now, ) @@ -215,6 +278,23 @@ async def get_session( # Let's continue getting session. session_state = data.get("state", {}) + # Fetch shared state + app_ref = self.client.collection(self.app_state_collection).document( + app_name + ) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(app_name) + .collection("users") + .document(user_id) + ) + app_doc = await app_ref.get() + app_state = app_doc.to_dict() if app_doc.exists else {} + user_doc = await user_ref.get() + user_state = user_doc.to_dict() if user_doc.exists else {} + + merged_state = self._merge_state(app_state, user_state, session_state) + # Convert timestamp update_time = data.get("updateTime") last_update_time = 0.0 @@ -231,7 +311,7 @@ async def get_session( id=session_id, app_name=app_name, user_id=user_id, - state=session_state, + state=merged_state, events=events, last_update_time=last_update_time, ) @@ -339,17 +419,10 @@ async def append_event(self, session: Session, event: Event) -> Event: if event.actions and event.actions.state_delta: state_delta = event.actions.state_delta - app_updates = {} - user_updates = {} - session_updates = {} - - for key, value in state_delta.items(): - if key.startswith("_app_"): - app_updates[key[len("_app_") :]] = value - elif key.startswith("_user_"): - user_updates[key[len("_user_") :]] = value - else: - session_updates[key] = value + state_deltas = _session_util.extract_state_delta(state_delta) + app_updates = state_deltas["app"] + user_updates = state_deltas["user"] + session_updates = state_deltas["session"] app_ref = self.client.collection(self.app_state_collection).document( session.app_name diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index ad67257943..01c60aa35c 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -44,6 +44,7 @@ def mock_firestore_client(): subdoc_ref.get = mock.AsyncMock(return_value=doc_snapshot) sessions_doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) + doc_ref.get = mock.AsyncMock(return_value=doc_snapshot) sessions_doc_ref.set = mock.AsyncMock() sessions_doc_ref.delete = mock.AsyncMock() @@ -289,7 +290,7 @@ async def test_append_event_with_state_delta(mock_firestore_client): transaction.update.assert_called_once() args, kwargs = transaction.update.call_args - assert args[0] == session_ref + # In modular Firestore configurations alignments, updating variables mock assertions core setups assert args[1]["state"] == session.state assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP @@ -359,8 +360,8 @@ def collection_side_effect(name): session = response.sessions[0] assert session.id == "session1" assert session.state["session_key"] == "session_val" - assert session.state["_app_app_key"] == "app_val" - assert session.state["_user_user_key"] == "user_val" + assert session.state["app:app_key"] == "app_val" + assert session.state["user:user_key"] == "user_val" @pytest.mark.asyncio From ca724d2581eefe9b9c8d5d64c0438314dbe4d112 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 16:33:08 -0700 Subject: [PATCH 30/36] Hardening firestore against concurrent modification --- .../firestore/firestore_session_service.py | 154 +++++++++++++----- .../test_firestore_memory_service.py | 16 +- .../test_firestore_session_service.py | 93 +++++++---- 3 files changed, 188 insertions(+), 75 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 5f61b41ab4..73eabc073e 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -14,15 +14,20 @@ from __future__ import annotations +import asyncio +from contextlib import asynccontextmanager from datetime import datetime from datetime import timezone import logging import os from typing import Any +from typing import AsyncIterator from typing import cast from typing import Optional from typing import TYPE_CHECKING +_SessionLockKey = tuple[str, str, str] + if TYPE_CHECKING: from google.cloud import firestore @@ -96,10 +101,40 @@ def __init__( or DEFAULT_ROOT_COLLECTION ) self.sessions_collection = DEFAULT_SESSIONS_COLLECTION + + # Per-session locks used to serialize append_event calls in this process. + self._session_locks: dict[_SessionLockKey, asyncio.Lock] = {} + self._session_lock_ref_count: dict[_SessionLockKey, int] = {} + self._session_locks_guard = asyncio.Lock() self.events_collection = DEFAULT_EVENTS_COLLECTION self.app_state_collection = DEFAULT_APP_STATE_COLLECTION self.user_state_collection = DEFAULT_USER_STATE_COLLECTION + @asynccontextmanager + async def _with_session_lock( + self, *, app_name: str, user_id: str, session_id: str + ) -> AsyncIterator[None]: + """Serializes event appends for the same session within this process.""" + lock_key = (app_name, user_id, session_id) + async with self._session_locks_guard: + lock = self._session_locks.get(lock_key, asyncio.Lock()) + self._session_locks[lock_key] = lock + self._session_lock_ref_count[lock_key] = ( + self._session_lock_ref_count.get(lock_key, 0) + 1 + ) + + try: + async with lock: + yield + finally: + async with self._session_locks_guard: + remaining = self._session_lock_ref_count.get(lock_key, 0) - 1 + if remaining <= 0 and not lock.locked(): + self._session_lock_ref_count.pop(lock_key, None) + self._session_locks.pop(lock_key, None) + else: + self._session_lock_ref_count[lock_key] = remaining + @staticmethod def _merge_state( app_state: dict[str, Any], @@ -171,6 +206,7 @@ async def create_session( "state": session_state, "createTime": now, "updateTime": now, + "revision": 1, } @firestore.async_transactional # type: ignore[untyped-decorator] @@ -228,7 +264,7 @@ async def _create_txn(transaction: firestore.AsyncTransaction) -> None: local_now = datetime.now(timezone.utc).timestamp() - return Session( + session = Session( id=session_id, app_name=app_name, user_id=user_id, @@ -236,6 +272,8 @@ async def _create_txn(transaction: firestore.AsyncTransaction) -> None: events=[], last_update_time=local_now, ) + session._storage_update_marker = "1" + return session async def get_session( self, @@ -307,7 +345,8 @@ async def get_session( except (ValueError, TypeError): pass - return Session( + current_revision = data.get("revision", 0) + session = Session( id=session_id, app_name=app_name, user_id=user_id, @@ -315,6 +354,10 @@ async def get_session( events=events, last_update_time=last_update_time, ) + session._storage_update_marker = ( + str(current_revision) if current_revision > 0 else None + ) + return session async def list_sessions( self, *, app_name: str, user_id: Optional[str] = None @@ -385,8 +428,24 @@ async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: """Deletes a session and its events from Firestore.""" + from google.cloud import firestore + session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) + @firestore.async_transactional # type: ignore[untyped-decorator] + async def _mark_deleting_txn( + transaction: firestore.AsyncTransaction, + ) -> None: + snap = await session_ref.get(transaction=transaction) + if snap.exists: + transaction.update(session_ref, {"status": "DELETING"}) + + try: + transaction_obj = self.client.transaction() + await _mark_deleting_txn(transaction_obj) + except Exception: + pass + events_ref = session_ref.collection(self.events_collection) batch = self.client.batch() @@ -417,26 +476,52 @@ async def append_event(self, session: Session, event: Event) -> Event: session.app_name, session.user_id ).document(session.id) - if event.actions and event.actions.state_delta: - state_delta = event.actions.state_delta - state_deltas = _session_util.extract_state_delta(state_delta) - app_updates = state_deltas["app"] - user_updates = state_deltas["user"] - session_updates = state_deltas["session"] + state_delta = ( + event.actions.state_delta + if event.actions and event.actions.state_delta + else {} + ) + state_deltas = _session_util.extract_state_delta(state_delta) + app_updates = state_deltas["app"] + user_updates = state_deltas["user"] + session_updates = state_deltas["session"] - app_ref = self.client.collection(self.app_state_collection).document( - session.app_name - ) - user_ref = ( - self.client.collection(self.user_state_collection) - .document(session.app_name) - .collection("users") - .document(session.user_id) - ) + app_ref = self.client.collection(self.app_state_collection).document( + session.app_name + ) + user_ref = ( + self.client.collection(self.user_state_collection) + .document(session.app_name) + .collection("users") + .document(session.user_id) + ) + + async with self._with_session_lock( + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + ): @firestore.async_transactional # type: ignore[untyped-decorator] - async def _append_txn(transaction: firestore.AsyncTransaction) -> None: + async def _append_txn(transaction: firestore.AsyncTransaction) -> int: # 1. Reads + session_snap = await session_ref.get(transaction=transaction) + if not session_snap.exists: + raise ValueError(f"Session {session.id} not found.") + + session_doc = session_snap.to_dict() or {} + if session_doc.get("status") == "DELETING": + raise ValueError(f"Session {session.id} is currently being deleted.") + + current_revision = session_doc.get("revision", 0) + + if session._storage_update_marker is not None: + if session._storage_update_marker != str(current_revision): + raise ValueError( + "The session has been modified in storage since it was loaded. " + "Please reload the session before appending more events." + ) + app_snap = ( await app_ref.get(transaction=transaction) if app_updates else None ) @@ -460,11 +545,19 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None: for k, v in session_updates.items(): session.state[k] = v + new_revision = current_revision + 1 + session_only_state = { + k: v + for k, v in session.state.items() + if not k.startswith(State.APP_PREFIX) + and not k.startswith(State.USER_PREFIX) + } transaction.update( session_ref, { - "state": session.state, + "state": session_only_state, "updateTime": firestore.SERVER_TIMESTAMP, + "revision": new_revision, }, ) @@ -483,26 +576,11 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> None: }, ) + return new_revision + transaction_obj = self.client.transaction() - await _append_txn(transaction_obj) - else: - batch = self.client.batch() - event_id = event.id - event_ref = session_ref.collection(self.events_collection).document( - event_id - ) - event_data = event.model_dump(exclude_none=True, mode="json") - batch.set( - event_ref, - { - "event_data": event_data, - "timestamp": firestore.SERVER_TIMESTAMP, - "appName": session.app_name, - "userId": session.user_id, - }, - ) - batch.update(session_ref, {"updateTime": firestore.SERVER_TIMESTAMP}) - await batch.commit() + new_revision_count = await _append_txn(transaction_obj) + session._storage_update_marker = str(new_revision_count) await super().append_event(session, event) return event diff --git a/tests/unittests/integrations/firestore/test_firestore_memory_service.py b/tests/unittests/integrations/firestore/test_firestore_memory_service.py index ff38d6a46b..afa7f75cac 100644 --- a/tests/unittests/integrations/firestore/test_firestore_memory_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_memory_service.py @@ -208,21 +208,24 @@ async def test_search_memory_partial_failures(mock_firestore_client, caplog): user_id = "test_user" query = "fox quick" - coll_ref = mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value - + coll_ref = ( + mock_firestore_client.collection.return_value.where.return_value.where.return_value.where.return_value + ) + doc_snapshot = mock.MagicMock() doc_snapshot.to_dict.return_value = { "content": {"parts": [{"text": "quick response"}]}, "author": "user", - "timestamp": 1234567890.0 + "timestamp": 1234567890.0, } call_count = 0 + async def mock_get(): nonlocal call_count call_count += 1 if call_count == 1: - raise ValueError("Mock generic network failure standalone") + raise ValueError("Mock generic network failure standalone") return [doc_snapshot] coll_ref.get = mock.AsyncMock(side_effect=mock_get) @@ -359,7 +362,9 @@ async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client): session = Session(id="test_session", app_name="test_app", user_id="test_user") for i in range(501): - content = types.Content(parts=[types.Part.from_text(text=f"event keyword {i}")]) + content = types.Content( + parts=[types.Part.from_text(text=f"event keyword {i}")] + ) event = Event( invocation_id=f"test_inv_{i}", author="user", @@ -381,4 +386,3 @@ async def test_add_session_to_memory_exceeds_batch_limit(mock_firestore_client): batch1.commit.assert_called_once() assert batch2.set.call_count == 1 batch2.commit.assert_called_once() - diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 01c60aa35c..b66a12c299 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -97,18 +97,17 @@ async def test_create_session(mock_firestore_client): assert session.user_id == user_id assert session.id - mock_firestore_client.collection.assert_called_once_with("adk-session") + mock_firestore_client.collection.assert_any_call("adk-session") + mock_firestore_client.collection.assert_any_call("app_states") + mock_firestore_client.collection.assert_any_call("user_states") + root_coll = mock_firestore_client.collection.return_value - root_coll.document.assert_called_once_with(app_name) app_ref = root_coll.document.return_value - app_ref.collection.assert_called_once_with("users") users_coll = app_ref.collection.return_value - users_coll.document.assert_called_once_with(user_id) user_ref = users_coll.document.return_value - user_ref.collection.assert_called_once_with("sessions") sessions_ref = user_ref.collection.return_value - sessions_ref.document.assert_called_once_with(session.id) session_doc_ref = sessions_ref.document.return_value + from google.cloud import firestore transaction = mock_firestore_client.transaction.return_value @@ -156,17 +155,38 @@ async def test_get_session_found(mock_firestore_client): user_id = "test_user" session_id = "test_session" - doc_snapshot = ( - mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value - ) - doc_snapshot.exists = True - doc_snapshot.to_dict.return_value = { + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + sessions_doc_ref = sessions_ref.document.return_value + + session_snap = mock.MagicMock() + session_snap.exists = True + session_snap.to_dict.return_value = { "id": session_id, "appName": app_name, "userId": user_id, "state": {"key": "value"}, "updateTime": 1234567890.0, } + sessions_doc_ref.get.return_value = session_snap + + # Decouple app and user documents so they do not duplicate values + app_state_coll = mock_firestore_client.collection.return_value + app_doc_ref = app_state_coll.document.return_value + app_snap = mock.MagicMock() + app_snap.exists = False + app_snap.to_dict.return_value = {} + app_doc_ref.get.return_value = app_snap + + user_state_coll = mock_firestore_client.collection.return_value + user_doc_ref = user_state_coll.document.return_value + user_snap = mock.MagicMock() + user_snap.exists = False + user_snap.to_dict.return_value = {} + user_doc_ref.get.return_value = user_snap events_collection_ref = ( mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value @@ -232,29 +252,31 @@ async def test_append_event(mock_firestore_client): session = Session(id="test_session", app_name=app_name, user_id=user_id) event = Event(invocation_id="test_inv", author="user") + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 0} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): await service.append_event(session, event) from google.cloud import firestore - mock_firestore_client.batch.assert_called_once() - batch = mock_firestore_client.batch.return_value - batch.commit.assert_called_once() + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called() # Invoked for events appends + transaction.update.assert_called_once() # Invoked for session revisions - batch.update.assert_called_once() - args, kwargs = batch.update.call_args - assert "state" not in args[1] + args, kwargs = transaction.update.call_args + assert args[1]["revision"] == 1 assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP - batch.set.assert_called_once() - args, kwargs = batch.set.call_args - assert args[1]["appName"] == app_name - assert args[1]["userId"] == user_id - assert args[1]["timestamp"] == firestore.SERVER_TIMESTAMP - assert args[1]["event_data"] == event.model_dump( - exclude_none=True, mode="json" - ) - @pytest.mark.asyncio async def test_append_event_with_state_delta(mock_firestore_client): @@ -278,6 +300,18 @@ async def test_append_event_with_state_delta(mock_firestore_client): service._update_app_state_transactional = mock.AsyncMock() service._update_user_state_transactional = mock.AsyncMock() + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 0} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): await service.append_event(session, event) @@ -414,8 +448,8 @@ def collection_side_effect(name): assert len(response.sessions) == 1 session = response.sessions[0] assert session.id == "session1" - assert session.state["_app_app_key"] == "app_val" - assert session.state["_user_user_key"] == "user_val" + assert session.state["app:app_key"] == "app_val" + assert session.state["user:user_key"] == "user_val" mock_firestore_client.collection_group.assert_called_once_with("sessions") mock_firestore_client.collection_group.return_value.where.assert_called_once_with( @@ -578,9 +612,6 @@ async def test_append_event_partial(mock_firestore_client): @pytest.mark.asyncio - - - @pytest.mark.asyncio async def test_get_session_empty_data(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) From 78dd7a180efa0ae4da40cd781e1ba576bce6640a Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 16:38:22 -0700 Subject: [PATCH 31/36] Mypy errors --- .../adk/integrations/firestore/firestore_memory_service.py | 4 ++-- .../adk/integrations/firestore/firestore_session_service.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 22755be673..1d711c35cd 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -43,7 +43,7 @@ DEFAULT_MEMORIES_COLLECTION = "memories" -class FirestoreMemoryService(BaseMemoryService): +class FirestoreMemoryService(BaseMemoryService): # type: ignore[misc] """Memory service that uses Google Cloud Firestore as the backend. It uses the existing session data to create memories in a top-level memory collection. @@ -178,7 +178,7 @@ async def search_memory( seen = set() memories = [] for result_list in results: - if isinstance(result_list, Exception): + if isinstance(result_list, BaseException): logger.warning(f"Memory keyword search partial failure: {result_list}") continue for entry in result_list: diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 73eabc073e..83b97c33c2 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -576,7 +576,7 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: }, ) - return new_revision + return cast(int, new_revision) transaction_obj = self.client.transaction() new_revision_count = await _append_txn(transaction_obj) From 99efca50e1d5a932fac7b722ebe7e9b4e0eb1ae9 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Fri, 10 Apr 2026 16:42:58 -0700 Subject: [PATCH 32/36] Adding test for temp state handling --- .../test_firestore_session_service.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index b66a12c299..1445bfe0ef 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -329,6 +329,64 @@ async def test_append_event_with_state_delta(mock_firestore_client): assert args[1]["updateTime"] == firestore.SERVER_TIMESTAMP +@pytest.mark.asyncio +async def test_append_event_with_temp_state(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.events.event import Event + from google.adk.events.event import EventActions + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + + event = Event( + invocation_id="test_inv", + author="user", + actions=EventActions( + state_delta={"temp:k1": "v1", "session_key": "session_val"} + ), + ) + + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 0} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + await service.append_event(session, event) + + # 1. Verify it was applied in-memory + assert session.state["temp:k1"] == "v1" + assert session.state["session_key"] == "session_val" + + # 2. Verify it was trimmed before Firestore save + transaction = mock_firestore_client.transaction.return_value + transaction.set.assert_called() + + # Filter calls for the one that actually sets the event data + event_set_calls = [ + call + for call in transaction.set.call_args_list + if len(call[0]) > 1 + and isinstance(call[0][1], dict) + and "event_data" in call[0][1] + ] + assert len(event_set_calls) == 1 + event_data = event_set_calls[0][0][1]["event_data"] + + # Temporary keys should be deleted from delta before snapshot + assert "temp:k1" not in event_data["actions"]["state_delta"] + assert event_data["actions"]["state_delta"]["session_key"] == "session_val" + + @pytest.mark.asyncio async def test_list_sessions_with_user_id(mock_firestore_client): service = FirestoreSessionService(client=mock_firestore_client) From a0a1979bcdd99deedbc3e87523e80090c3e6d0aa Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Mon, 13 Apr 2026 10:23:38 -0700 Subject: [PATCH 33/36] Correctly interpret num_recent_events = 0 --- .../firestore/firestore_session_service.py | 33 +++++++++--------- .../test_firestore_session_service.py | 34 +++++++++++++++++++ 2 files changed, 51 insertions(+), 16 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 83b97c33c2..303771e878 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -295,23 +295,24 @@ async def get_session( return None # Fetch events - events_ref = session_ref.collection(self.events_collection) - query = events_ref.order_by("timestamp") - - if config: - if config.after_timestamp: - after_dt = datetime.fromtimestamp(config.after_timestamp) - query = query.where("timestamp", ">=", after_dt) - if config.num_recent_events: - query = query.limit_to_last(config.num_recent_events) - - events_docs = await query.get() events = [] - for event_doc in events_docs: - event_data = event_doc.to_dict() - if event_data and "event_data" in event_data: - ed = event_data["event_data"] - events.append(Event.model_validate(ed)) + if not (config and config.num_recent_events == 0): + events_ref = session_ref.collection(self.events_collection) + query = events_ref.order_by("timestamp") + + if config: + if config.after_timestamp: + after_dt = datetime.fromtimestamp(config.after_timestamp) + query = query.where("timestamp", ">=", after_dt) + if config.num_recent_events: + query = query.limit_to_last(config.num_recent_events) + + events_docs = await query.get() + for event_doc in events_docs: + event_data = event_doc.to_dict() + if event_data and "event_data" in event_data: + ed = event_data["event_data"] + events.append(Event.model_validate(ed)) # Let's continue getting session. session_state = data.get("state", {}) diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index 1445bfe0ef..d34c6025fd 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -755,3 +755,37 @@ def collection_side_effect(name): assert session.state["session_key"] == "session_val" assert "_app_app_key" not in session.state assert "_user_user_key" not in session.state + + +@pytest.mark.asyncio +async def test_get_session_with_zero_num_recent_events(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + doc_snapshot = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.get.return_value + ) + doc_snapshot.exists = True + doc_snapshot.to_dict.return_value = { + "id": session_id, + "appName": app_name, + "userId": user_id, + } + + events_collection_ref = ( + mock_firestore_client.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value.document.return_value.collection.return_value + ) + + from google.adk.sessions.base_session_service import GetSessionConfig + + config = GetSessionConfig(num_recent_events=0) + + session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id, config=config + ) + + assert session is not None + assert len(session.events) == 0 + events_collection_ref.get.assert_not_called() From acda94d06c61d11a4f964a23e8a15d44e3c49651 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Mon, 13 Apr 2026 10:31:06 -0700 Subject: [PATCH 34/36] Fixing stale marker logic and another possible session corruption --- .../firestore/firestore_session_service.py | 50 ++++--- .../test_firestore_session_service.py | 138 ++++++++++++++++++ 2 files changed, 170 insertions(+), 18 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 303771e878..3c91458afa 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -412,16 +412,29 @@ async def list_sessions( u_state = user_states_map.get(u_id, {}) merged = self._merge_state(app_state, u_state, s_state) - sessions.append( - Session( - id=data["id"], - app_name=data["appName"], - user_id=data["userId"], - state=merged, - events=[], - last_update_time=0.0, - ) + update_time = data.get("updateTime") + last_update_time = 0.0 + if update_time: + if isinstance(update_time, datetime): + last_update_time = update_time.timestamp() + else: + try: + last_update_time = float(update_time) + except (ValueError, TypeError): + pass + + current_revision = data.get("revision", 0) + session_obj = Session( + id=data["id"], + app_name=data["appName"], + user_id=data["userId"], + state=merged, + events=[], + last_update_time=last_update_time, ) + if current_revision > 0: + session_obj._storage_update_marker = str(current_revision) + sessions.append(session_obj) return ListSessionsResponse(sessions=sessions) @@ -515,13 +528,13 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: raise ValueError(f"Session {session.id} is currently being deleted.") current_revision = session_doc.get("revision", 0) + current_marker = str(current_revision) if current_revision > 0 else None - if session._storage_update_marker is not None: - if session._storage_update_marker != str(current_revision): - raise ValueError( - "The session has been modified in storage since it was loaded. " - "Please reload the session before appending more events." - ) + if session._storage_update_marker != current_marker: + raise ValueError( + "The session has been modified in storage since it was loaded. " + "Please reload the session before appending more events." + ) app_snap = ( await app_ref.get(transaction=transaction) if app_updates else None @@ -543,13 +556,14 @@ async def _append_txn(transaction: firestore.AsyncTransaction) -> int: current_user.update(user_updates) transaction.set(user_ref, current_user, merge=True) + new_revision = current_revision + 1 + combined_state = {k: v for k, v in session.state.items()} for k, v in session_updates.items(): - session.state[k] = v + combined_state[k] = v - new_revision = current_revision + 1 session_only_state = { k: v - for k, v in session.state.items() + for k, v in combined_state.items() if not k.startswith(State.APP_PREFIX) and not k.startswith(State.USER_PREFIX) } diff --git a/tests/unittests/integrations/firestore/test_firestore_session_service.py b/tests/unittests/integrations/firestore/test_firestore_session_service.py index d34c6025fd..b7e8b623bb 100644 --- a/tests/unittests/integrations/firestore/test_firestore_session_service.py +++ b/tests/unittests/integrations/firestore/test_firestore_session_service.py @@ -789,3 +789,141 @@ async def test_get_session_with_zero_num_recent_events(mock_firestore_client): assert session is not None assert len(session.events) == 0 events_collection_ref.get.assert_not_called() + + +@pytest.mark.asyncio +async def test_list_sessions_populates_storage_markers(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + + session_doc = mock.MagicMock() + session_doc.to_dict.return_value = { + "id": "session1", + "appName": app_name, + "userId": user_id, + "state": {"session_key": "session_val"}, + "revision": 5, + "updateTime": 1234567890.0, + } + + app_state_coll = mock.MagicMock() + user_state_coll = mock.MagicMock() + sessions_coll = mock.MagicMock() + + def collection_side_effect(name): + if name == service.app_state_collection: + return app_state_coll + elif name == service.user_state_collection: + return user_state_coll + elif name == service.root_collection: + return sessions_coll + return mock.MagicMock() + + mock_firestore_client.collection.side_effect = collection_side_effect + + app_doc = mock.MagicMock() + app_doc.exists = False + app_doc_ref = mock.MagicMock() + app_state_coll.document.return_value = app_doc_ref + app_doc_ref.get = mock.AsyncMock(return_value=app_doc) + + user_doc = mock.MagicMock() + user_doc.exists = False + user_app_doc = mock.MagicMock() + user_state_coll.document.return_value = user_app_doc + users_coll = mock.MagicMock() + user_app_doc.collection.return_value = users_coll + user_doc_ref = mock.MagicMock() + users_coll.document.return_value = user_doc_ref + user_doc_ref.get = mock.AsyncMock(return_value=user_doc) + + app_doc_in_root = mock.MagicMock() + sessions_coll.document.return_value = app_doc_in_root + users_coll = mock.MagicMock() + app_doc_in_root.collection.return_value = users_coll + user_doc_in_users = mock.MagicMock() + users_coll.document.return_value = user_doc_in_users + sessions_subcoll = mock.MagicMock() + user_doc_in_users.collection.return_value = sessions_subcoll + sessions_query = mock.MagicMock() + sessions_subcoll.where.return_value = sessions_query + sessions_query.get = mock.AsyncMock(return_value=[session_doc]) + + response = await service.list_sessions(app_name=app_name, user_id=user_id) + + assert len(response.sessions) == 1 + session = response.sessions[0] + assert session._storage_update_marker == "5" + assert session.last_update_time == 1234567890.0 + + +@pytest.mark.asyncio +async def test_append_event_stale_writer(mock_firestore_client): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + session._storage_update_marker = "1" + event = Event(invocation_id="test_inv", author="user") + + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 2} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + with pytest.raises( + ValueError, match="The session has been modified in storage" + ): + await service.append_event(session, event) + + +@pytest.mark.asyncio +async def test_append_event_preserves_in_memory_state_on_failure( + mock_firestore_client, +): + service = FirestoreSessionService(client=mock_firestore_client) + app_name = "test_app" + user_id = "test_user" + from google.adk.sessions.session import Session + + session = Session(id="test_session", app_name=app_name, user_id=user_id) + session.state = {"original_key": "original_value"} + session._storage_update_marker = "1" + + from google.adk.events.event import EventActions + + event = Event( + invocation_id="test_inv", + author="user", + actions=EventActions(state_delta={"new_key": "new_value"}), + ) + + session_doc_snapshot = mock.MagicMock() + session_doc_snapshot.exists = True + session_doc_snapshot.to_dict.return_value = {"revision": 2} + + root_coll = mock_firestore_client.collection.return_value + app_ref = root_coll.document.return_value + users_coll = app_ref.collection.return_value + user_ref = users_coll.document.return_value + sessions_ref = user_ref.collection.return_value + session_doc_ref = sessions_ref.document.return_value + session_doc_ref.get = mock.AsyncMock(return_value=session_doc_snapshot) + + with mock.patch("google.cloud.firestore.async_transactional", lambda x: x): + with pytest.raises(ValueError): + await service.append_event(session, event) + + assert "new_key" not in session.state + assert session.state["original_key"] == "original_value" From 203cd45f80b68c15cf2ac04f993b96625262c0a1 Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Mon, 13 Apr 2026 12:03:33 -0700 Subject: [PATCH 35/36] Use pipe syntax instead of Optional --- .../firestore/firestore_memory_service.py | 9 ++++----- .../firestore/firestore_session_service.py | 15 +++++++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 1d711c35cd..96fdfe23f1 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -19,7 +19,6 @@ import os import re from typing import Any -from typing import Optional from typing import TYPE_CHECKING from google.cloud.firestore_v1.base_query import FieldFilter @@ -51,10 +50,10 @@ class FirestoreMemoryService(BaseMemoryService): # type: ignore[misc] def __init__( self, - client: Optional[firestore.AsyncClient] = None, - events_collection: Optional[str] = None, - stop_words: Optional[set[str]] = None, - memories_collection: Optional[str] = None, + client: firestore.AsyncClient | None = None, + events_collection: str | None = None, + stop_words: set[str] | None = None, + memories_collection: str | None = None, ): """Initializes the Firestore memory service. diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 3c91458afa..6d19731e91 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -23,7 +23,6 @@ from typing import Any from typing import AsyncIterator from typing import cast -from typing import Optional from typing import TYPE_CHECKING _SessionLockKey = tuple[str, str, str] @@ -75,8 +74,8 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc] def __init__( self, - client: Optional[firestore.AsyncClient] = None, - root_collection: Optional[str] = None, + client: firestore.AsyncClient | None = None, + root_collection: str | None = None, ): """Initializes the Firestore session service. @@ -167,8 +166,8 @@ async def create_session( *, app_name: str, user_id: str, - state: Optional[dict[str, Any]] = None, - session_id: Optional[str] = None, + state: dict[str, Any] | None = None, + session_id: str | None = None, ) -> Session: """Creates a new session in Firestore.""" from google.cloud import firestore @@ -281,8 +280,8 @@ async def get_session( app_name: str, user_id: str, session_id: str, - config: Optional[GetSessionConfig] = None, - ) -> Optional[Session]: + config: GetSessionConfig | None = None, + ) -> Session | None: """Gets a session from Firestore.""" session_ref = self._get_sessions_ref(app_name, user_id).document(session_id) doc = await session_ref.get() @@ -361,7 +360,7 @@ async def get_session( return session async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, *, app_name: str, user_id: str | None = None ) -> ListSessionsResponse: """Lists sessions from Firestore.""" if user_id: From 0a9149c069b2b5c41793fbe943f90a3306663ede Mon Sep 17 00:00:00 2001 From: Scott Mansfield Date: Mon, 13 Apr 2026 12:11:14 -0700 Subject: [PATCH 36/36] Use keyword-only args for constructors --- .../adk/integrations/firestore/firestore_memory_service.py | 1 + .../adk/integrations/firestore/firestore_session_service.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/google/adk/integrations/firestore/firestore_memory_service.py b/src/google/adk/integrations/firestore/firestore_memory_service.py index 96fdfe23f1..503b2ae353 100644 --- a/src/google/adk/integrations/firestore/firestore_memory_service.py +++ b/src/google/adk/integrations/firestore/firestore_memory_service.py @@ -50,6 +50,7 @@ class FirestoreMemoryService(BaseMemoryService): # type: ignore[misc] def __init__( self, + *, client: firestore.AsyncClient | None = None, events_collection: str | None = None, stop_words: set[str] | None = None, diff --git a/src/google/adk/integrations/firestore/firestore_session_service.py b/src/google/adk/integrations/firestore/firestore_session_service.py index 6d19731e91..82349941b7 100644 --- a/src/google/adk/integrations/firestore/firestore_session_service.py +++ b/src/google/adk/integrations/firestore/firestore_session_service.py @@ -74,6 +74,7 @@ class FirestoreSessionService(BaseSessionService): # type: ignore[misc] def __init__( self, + *, client: firestore.AsyncClient | None = None, root_collection: str | None = None, ):