diff --git a/MULTIMODAL_PLAN.md b/MULTIMODAL_PLAN.md deleted file mode 100644 index 5a25701..0000000 --- a/MULTIMODAL_PLAN.md +++ /dev/null @@ -1,936 +0,0 @@ -# Agentflow Multimodal Support — Master Plan - -## Research Summary - -### How ADK Handles Multimodal - -Google ADK uses `google.genai.types.Part` as the universal content unit: -- **Inline data**: `types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg")` — for files < 20MB -- **File API upload**: `client.files.upload(file="path.jpg")` → returns a `types.File` reference — for large/reusable files -- **PIL Image**: Pillow Image objects are auto-converted by the SDK -- **Artifacts**: Binary data stored via `ArtifactService` (in-memory or GCS), versioned, identified by filename + namespace (session or user scope), represented as `types.Part(inline_data=types.Blob(data=bytes, mime_type="..."))` -- **Supported formats**: PNG, JPEG, WEBP, HEIC, HEIF for images; PDF via document processing; audio via PCM blobs - -### How LangChain Handles Multimodal - -LangChain v1 uses standard content blocks in `HumanMessage.content`: -```python -# Image via URL -{"type": "image", "url": "https://example.com/image.jpg"} -# Image via base64 -{"type": "image", "base64": "...", "mime_type": "image/jpeg"} -# Image via file_id -{"type": "image", "file_id": "file-abc123"} -# PDF document -{"type": "file", "base64": "...", "mime_type": "application/pdf"} -# Audio -{"type": "audio", "base64": "...", "mime_type": "audio/wav"} -``` -Provider-native formats also supported (OpenAI's `image_url` type, etc.) - -### How OpenAI API Expects Multimodal - -```python -messages = [{ - "role": "user", - "content": [ - {"type": "text", "text": "What is in this image?"}, - {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{base64_data}"}}, - # OR - {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}, - ] -}] -# PDFs: use file_search or pass base64 as image -# Audio input: {"type": "input_audio", "input_audio": {"data": "...", "format": "wav"}} -``` - -### How Google GenAI API Expects Multimodal - -```python -from google.genai import types -contents = [ - types.Content(role="user", parts=[ - types.Part(text="What is this?"), - types.Part.from_bytes(data=img_bytes, mime_type="image/jpeg"), - # OR via File API - types.Part(file_data=types.FileData(file_uri="...", mime_type="...")), - ]) -] -``` - ---- - -## Current State Analysis - -### What Exists ✅ -| Component | Status | Notes | -|-----------|--------|-------| -| `MessageBlock` types | ✅ Defined | `ImageBlock`, `AudioBlock`, `VideoBlock`, `DocumentBlock`, `DataBlock` all exist in `message_block.py` | -| `MediaRef` model | ✅ Defined | Supports `url`, `file_id`, `data_base64` with mime_type, size, dimensions | -| `Message.attach_media()` | ✅ Exists | Can append media blocks to message content | -| `ContentBlock` union | ✅ Defined | Discriminated union of all block types | -| Response converters | ✅ Partial | `OpenAIConverter` and `GoogleGenAIConverter` can extract images from responses | -| `TokenUsages.image_tokens` | ✅ Defined | Field exists for multimodal token tracking | - -### What's Missing ❌ -| Component | Gap | Impact | -|-----------|-----|--------| -| **`_convert_dict()` in converter.py** | Only extracts `.text()` — all media blocks are **silently dropped** | Images/audio/docs in messages are never sent to LLM | -| **`_handle_regular_message()` in google.py** | Only wraps text in `types.Part(text=...)` — no media | Google provider ignores all multimodal content | -| **OpenAI message format** | Messages only contain string `content` — no `content: [...]` array | OpenAI provider ignores all multimodal content | -| **File upload API endpoint** | No endpoint to upload files (images, PDFs, docs) | No way for clients to send files | -| **Document extraction** | No PDF/DOCX text extraction utilities | Can't read documents to pass as text | -| **Media storage backend** | No binary file storage service — blobs would be inlined in DB | Uploaded files have nowhere to persist | -| **Multimodal config** | No per-agent config for how to handle images (base64 vs PIL vs file_id) | No flexibility | -| **Document processing config** | No config for whether to extract text from PDFs or pass raw to AI | No flexibility | -| **Input message creation helpers** | No convenience API to create multimodal messages easily | Poor DX | -| **MediaStore → MediaRef pipeline** | No mechanism to store blobs externally and reference them in messages | Blobs would be inlined into state/checkpointer, bloating DB | - -### Critical Problem: State & Checkpointer with Binary Data ⚠️ - -Current serialization paths that would be affected by naive inline base64: - -``` -Path 1: AgentState → PG states table (JSONB) - state.model_dump() → json.dumps() → INSERT INTO states(state_data JSONB) - ⚠️ AgentState.context = list[Message] → all messages serialized into ONE JSONB blob - ⚠️ Each state save RE-SERIALIZES ALL messages including ALL images from ALL turns - ⚠️ A 10-turn conversation with 3 images ≈ 30MB+ in a single JSONB cell, growing every turn - -Path 2: Messages → PG messages table (TEXT) - [block.model_dump(mode="json") for block in message.content] → json.dumps() → INSERT INTO messages(content TEXT) - ⚠️ Each image block with data_base64 ≈ 1-5MB stored as TEXT per message row - -Path 3: AgentState → Redis cache - state.model_dump() → json.dumps() → Redis SETEX - ⚠️ Same massive state with all images goes into Redis, evicted by TTL = 24h - -Path 4: InMemoryCheckpointer → Python dict - self._states[key] = state (holds full Python objects in memory) - ⚠️ All images kept as Python objects in process memory forever -``` - -**Result**: One 1MB image creates 3+ copies across PG JSONB, PG TEXT, and Redis. -A realistic 20-message conversation with 5 images = 60-100MB database footprint, re-serialized on every state save. - ---- - -## Architecture Design - -### Design Principles -1. **Never store blobs in the database** — Binary data goes to `MediaStore`; only tiny `MediaRef` references live in messages/state -2. **Provider-agnostic content model** — `Message` with `ContentBlock` types is the universal format -3. **Configurable processing** — Developer controls image handling (base64/url/file_id) and document handling (extract text vs pass raw) -4. **Maximum flexibility** — Support all input methods: URL, base64, file path, PIL Image, bytes, file_id -5. **Lazy conversion** — Content stays as `ContentBlock` until the last mile (provider call), then converts to provider-specific format -6. **Checkpointer stays unchanged** — The fix is what goes INTO messages, not how they're stored - -### The Reference Pattern: How State & Checkpointer Work - -The `MediaRef` model already has the right design. The fix is a `BaseMediaStore` layer that stores blobs externally and converts them to lightweight references BEFORE they enter the message. - -``` -┌─────────────────────────────────────────────────────────────────────┐ -│ INGEST BOUNDARY (where blobs enter the system) │ -│ │ -│ Option A: API upload │ -│ POST /v1/files/upload (multipart) │ -│ → MediaStore.store(bytes, mime_type) → storage_key "abc123" │ -│ → MediaRef(kind="url", url="agentflow://media/abc123") │ -│ │ -│ Option B: SDK usage │ -│ msg = Message.with_image(bytes, mime_type, media_store=store) │ -│ → MediaStore.store(bytes, mime_type) → storage_key "abc123" │ -│ → MediaRef(kind="url", url="agentflow://media/abc123") │ -│ │ -│ Option C: External URL (no storage needed) │ -│ MediaRef(kind="url", url="https://cdn.example.com/img.jpg") │ -│ │ -│ Option D: Small inline (< threshold, e.g. 50KB) │ -│ MediaRef(kind="data", data_base64="...", mime_type="image/png") │ -└─────────────────────────────────────────────────────────────────────┘ - │ - ▼ - Message.content = [TextBlock("describe this"), ImageBlock(media=ref)] - │ - ┌────────────────────────┼────────────────────────┐ - ▼ ▼ ▼ - ┌──────────────┐ ┌───────────────┐ ┌──────────────────┐ - │ PG states │ │ PG messages │ │ Redis cache │ - │ table JSONB │ │ table TEXT │ │ │ - │ │ │ │ │ │ - │ MediaRef is │ │ MediaRef is │ │ MediaRef is │ - │ ~100 bytes: │ │ ~100 bytes │ │ ~100 bytes │ - │ {kind:"url", │ │ │ │ │ - │ url:"ag:// │ │ NOT 1-5MB │ │ NOT 1-5MB │ - │ media/abc"} │ │ base64! │ │ base64! │ - └──────────────┘ └───────────────┘ └──────────────────┘ - │ - ▼ - ┌────────────────────────────────────────────────┐ - │ LLM BOUNDARY (converter resolves references) │ - │ │ - │ MediaRef(kind="url", url="agentflow://...") │ - │ → MediaStore.retrieve(key) → raw bytes │ - │ → OpenAI: base64 data URL │ - │ → Google: types.Part.from_bytes(bytes) │ - │ │ - │ MediaRef(kind="url", url="https://...") │ - │ → OpenAI: pass URL directly │ - │ → Google: types.Part.from_uri(uri) │ - │ │ - │ MediaRef(kind="data", data_base64="...") │ - │ → OpenAI: data:mime;base64,{data} │ - │ → Google: types.Part.from_bytes(decoded) │ - └────────────────────────────────────────────────┘ -``` - -### BaseMediaStore Interface - -```python -class BaseMediaStore(ABC): - """Abstract interface for storing binary media outside the message system.""" - - @abstractmethod - async def store(self, data: bytes, mime_type: str, metadata: dict | None = None) -> str: - """Store binary data, return a storage_key (opaque string).""" - - @abstractmethod - async def retrieve(self, storage_key: str) -> tuple[bytes, str]: - """Retrieve binary data and mime_type by storage_key.""" - - @abstractmethod - async def delete(self, storage_key: str) -> bool: - """Delete stored media. Returns True if deleted.""" - - @abstractmethod - async def exists(self, storage_key: str) -> bool: - """Check if media exists in store.""" - - def to_media_ref(self, storage_key: str, mime_type: str, **kwargs) -> MediaRef: - """Convert a storage key to a MediaRef for embedding in messages.""" - return MediaRef( - kind="url", - url=f"agentflow://media/{storage_key}", - mime_type=mime_type, - **kwargs, - ) -``` - -### MediaStore Implementations - -| Implementation | Backend | Use Case | -|----------------|---------|----------| -| `InMemoryMediaStore` | Python `dict[str, tuple[bytes, str]]` | Tests, ephemeral scripts | -| `LocalFileMediaStore` | Filesystem (configurable base path) | Dev, single-server deployments | -| `S3MediaStore` / `CloudMediaStore` | S3/MinIO/GCS compatible via `cloud-storage-manager` | Production, multi-instance | -| `PgBlobStore` | PostgreSQL `bytea` in separate `media` table | PG-only deployments (avoids S3) | - -Even `PgBlobStore` stores blobs in a **separate `media` table** with `bytea` column — never inside the `states` or `messages` JSONB. This keeps the core tables lean and the media data separately manageable (can be vacuumed, archived, or migrated to S3 later). - -### MediaRef Resolution Strategy - -The converter layer resolves `MediaRef` → provider format at LLM call time: - -```python -class MediaRefResolver: - """Resolves MediaRef objects to actual binary data or URLs for provider APIs.""" - - def __init__(self, media_store: BaseMediaStore | None = None): - self.media_store = media_store - - async def resolve_for_openai(self, ref: MediaRef) -> dict: - """Convert MediaRef to OpenAI content part.""" - if ref.kind == "url" and ref.url and ref.url.startswith("agentflow://media/"): - key = ref.url.removeprefix("agentflow://media/") - data, mime = await self.media_store.retrieve(key) - b64 = base64.b64encode(data).decode() - return {"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}} - elif ref.kind == "url" and ref.url: - return {"type": "image_url", "image_url": {"url": ref.url}} - elif ref.kind == "data" and ref.data_base64: - return {"type": "image_url", "image_url": {"url": f"data:{ref.mime_type};base64,{ref.data_base64}"}} - elif ref.kind == "file_id": - # Provider-native file reference - return {"type": "image_url", "image_url": {"url": ref.url or ref.file_id}} - - async def resolve_for_google(self, ref: MediaRef) -> Any: - """Convert MediaRef to Google types.Part.""" - from google.genai import types - if ref.kind == "url" and ref.url and ref.url.startswith("agentflow://media/"): - key = ref.url.removeprefix("agentflow://media/") - data, mime = await self.media_store.retrieve(key) - return types.Part.from_bytes(data=data, mime_type=mime) - elif ref.kind == "url" and ref.url: - return types.Part.from_uri(file_uri=ref.url, mime_type=ref.mime_type) - elif ref.kind == "data" and ref.data_base64: - data = base64.b64decode(ref.data_base64) - return types.Part.from_bytes(data=data, mime_type=ref.mime_type) - elif ref.kind == "file_id": - return types.Part(file_data=types.FileData(file_uri=ref.file_id, mime_type=ref.mime_type)) -``` - -### Inline Data Guard (Optional Safety Net) - -For safety, add an optional pre-save hook that prevents large inline base64 from accidentally entering the checkpointer: - -```python -class MediaOffloadPolicy(str, Enum): - NEVER = "never" # Allow inline base64 (testing/small images) - THRESHOLD = "threshold" # Offload if > max_inline_bytes (default) - ALWAYS = "always" # Always offload to MediaStore - -class MultimodalConfig(BaseModel): - # ...existing fields... - offload_policy: MediaOffloadPolicy = MediaOffloadPolicy.THRESHOLD - max_inline_bytes: int = 50_000 # ~50KB — below this, inline is fine -``` - -If a message enters the system with a large `data_base64` and a `MediaStore` is configured, the system can: -1. Log a warning ("Large inline media detected, consider using MediaStore") -2. Optionally auto-offload: store to `MediaStore`, replace `MediaRef(kind="data")` with `MediaRef(kind="url")` - -This is NOT done in the checkpointer itself — it happens at the message ingestion boundary. - -### What Changes in Checkpointer: **Nothing** - -| Checkpointer Component | Change Needed? | Why | -|------------------------|---------------|-----| -| `BaseCheckpointer` | ❌ No change | Interface stays the same | -| `InMemoryCheckpointer` | ❌ No change | Python objects, just holds references | -| `PgCheckpointer.aput_state()` | ❌ No change | `state.model_dump()` → JSONB still works; MediaRef is tiny | -| `PgCheckpointer.aput_messages()` | ❌ No change | `block.model_dump()` works; ImageBlock/MediaRef serialize fine | -| `PgCheckpointer._row_to_message()` | ❌ No change | Pydantic `model_validate` already deserializes ImageBlock/MediaRef | -| Redis cache | ❌ No change | Same tiny JSON references | -| DB schema | ❌ No change* | *Optional: add `media` table for metadata, not required | - -The key insight: **fix the input, not the storage.** If only references enter messages, the existing serialization pipeline handles everything perfectly. - -### Content Flow - -``` -Client Upload PyAgenity Core Provider API -───────────── ────────────── ──────────── - -image/pdf/docx ──► API endpoint ──► MediaProcessor ──► Message - (FastAPI) - validate │ - - file upload - store binary │ - - base64 - create MediaRef │ - - URL │ - ▼ - ContentBlock - (ImageBlock, - DocumentBlock, - etc.) - │ - ┌──────────────────────┤ - ▼ ▼ - convert_dict() Google format - (OpenAI format) (types.Part) - │ │ - ▼ ▼ - OpenAI API Gemini API -``` - -### Design Philosophy: Library vs API - -| Layer | Extraction? | Reason | -|-------|-------------|--------| -| **PyAgenity (core library)** | ❌ No extraction | Keeps the library lightweight; developers choose their own tools for document handling when using the SDK directly | -| **pyagenity-api (platform)** | ✅ Auto-extracts | When developers use the hosted API, extraction is handled transparently using `textxtract`. They just upload files and get back AI responses. | - -> **Rule**: If you're using PyAgenity as a library, you control how documents are converted to text and pass the result as a `TextBlock`. If you're using the API platform, upload the file and the API handles extraction automatically via `textxtract`. - -### New Components - -``` -agentflow/ # PyAgenity core — NO extraction logic -├── media/ # NEW: Media processing & storage module -│ ├── __init__.py -│ ├── config.py # MultimodalConfig, ImageHandling, DocumentHandling, MediaOffloadPolicy -│ ├── processor.py # MediaProcessor: validate mime type, size, resize images -│ ├── resolver.py # MediaRefResolver: resolve MediaRef → provider format at LLM call time -│ └── storage/ # Binary blob storage backends (NOT in checkpointer DB) -│ ├── __init__.py -│ ├── base.py # BaseMediaStore: store/retrieve/delete/exists/to_media_ref -│ ├── memory_store.py # InMemoryMediaStore — dict-based (dev/test) -│ ├── local_store.py # LocalFileMediaStore — filesystem (dev/single-server) -│ ├── cloud_store.py # CloudMediaStore — S3/GCS via cloud-storage-manager (production) -│ └── pg_store.py # PgBlobStore — separate PG BYTEA table (PG-only deployments) - -agentflow_cli/ # pyagenity-api — document extraction lives HERE -├── media/ # NEW: API-side media handling -│ ├── __init__.py -│ ├── extractor.py # DocumentExtractor: wraps textxtract AsyncTextExtractor -│ └── pipeline.py # DocumentPipeline: upload → extract → inject into message -``` - ---- - -## Sprint Plan - -### Sprint 1: Core Multimodal Pipeline (PyAgenity) — Foundation -**Goal**: Make images work end-to-end through the agent pipeline - -- [x] **1.1** Create `agentflow/media/__init__.py` and `agentflow/media/config.py` - - `MultimodalConfig` pydantic model: - ```python - class ImageHandling(str, Enum): - BASE64 = "base64" # Inline base64 in API call - URL = "url" # Pass URL reference - FILE_ID = "file_id" # Use provider file upload API - - class DocumentHandling(str, Enum): - EXTRACT_TEXT = "extract_text" # Read PDF/DOCX → pass as text - PASS_RAW = "pass_raw" # Pass binary to AI (if provider supports) - SKIP = "skip" # Don't send documents to AI - - class MultimodalConfig(BaseModel): - image_handling: ImageHandling = ImageHandling.BASE64 - document_handling: DocumentHandling = DocumentHandling.EXTRACT_TEXT - max_image_size_mb: float = 10.0 - max_image_dimension: int = 2048 # Resize if larger - supported_image_types: set[str] = {"image/jpeg", "image/png", "image/webp", "image/gif"} - supported_doc_types: set[str] = {"application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"} - ``` - -- [x] **1.2** Update `_convert_dict()` in `agentflow/utils/converter.py` - - Convert `ImageBlock` → OpenAI's `{"type": "image_url", "image_url": {"url": "..."}}` format - - Convert `AudioBlock` → `{"type": "input_audio", "input_audio": {"data": "...", "format": "..."}}` - - Convert `DocumentBlock` → text block (if extract_text mode) or image_url (if pass_raw mode for PDF) - - Return `content` as a **list of content parts** (not a string) when multimodal blocks are present - - Keep backward compat: text-only messages still return `{"content": "string"}` - -- [x] **1.3** Update `_handle_regular_message()` in `agentflow/graph/agent_internal/google.py` - - When message `content` is a list of parts (from convert_dict), convert each to `types.Part`: - - Text → `types.Part(text=...)` - - Image URL → `types.Part.from_uri(file_uri=url, mime_type=...)` - - Image base64 → `types.Part.from_bytes(data=decoded_bytes, mime_type=...)` - - File ID → `types.Part(file_data=types.FileData(file_uri=..., mime_type=...))` - - Handle the case where content comes as a list of dicts (multimodal) - -- [x] **1.4** Update OpenAI message handling - - In `_call_openai()` / `_call_openai_responses()`, messages with list content already work with OpenAI SDK - - Ensure the content array format `[{"type": "text", "text": "..."}, {"type": "image_url", ...}]` is passed through correctly - -- [x] **1.5** Add `multimodal_config` parameter to `Agent.__init__()` - - Optional `MultimodalConfig` parameter on Agent - - Pass config through to converter functions - -- [x] **1.6** Add convenience constructors to `Message` - - `Message.image_message(image_url=..., text=..., role="user")` - - `Message.multimodal_message(content_blocks=[...], role="user")` - - `Message.from_file(file_path, mime_type=None, text=None)` — auto-detect type, create appropriate blocks - -- [x] **1.7** Write tests - - Test `_convert_dict` with ImageBlock (base64, URL, file_id) - - Test Google format conversion with images - - Test OpenAI format conversion with images - - Test end-to-end: Message with image → provider call format - - Test backward compatibility: text-only still works - -### Sprint 2: Document Processing & Extraction (pyagenity-api) -**Goal**: Automatic PDF, DOCX, and other document extraction in the API platform using `textxtract`. The PyAgenity core library does **not** include any extraction logic — if developers use the SDK directly, they extract text themselves and pass it as a `TextBlock`. - -**Library used**: [`textxtract`](https://10xhub.github.io/textxtract/) — supports async, works from file path or raw bytes, handles PDF, DOCX, DOC, RTF, HTML, CSV, JSON, XML, MD, TXT, ZIP. - -```python -from textxtract import AsyncTextExtractor -from textxtract.core.exceptions import FileTypeNotSupportedError, ExtractionError - -extractor = AsyncTextExtractor() -text = await extractor.extract(file_bytes, "document.pdf") -``` - -- [x] **2.1** Add `textxtract` to `pyagenity-api` dependencies - - In `pyagenity-api/pyproject.toml`, add: - - `textxtract[pdf]` → PyMuPDF for PDF support - - `textxtract[docx]` → python-docx for Word support - - `textxtract[html]` → beautifulsoup4 for HTML support - - `textxtract[xml]` → lxml for XML support - - `textxtract[md]` → markdown for Markdown support - - Or use `textxtract[pdf,docx,html,xml,md]` combined extras - - Text, CSV, JSON, ZIP are supported built-in (no extras needed) - -- [x] **2.2** Create `DocumentExtractor` service in `agentflow_cli/media/extractor.py` - - Wraps `AsyncTextExtractor` from `textxtract` - - Single method: `async def extract(data: bytes, filename: str) -> str` - - Maps MIME type → filename extension when only bytes + mime_type is known - - Handles `FileTypeNotSupportedError` → returns `None` (unsupported → pass raw) - - Handles `ExtractionError` → raises `400 Bad Request` with clear message - - Example: - ```python - from textxtract import AsyncTextExtractor - from textxtract.core.exceptions import FileTypeNotSupportedError, ExtractionError - - class DocumentExtractor: - def __init__(self): - self._extractor = AsyncTextExtractor() - - async def extract(self, data: bytes, filename: str) -> str | None: - try: - return await self._extractor.extract(data, filename) - except FileTypeNotSupportedError: - return None # caller decides how to handle unsupported types - except ExtractionError as e: - raise ValueError(f"Failed to extract text from {filename}: {e}") from e - ``` - -- [x] **2.3** Create `DocumentPipeline` in `agentflow_cli/media/pipeline.py` - - Orchestrates: receive uploaded file → extract text via `DocumentExtractor` → return `TextBlock` or `DocumentBlock` - - When extraction succeeds → returns `TextBlock(text=extracted_text)` with original filename as metadata - - When file type not supported for extraction (e.g. images) → return `DocumentBlock` (raw, to be handled by provider) - - Respects `DocumentHandling` config: - - `EXTRACT_TEXT` → always attempt extraction, raise if fails - - `PASS_RAW` → skip extraction, return `DocumentBlock` with base64/media_ref - - `SKIP` → return `None` (drop document from message) - -- [x] **2.4** Wire `DocumentPipeline` into the file upload endpoint (Sprint 4) - - When `POST /v1/files/upload` receives a document (non-image, non-audio): - - Store binary in `MediaStore` - - Also run `DocumentExtractor.extract()` and cache the extracted text - - Return both `file_id` and optionally `extracted_text` in the response - - When a `DocumentBlock` arrives in a message at invoke/stream time: - - If `file_id` references an already-extracted file → inject `TextBlock` with cached text - - If inline bytes/base64 → run extraction on-the-fly - -- [x] **2.5** Update `MediaProcessor` in PyAgenity (`agentflow/media/processor.py`) - - `MediaProcessor` handles only **images**: validate mime type, check file size, optionally resize - - **No document extraction logic** — documents are processed by `DocumentPipeline` in the API layer - - Clearly document this in docstring: - > `MediaProcessor` handles image validation and resizing only. Document text extraction is the responsibility of the caller (API layer uses `DocumentPipeline`; SDK users extract text themselves). - -- [x] **2.6** Add optional image dependency to PyAgenity `pyproject.toml` - - `pip install 10xscale-agentflow[images]` → `Pillow` (for image resizing/processing) - - Remove any pdf/docx extras from PyAgenity — those belong in pyagenity-api - - `pip install 10xscale-agentflow[all]` → `Pillow` only (no extraction deps) - -- [x] **2.7** Write tests - - `DocumentExtractor`: mock `AsyncTextExtractor`, test success, unsupported type, extraction error - - `DocumentPipeline`: test all three `DocumentHandling` modes - - Integration: upload PDF via API → text extracted → injected into agent message - - Verify PyAgenity core has zero `textxtract` or extraction imports - -### Sprint 3: Media Storage Layer (PyAgenity) — The Database Problem Solution -**Goal**: Binary data NEVER touches the checkpointer/state DB. Only lightweight `MediaRef` references are stored. - -- [x] **3.1** Create `BaseMediaStore` abstract interface - - `agentflow/media/storage/base.py`: - ```python - class BaseMediaStore(ABC): - async def store(self, data: bytes, mime_type: str, metadata: dict | None = None) -> str # returns storage_key - async def retrieve(self, storage_key: str) -> tuple[bytes, str] # returns (data, mime_type) - async def delete(self, storage_key: str) -> bool - async def exists(self, storage_key: str) -> bool - def to_media_ref(self, storage_key: str, mime_type: str, **kwargs) -> MediaRef - ``` - - `to_media_ref()` creates `MediaRef(kind="url", url="agentflow://media/{key}", mime_type=...)` - - Storage keys are opaque strings (UUID-based), no user input in keys - -- [x] **3.2** Implement `InMemoryMediaStore` - - `agentflow/media/storage/memory_store.py` - - `dict[str, tuple[bytes, str, dict]]` — key → (data, mime_type, metadata) - - For testing and ephemeral scripts - - Auto-cleanup via TTL or max size (optional) - -- [x] **3.3** Implement `LocalFileMediaStore` - - `agentflow/media/storage/local_store.py` - - Configurable base directory (default `./agentflow_media/`) - - Storage layout: `{base_dir}/{key[:2]}/{key[2:4]}/{key}.{ext}` (sharded to avoid too many files per dir) - - Metadata stored in sidecar `{key}.json` file - - Retrieve reads from disk; delete removes both files - - Security: path traversal prevention, validate key format - -- [x] **3.4** Implement `PgBlobStore` (for PG-only deployments) - - `agentflow/media/storage/pg_store.py` - - Uses **separate** `media_blobs` table (NOT in states/messages JSONB): - ```sql - CREATE TABLE media_blobs ( - storage_key VARCHAR(255) PRIMARY KEY, - data BYTEA NOT NULL, - mime_type VARCHAR(100) NOT NULL, - size_bytes BIGINT, - thread_id VARCHAR(255), -- optional, for cleanup - created_at TIMESTAMPTZ DEFAULT NOW(), - metadata JSONB DEFAULT '{}' - ); - CREATE INDEX idx_media_blobs_thread ON media_blobs(thread_id); - ``` - - Stores actual bytes in `BYTEA` column — separate from message/state JSONB - - Messages/state still only contain `MediaRef(kind="url", url="agentflow://media/key")` - - Benefits: same PG infra, transactional consistency, no extra service - - Trade-off: PG not ideal for large blobs; fine for <10MB typical images - -- [x] **3.5** Create `MediaRefResolver` — resolves references at LLM call time - - `agentflow/media/resolver.py` - - `resolve_for_openai(ref: MediaRef) -> dict` — converts to OpenAI content part format - - `resolve_for_google(ref: MediaRef) -> types.Part` — converts to Google Part - - Handles all `MediaRef.kind` values: `"url"` (internal + external), `"data"`, `"file_id"` - - For `agentflow://media/{key}` URLs: calls `MediaStore.retrieve()` to get bytes - - For `https://` URLs: passes through to provider directly - - For `data` kind: uses inline base64 directly (small payloads) - - For `file_id` kind: uses provider's native file reference - -- [x] **3.6** Add inline data guard / auto-offload hook - - In `MediaProcessor` or as a standalone utility: - ```python - async def ensure_media_offloaded(message: Message, store: BaseMediaStore, max_inline: int = 50_000) -> Message: - """Replace large inline data_base64 with MediaStore references.""" - for block in message.content: - if hasattr(block, 'media') and block.media.kind == "data": - raw_size = len(block.media.data_base64 or "") * 3 // 4 # approx decoded size - if raw_size > max_inline: - data = base64.b64decode(block.media.data_base64) - key = await store.store(data, block.media.mime_type or "application/octet-stream") - block.media = store.to_media_ref(key, block.media.mime_type) - return message - ``` - - Callable at API ingestion boundary, NOT in checkpointer - - Configurable via `MultimodalConfig.offload_policy` and `max_inline_bytes` - -- [x] **3.7** Inject `MediaStore` into graph compilation - - `graph.compile(media_store=LocalFileMediaStore("./uploads"))` or - - `graph.compile(media_store=InMemoryMediaStore())` - - Store reference available to Agent and converters during execution - - Similar to how ADK's `ArtifactService` is injected - -- [x] **3.8** Wire `MediaRefResolver` into converter pipeline - - `_convert_dict()` (OpenAI) → uses resolver for ImageBlock/AudioBlock/DocumentBlock - - `_handle_regular_message()` (Google) → uses resolver for same - - Resolver is instantiated with the `MediaStore` from graph config - -- [x] **3.9** Convenience helpers on `Message` - - `Message.with_image(data: bytes, mime_type: str, store: BaseMediaStore) -> Message` - - `Message.with_file(path: str, store: BaseMediaStore) -> Message` - - These store-then-reference: `store.store(data) → to_media_ref() → ImageBlock(media=ref)` - - Also support direct URL/file_id for cases where MediaStore isn't needed - -- [x] **3.10** Write comprehensive tests - - InMemoryMediaStore: store/retrieve/delete roundtrip - - LocalFileMediaStore: same + path traversal prevention + cleanup - - PgBlobStore: same + verify blobs NOT in states/messages tables - - MediaRefResolver: all MediaRef kinds → correct OpenAI/Google format - - Auto-offload: large inline → auto-replaced with store reference - - End-to-end: upload image → store → message → checkpointer save → reload → resolve for LLM - - Verify: after checkpointer roundtrip, states JSONB is small (no base64 blobs) - -- [x] **3.11** Implement `CloudMediaStore` (S3/GCS) via `cloud-storage-manager` - - `agentflow/media/storage/cloud_store.py` - - Uses `cloud-storage-manager` package (`pip install cloud-storage-manager`) - - Supports both AWS S3 and GCS through unified `CloudStorageFactory` interface - - Blob + sidecar metadata JSON stored in bucket with sharded layout - - Signed URL download via `httpx` (async) with `urllib` fallback - - Temp file upload (bytes → tempfile → upload → cleanup) - - `get_public_url()` bonus method for direct browser/client access - - Added `cloud-storage` optional dependency: `pip install 10xscale-agentflow[cloud-storage]` - -- [x] **3.12** Fix OpenAI Responses API multimodal input format - - Added `_to_responses_content()` helper in `agent_internal/openai.py` - - Converts Chat Completions content parts → Responses API format: - - `text` → `input_text`, `image_url` → `input_image` (flattened URL), `input_audio` preserved - - Wired into `_call_openai_responses()` for all message content - -- [x] **3.13** Add multimodal response handling to OpenAI Responses converter - - `_extract_media_from_message_item()`: handles `output_image` and `output_audio` entries - - `_extract_image_generation()`: handles `image_generation_call` items (DALL-E etc.) - - Both non-streaming `convert_response` and streaming paths updated - -- [x] **3.14** Add Document & Video type support across all providers - - Updated `_document_block_to_openai()` → proper `{"type": "document", "document": {...}}` format - - Created `_video_block_to_openai()` → `{"type": "video", "video": {...}}` - - Updated `_build_content()` to handle `VideoBlock` in multimodal branch - - Updated `_to_responses_content()` → document (`input_text`/`input_file`) + video (`input_text` ref) - - Updated `_content_parts_to_google()` → document (`Part(text=…)`/`Part.from_bytes()`/`Part.from_uri()`) + video - -- [x] **3.15** Multi-agent media stripping for text-only agents - - Created `strip_media_blocks()` in `converter.py` — removes non-text content parts from message dicts - - Wired into `Agent.execute()` in `execution.py` — auto-strips when `multimodal_config is None` - - Collapses single remaining text part back to plain string for maximum compatibility - -- [x] **3.16** Streaming media extraction for OpenAI Responses converter - - Added `output_item.done` handlers for `message` type → `_extract_media_from_message_item()` - - Added `image_generation_call` / `image_generation` handlers → `_extract_image_generation()` - - Both sync and async streaming paths updated - -- [x] **3.17** Streaming media extraction for Google GenAI converter - - Added `_process_inline_media_part()` and `_process_file_media_part()` calls in `_extract_delta_content_blocks()` - - Streaming chunks with images/audio/video now extracted as `ContentBlock`s - -- [x] **3.18** Comprehensive multimodal end-to-end tests (`tests/test_multimodal_e2e.py`) - - 73 tests across 11 test classes covering: - - `_build_content` with all media types (image, audio, document, video) - - `strip_media_blocks` for multi-agent workflows - - `_to_responses_content` (OpenAI Responses input) - - `_content_parts_to_google` (Google GenAI input) - - OpenAI Chat / Responses / Google GenAI converter output - - Multi-agent image stripping integration - - Edge cases and full pipeline integration - - All 73 tests passing; full suite: 2279 passed, 0 failed - -### How the Pieces Fit Together (State & Checkpointer Summary) - -``` -BEFORE (broken): - Image bytes → base64 → MediaRef(kind="data", data_base64="...1MB...") - → Message → AgentState.context → PgCheckpointer → 1MB in JSONB + 1MB in TEXT + 1MB in Redis - -AFTER (fixed): - Image bytes → MediaStore.store(bytes) → key "abc123" - → MediaRef(kind="url", url="agentflow://media/abc123") ← ~100 bytes - → Message → AgentState.context → PgCheckpointer → 100 bytes in JSONB + 100 bytes in TEXT + 100 bytes in Redis - - Actual binary stored ONCE in: - - InMemoryMediaStore: Python dict (testing) - - LocalFileMediaStore: filesystem (dev) - - PgBlobStore: media_blobs BYTEA table (PG-only deployment) - - CloudMediaStore: S3 / GCS bucket via cloud-storage-manager (production) -``` - -**Checkpointer itself: ZERO changes needed.** -**DB schema for states/messages: ZERO changes needed.** -**The fix is entirely at the ingestion boundary and the LLM conversion boundary.** - -### Sprint 4: API Layer — File Upload Endpoints (pyagenity-api) -**Goal**: REST API support for multimodal messages. Document extraction (Sprint 2) is already wired in; this sprint adds the upload endpoints, invoke/stream multimodal support, and wires everything together. - -- [x] **4.1** Add file upload endpoint - ``` - POST /v1/files/upload - Content-Type: multipart/form-data - - Response: { - "file_id": "file_abc123", - "mime_type": "image/jpeg", - "size_bytes": 102400, - "filename": "photo.jpg", - "extracted_text": "... (populated for supported document types, null for images/binary)", - "url": "/v1/files/file_abc123" - } - ``` - - For images/audio: store binary in `MediaStore`, return `file_id` - - For documents (PDF, DOCX, etc.): store binary **and** run `DocumentPipeline.extract()`, return both `file_id` and `extracted_text` - - Enforce `MEDIA_MAX_SIZE_MB` limit - -- [x] **4.2** Add file retrieval endpoint - ``` - GET /v1/files/{file_id} - → Returns file binary with correct Content-Type - - GET /v1/files/{file_id}/info - → Returns file metadata (filename, mime_type, size_bytes, extracted_text if available) - ``` - -- [x] **4.3** Update graph invoke/stream endpoints to accept multimodal messages - - `GraphInputSchema.messages` already accepts `Message` with `ContentBlock` — no schema change needed - - Ensure JSON deserialization of `ImageBlock`, `DocumentBlock` etc. works correctly in API request - - When a `DocumentBlock` with `file_id` is present in an incoming message: - - Look up cached extracted text (from upload in 4.1) and substitute `TextBlock` - - If no cached extraction, run `DocumentPipeline` on-the-fly - - Client sends: - ```json - { - "messages": [{ - "role": "user", - "content": [ - {"type": "text", "text": "What is in this image?"}, - {"type": "image", "media": {"kind": "url", "url": "https://..."}} - ] - }] - } - ``` - - Or with uploaded file: - ```json - { - "messages": [{ - "role": "user", - "content": [ - {"type": "text", "text": "Analyze this PDF"}, - {"type": "document", "media": {"kind": "file_id", "file_id": "file_abc123", "mime_type": "application/pdf"}} - ] - }] - } - ``` - -- [x] **4.4** Add multimodal config endpoint - ``` - GET /v1/config/multimodal → returns current config - PUT /v1/config/multimodal → update config (admin) - ``` - -- [x] **4.5** Wire up `MediaProcessor`, `MediaStore`, and `DocumentPipeline` in API server startup - - Configure via environment variables / settings: - - `MEDIA_STORAGE_TYPE=local|memory|s3|gcs|cloud` - - `MEDIA_STORAGE_PATH=./uploads` - - `MEDIA_MAX_SIZE_MB=25` - - `DOCUMENT_HANDLING=extract_text|pass_raw|skip` - - `DocumentPipeline` instantiated once at startup, injected via FastAPI dependency - -- [x] **4.6** Write API tests - -### Sprint 5: Client SDK Support (agentflow-react) -**Goal**: TypeScript client support for multimodal - -- [x] **5.1** Update TypeScript message types - - Add `ImageBlock`, `AudioBlock`, `DocumentBlock` types matching Python models - - Update `ContentBlock` union type - -- [x] **5.2** Add file upload client methods - ```typescript - client.files.upload(file: File | Blob): Promise - client.files.get(fileId: string): Promise - ``` - -- [x] **5.3** Add multimodal message helpers - ```typescript - Message.withImage(text: string, imageUrl: string): Message - Message.withFile(text: string, file: FileRef): Message - ``` - -- [x] **5.4** Update playground/UI components - - File upload button in chat input - - Image preview in message bubbles - - Document icon/preview for PDFs - - Drag & drop support - -- [x] **5.5** Write client tests - -### Sprint 6: Advanced Features & Polish -**Goal**: Production readiness - -- [x] **6.1** Image processing utilities - - Auto-resize large images before sending - - Thumbnail generation for storage - - PIL-based processing option (convert to JPEG, optimize) - - EXIF rotation handling - -- [x] **6.2** Provider-specific optimizations - - Google: Use File API for large files (>20MB) - - OpenAI: Use file_search for PDFs when available - - Caching: Don't re-upload same file to provider - -- [x] **6.3** Streaming support for multimodal - - Ensure streaming responses with images work correctly - - Handle image generation streaming (progressive) - -- [x] **6.4** Security hardening - - File type validation (magic bytes, not just extension) - - Max file size enforcement - - Virus scanning hook (optional) - - Rate limiting on uploads - - Sanitize filenames - -- [x] **6.5** Documentation - - Multimodal usage guide - - API reference for file endpoints - - Configuration guide - - Examples: image analysis, document Q&A, multimodal agent - ---- - -## Configuration Reference - -### Agent-Level Config -```python -from agentflow.media.config import MultimodalConfig, ImageHandling, DocumentHandling - -agent = Agent( - model="gpt-4o", - provider="openai", - multimodal_config=MultimodalConfig( - image_handling=ImageHandling.BASE64, # base64 | url | file_id - document_handling=DocumentHandling.EXTRACT_TEXT, # extract_text | pass_raw | skip - max_image_size_mb=10.0, - max_image_dimension=2048, - ), -) -``` - -### API-Level Config (pyagenity-api) -```python -# In API settings / .env -MULTIMODAL_IMAGE_HANDLING=base64 # How API stores/passes images -MULTIMODAL_DOCUMENT_HANDLING=extract_text # How API handles documents -MEDIA_STORAGE_TYPE=local # local | memory | cloud (s3/gcs) -MEDIA_STORAGE_PATH=./uploads # For local storage -MEDIA_MAX_SIZE_MB=25 # Max file upload size -``` - -### Per-Request Override (via API) -```json -{ - "messages": [...], - "config": { - "multimodal": { - "document_handling": "pass_raw" - } - } -} -``` - ---- - -## Provider Support Matrix - -| Feature | OpenAI | Google Gemini | Notes | -|---------|--------|---------------|-------| -| Image (base64) | ✅ gpt-4o, gpt-4o-mini | ✅ All Gemini | Most universal | -| Image (URL) | ✅ All vision models | ✅ All Gemini | Requires public URL | -| Image (file_id) | ✅ Via Assistants API | ✅ File API | Provider-managed | -| PDF (raw) | ✅ gpt-4o (as images) | ✅ Gemini (native) | Google has better native PDF support | -| PDF (extract text) | ✅ All text models | ✅ All text models | Universal fallback | -| DOCX (extract text) | ✅ All text models | ✅ All text models | Always extract | -| Audio input | ✅ gpt-4o-audio | ✅ Gemini Live | Limited model support | -| Video input | ❌ Not supported | ✅ Gemini native | Google only | - ---- - -## Priority & Dependencies - -``` -Sprint 1 (Foundation) ← Start here, unblocks everything - │ - ├─► Sprint 2 (Documents) ← Independent from Sprint 3 - │ - ├─► Sprint 3 (MediaStore) ← Independent from Sprint 2, CRITICAL for production - │ │ (without this, images bloat the DB) - │ │ - └───┬───┘ - │ - └─► Sprint 4 (API) ← Depends on Sprint 1 + Sprint 3 (file upload needs MediaStore) - │ - └─► Sprint 5 (Client) ← Depends on Sprint 4 - │ - └─► Sprint 6 (Polish) ← After everything works - -Sprint 1 is mandatory first. -Sprints 2 and 3 can run in parallel after Sprint 1. -Sprint 4 REQUIRES Sprint 3 (file uploads need MediaStore). -For dev/testing, Sprint 1 alone enables multimodal with inline base64 using InMemoryCheckpointer. -For production with PgCheckpointer, Sprint 3 is required BEFORE serving real traffic. -``` - ---- - -## Estimated Complexity - -| Sprint | Files Changed | New Files | Complexity | -|--------|--------------|-----------|------------| -| Sprint 1 | 4-5 modified | 2 new | Medium — core pipeline changes | -| Sprint 2 | 2-3 modified | 5 new | Medium — new module, optional deps | -| Sprint 3 | 3-4 modified | 7 new | High — storage backends, resolver, auto-offload, wiring | -| Sprint 4 | 3-4 modified | 3 new | Medium — API endpoints + wiring | -| Sprint 5 | 5-6 modified | 2 new | Medium — TypeScript types + UI | -| Sprint 6 | 3-4 modified | 1-2 new | Low-Medium — polish & optimization | - -## Key Architectural Decision: Why Checkpointer Doesn't Change - -The temptation is to add media-aware serialization in the checkpointer (e.g., intercept `model_dump()`, detect large base64, store separately). **Don't do this.** Reasons: - -1. **Violates single responsibility** — Checkpointer's job is state persistence, not media management -2. **Creates hidden coupling** — Checkpointer would need a MediaStore reference, complicating DI -3. **Breaks deserialization** — If checkpointer strips data on save, it needs to re-inject on load; this is fragile -4. **Not where the problem is** — The problem is at the *ingestion* boundary, not the *persistence* boundary - -Instead, the fix follows the **"clean input" principle**: ensure that by the time data reaches `AgentState.context`, all large binary payloads have already been offloaded to `MediaStore` and replaced with lightweight `MediaRef` references. The checkpointer then serializes these references like any other Pydantic model — no special handling needed. - -The only optional guard is the `ensure_media_offloaded()` function that can be called at the API boundary as a safety net. If somehow a large inline base64 slips through, it warns (or auto-offloads) before the message enters the state. diff --git a/agentflow/__init__.py b/agentflow/__init__.py index 8431364..e69de29 100644 --- a/agentflow/__init__.py +++ b/agentflow/__init__.py @@ -1,291 +0,0 @@ -""" -10xScale Agentflow: A lightweight Python framework for building intelligent -agents and multi-agent workflows. - -Quick start:: - - from agentflow import StateGraph, Agent, Message, START, END - - graph = StateGraph(AgentState) - graph.add_node("agent", Agent(model="gpt-4o")) - graph.add_edge(START, "agent") - graph.add_edge("agent", END) - app = graph.compile() -""" - -from __future__ import annotations - -# --------------------------------------------------------------------------- -# Sub-packages (namespace imports for deep access) -# --------------------------------------------------------------------------- -from . import core, prebuilt, qa, runtime, utils -from .core import exceptions, graph, skills, state - -# --------------------------------------------------------------------------- -# Exceptions -# --------------------------------------------------------------------------- -from .core.exceptions import ( - GraphError, - GraphRecursionError, - MetricsError, - NodeError, - ResourceNotFoundError, - SchemaVersionError, - SerializationError, - StorageError, - TransientStorageError, -) - -# --------------------------------------------------------------------------- -# Graph / Workflow Engine -# --------------------------------------------------------------------------- -from .core.graph import ( - Agent, - BaseAgent, - CompiledGraph, - Edge, - Node, - RetryConfig, - StateGraph, - ToolNode, -) - -# --------------------------------------------------------------------------- -# Skills -# --------------------------------------------------------------------------- -from .core.skills import ( - SkillConfig, - SkillMeta, - SkillsRegistry, -) - -# --------------------------------------------------------------------------- -# State Management -# --------------------------------------------------------------------------- -from .core.state import ( - AgentState, - AnnotationBlock, - AnnotationRef, - AudioBlock, - # Context managers - BaseContextManager, - ContentBlock, - DataBlock, - DocumentBlock, - ErrorBlock, - ExecutionState, - ExecutionStatus, - ImageBlock, - MediaRef, - Message, - MessageContextManager, - ReasoningBlock, - StreamChunk, - StreamEvent, - # Content blocks - TextBlock, - TokenUsages, - ToolCallBlock, - # Tool results - ToolResult, - ToolResultBlock, - VideoBlock, - # Reducers - add_messages, - append_items, - remove_tool_messages, - replace_messages, - replace_value, -) - -# --------------------------------------------------------------------------- -# Prebuilt Agents -# --------------------------------------------------------------------------- -from .prebuilt import ( - RAGAgent, - ReactAgent, - RouterAgent, - create_handoff_tool, -) -from .storage import checkpointer, media, store - -# --------------------------------------------------------------------------- -# Storage (common classes re-exported for convenience) -# --------------------------------------------------------------------------- -from .storage.checkpointer import ( - BaseCheckpointer, - InMemoryCheckpointer, - PgCheckpointer, -) -from .storage.media import ( - BaseMediaStore, - CloudMediaStore, - DocumentHandling, - ImageHandling, - InMemoryMediaStore, - LocalFileMediaStore, - MediaOffloadPolicy, - MediaProcessor, - MediaRefResolver, - MultimodalConfig, - ProviderMediaCache, - enforce_file_size, - sanitize_filename, - validate_magic_bytes, -) -from .storage.store import ( - BaseEmbedding, - BaseStore, - GoogleEmbedding, - Mem0Store, - MemoryIntegration, - MemoryRecord, - MemorySearchResult, - OpenAIEmbedding, - QdrantStore, - ReadMode, - create_cloud_qdrant_store, - create_local_qdrant_store, - create_mem0_store, -) - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- -# --------------------------------------------------------------------------- -# Utilities (most commonly used) -# --------------------------------------------------------------------------- -from .utils import ( - END, - START, - CallbackContext, - CallbackManager, - Command, - ResponseGranularity, - ThreadInfo, - convert_messages, - get_tool_metadata, - tool, -) - - -__all__ = [ - # Sub-packages - "core", - "exceptions", - "graph", - "prebuilt", - "qa", - "runtime", - "skills", - "state", - "storage", - "utils", - "checkpointer", - "store", - "media", - # Graph - "Agent", - "BaseAgent", - "CompiledGraph", - "Edge", - "Node", - "RetryConfig", - "StateGraph", - "ToolNode", - # State - "AgentState", - "ExecutionState", - "ExecutionStatus", - "Message", - "StreamChunk", - "StreamEvent", - "TokenUsages", - "TextBlock", - "ImageBlock", - "AudioBlock", - "VideoBlock", - "DocumentBlock", - "DataBlock", - "ErrorBlock", - "ReasoningBlock", - "ToolCallBlock", - "ToolResultBlock", - "AnnotationBlock", - "ContentBlock", - "AnnotationRef", - "MediaRef", - "BaseContextManager", - "MessageContextManager", - "ToolResult", - # Reducers - "add_messages", - "append_items", - "replace_messages", - "replace_value", - "remove_tool_messages", - # Constants - "START", - "END", - # Utilities - "Command", - "CallbackManager", - "CallbackContext", - "ResponseGranularity", - "ThreadInfo", - "tool", - "get_tool_metadata", - "convert_messages", - # Exceptions - "GraphError", - "NodeError", - "GraphRecursionError", - "StorageError", - "TransientStorageError", - "ResourceNotFoundError", - "SerializationError", - "SchemaVersionError", - "MetricsError", - # Prebuilt - "ReactAgent", - "RAGAgent", - "RouterAgent", - "create_handoff_tool", - # Skills - "SkillConfig", - "SkillMeta", - "SkillsRegistry", - # Checkpointer - "BaseCheckpointer", - "InMemoryCheckpointer", - "PgCheckpointer", - # Store - "BaseStore", - "QdrantStore", - "Mem0Store", - "BaseEmbedding", - "OpenAIEmbedding", - "GoogleEmbedding", - "MemoryIntegration", - "ReadMode", - "MemoryRecord", - "MemorySearchResult", - "create_local_qdrant_store", - "create_cloud_qdrant_store", - "create_mem0_store", - # Media - "BaseMediaStore", - "InMemoryMediaStore", - "LocalFileMediaStore", - "CloudMediaStore", - "MediaRefResolver", - "MediaProcessor", - "MultimodalConfig", - "DocumentHandling", - "ImageHandling", - "MediaOffloadPolicy", - "ProviderMediaCache", - "enforce_file_size", - "sanitize_filename", - "validate_magic_bytes", -] diff --git a/agentflow/core/graph/agent.py b/agentflow/core/graph/agent.py index b4d6f82..2010426 100644 --- a/agentflow/core/graph/agent.py +++ b/agentflow/core/graph/agent.py @@ -5,8 +5,7 @@ """ import logging -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any from agentflow.core.graph.base_agent import BaseAgent from agentflow.core.graph.tool_node import ToolNode @@ -17,11 +16,16 @@ from .agent_internal.constants import DEFAULT_RETRY_CONFIG, REASONING_DEFAULT, RetryConfig from .agent_internal.execution import AgentExecutionMixin from .agent_internal.google import AgentGoogleMixin +from .agent_internal.memory import AgentMemoryMixin from .agent_internal.openai import AgentOpenAIMixin from .agent_internal.providers import AgentProviderMixin from .agent_internal.skills import AgentSkillsMixin +if TYPE_CHECKING: + from agentflow.storage.store.memory_config import MemoryConfig + + logger = logging.getLogger("agentflow.agent") @@ -31,6 +35,7 @@ class Agent( AgentOpenAIMixin, AgentProviderMixin, AgentSkillsMixin, + AgentMemoryMixin, BaseAgent, ): """A smart node function wrapper for LLM interactions. @@ -47,26 +52,27 @@ class Agent( Example: ```python - # Create an agent node with OpenAI + # Create an agent node with a ToolNode + tool_node = ToolNode([weather_tool]) agent = Agent( model="gpt-4o", provider="openai", - system_prompt="You are a helpful assistant", - tools=[weather_tool], + system_prompt=[{"role": "system", "content": "You are a helpful assistant"}], + tool_node=tool_node, ) # Use it in a graph graph = StateGraph() - graph.add_node("MAIN", agent) # Agent acts as a node function - graph.add_node("TOOL", agent.get_tool_node()) - # ... setup edges + graph.add_node("MAIN", agent) + graph.add_node("TOOL", tool_node) + # ... setup conditional edges ``` Attributes: model: Model identifier (e.g., "gpt-4o", "gemini-2.0-flash") provider: Provider name ("openai", "google") system_prompt: System prompt string or list of message dicts - tools: List of tool functions or ToolNode instance + tool_node: ToolNode instance or name of an existing TOOL graph node (str) client: Optional custom client instance (escape hatch for power users) temperature: LLM sampling temperature max_tokens: Maximum tokens to generate @@ -77,16 +83,16 @@ def __init__( # noqa: PLR0913 self, model: str, provider: str | None = None, - output_type: str = "text", # NEW: Explicit output type + output_type: str = "text", system_prompt: list[dict[str, Any]] | None = None, - tools: list[Callable] | ToolNode | None = None, - tool_node_name: str | None = None, + tool_node: "str | ToolNode | None" = None, extra_messages: list[Message] | None = None, trim_context: bool = False, tools_tags: set[str] | None = None, api_style: str = "chat", reasoning_config: dict[str, Any] | bool | None = REASONING_DEFAULT, # type: ignore skills: "SkillConfig | None" = None, + memory: "MemoryConfig | None" = None, retry_config: RetryConfig | bool | None = True, fallback_models: list[str | tuple[str, str]] | None = None, multimodal_config: MultimodalConfig | None = None, @@ -121,18 +127,30 @@ class MyState(AgentState): }] ) # At runtime, placeholders are replaced with state values - tools: List of tool functions, ToolNode instance, or None. - If list is provided, will be converted to ToolNode internally. - tool_node_name: Name of the existing ToolNode. You can send list of tools - or provide ToolNode instance via `tools` parameter instead. + tool_node: A ``ToolNode`` instance containing the tools this agent may call, + **or** a ``str`` naming an existing graph node whose ``func`` is a + ``ToolNode`` (resolved at execution time via the DI container). + Pass ``None`` when the agent needs no tools. + + Examples:: + + # Inline ToolNode — agent owns the tools + tool_node = ToolNode([get_weather, search]) + agent = Agent(model="gpt-4o", tool_node=tool_node) + + # Named reference — ToolNode lives as a separate graph node + agent = Agent(model="gpt-4o", tool_node="TOOL") + extra_messages: Additional messages to include in every interaction. trim_context: Whether to trim context using context manager. tools_tags: Optional tags to filter tools. base_url (via **kwargs): Optional base URL for OpenAI-compatible APIs (ollama, vllm, openrouter, deepseek, etc.). Default: ``None``. - api_style (via **kwargs): API style for OpenAI provider. ``"chat"`` uses + api_style: API style for OpenAI provider. ``"chat"`` uses Chat Completions, ``"responses"`` uses the Responses API. Default: ``"chat"``. + memory: Optional ``MemoryConfig`` enabling agent-level long-term + memory tools and system prompts. reasoning_config: Unified reasoning control for all providers. Default is ``{"effort": "medium"}`` (on). Pass ``None`` to turn off. ``effort`` applies to both providers; ``summary`` is OpenAI-only; @@ -176,35 +194,27 @@ class MyState(AgentState): Example: ```python - # Text generation (default - no need to specify output_type) + # Text generation with inline ToolNode + tool_node = ToolNode([weather_tool, calculator]) text_agent = Agent( model="openai/gpt-4o", - system_prompt="You are a helpful assistant", - tools=[weather_tool, calculator], + system_prompt=[{"role": "system", "content": "You are a helpful assistant"}], + tool_node=tool_node, temperature=0.8, ) - # Image generation (explicit) - image_agent = Agent( - model="openai/dall-e-3", - output_type="image", - ) - - # Video generation (explicit) - video_agent = Agent( - model="google/veo-2.0", - provider="google", - output_type="video", + # Text generation with named ToolNode in graph + agent = Agent( + model="google/gemini-2.5-flash", + tool_node="TOOL", # references graph node named "TOOL" ) - # Multi-modal workflow (Google ADK style) - prompt_agent = Agent( - model="google/gemini-2.0-flash-exp", - system_prompt="Generate detailed image prompts", - ) + # No tools + agent = Agent(model="gpt-4o") - imagen_agent = Agent( - model="google/imagen-3.0-generate-001", + # Image generation + image_agent = Agent( + model="openai/dall-e-3", output_type="image", ) @@ -215,30 +225,27 @@ class MyState(AgentState): base_url="https://api.qwen.com/v1", ) - ollama_agent = Agent( - model="llama3:70b", - provider="openai", - base_url="http://localhost:11434/v1", - ) - # With retry and fallback resilient_agent = Agent( model="gemini-2.5-flash", provider="google", retry_config=RetryConfig(max_retries=5, initial_delay=2.0), fallback_models=[ - "gemini-2.0-flash", # same provider - ("gpt-4o-mini", "openai"), # cross-provider fallback + "gemini-2.0-flash", + ("gpt-4o-mini", "openai"), ], ) ``` """ # Pop kwargs-only params before passing to parent base_url: str | None = kwargs.pop("base_url", None) - # Note: api_style is already a named parameter, don't pop from kwargs # Call parent constructor super().__init__( - model=model, system_prompt=system_prompt or [], tools=tools, base_url=base_url, **kwargs + model=model, + system_prompt=system_prompt or [], + tool_node=tool_node, + base_url=base_url, + **kwargs, ) # check user sending model and provider as prefix, if provider is not explicitly provided @@ -267,7 +274,7 @@ class MyState(AgentState): self.extra_messages = extra_messages self.trim_context = trim_context self.tools_tags = tools_tags - self.tool_node_name = tool_node_name + self.tool_node_name = None # may be set to a str by _setup_tools() # Internal setup self._tool_node = self._setup_tools() @@ -307,6 +314,10 @@ class MyState(AgentState): f"output_type={self.output_type}, has_tools={self._tool_node is not None}" ) + # Memory setup (via mixin) runs before skills so a memory-only Agent can + # lazily create the internal ToolNode that both systems append to. + self._setup_memory(memory) + # Skills setup (via mixin) self._setup_skills(skills) diff --git a/agentflow/core/graph/agent_internal/execution.py b/agentflow/core/graph/agent_internal/execution.py index 81ec092..2c71c8a 100644 --- a/agentflow/core/graph/agent_internal/execution.py +++ b/agentflow/core/graph/agent_internal/execution.py @@ -28,31 +28,29 @@ class AgentExecutionMixin: """Execution flow, tool resolution, and provider dispatch helpers.""" def _setup_tools(self) -> ToolNode | None: - """Normalize the tools input to a ToolNode instance.""" - if self.tools is None: - logger.debug("No tools provided") + """Normalize the tool_node input and wire internal state. + + - ``ToolNode`` → stored as ``self._tool_node``; ``tool_node_name`` remains ``None``. + - ``str`` → stored as ``self.tool_node_name`` for lazy lookup via the DI + container at execution time; returns ``None``. + - ``None`` → no tools; both attributes remain ``None``. + """ + tn = self.tool_node # str | ToolNode | None + if tn is None: + logger.debug("No tool_node provided") return None - if isinstance(self.tools, ToolNode): - logger.debug("Tools already a ToolNode instance") - return self.tools + if isinstance(tn, str): + logger.debug("tool_node is a named graph-node reference: '%s'", tn) + self.tool_node_name = tn + return None - logger.debug("Converting %d tool functions to ToolNode", len(self.tools)) - return ToolNode(self.tools) + logger.debug("tool_node is a ToolNode instance") + return tn def get_tool_node(self) -> ToolNode | None: - """Return the agent's internal ToolNode. - - Use this public method instead of accessing ``agent._tool_node`` - directly when wiring the tool node into the graph. When skills are - enabled, the returned ToolNode already contains the ``set_skill`` tool. - - Example:: - - agent = Agent(model="gpt-4o", tools=[my_tool], skills=SkillConfig(...)) - graph.add_node("TOOL", agent.get_tool_node()) - """ - return self._tool_node + """Return the agent-owned ``ToolNode`` when one is configured.""" + return getattr(self, "_tool_node", None) async def _trim_context( self, @@ -305,6 +303,9 @@ async def execute( if hasattr(self, "_build_skill_prompts") and callable(self._build_skill_prompts): effective_system_prompt = self._build_skill_prompts(state, self.system_prompt) + if hasattr(self, "_build_memory_prompts") and callable(self._build_memory_prompts): + effective_system_prompt.extend(await self._build_memory_prompts(state, config)) + messages = convert_messages( system_prompts=effective_system_prompt, state=state, @@ -499,15 +500,25 @@ async def _resolve_tools(self, container: InjectQ) -> list[dict[str, Any]]: try: node = container.call_factory("get_node", self.tool_node_name) - except (KeyError, DependencyNotFoundError): - logger.warning( - "ToolNode with name '%s' not found in InjectQ registry.", - self.tool_node_name, + except (KeyError, DependencyNotFoundError) as exc: + raise RuntimeError( + f"ToolNode named '{self.tool_node_name}' was not found in the compiled graph. " + "Register the named ToolNode in the graph before executing the Agent." + ) from exc + + if node is None: + raise RuntimeError( + f"ToolNode named '{self.tool_node_name}' was not found in the compiled graph. " + "Register the named ToolNode in the graph before executing the Agent." + ) + + if not isinstance(node.func, ToolNode): + raise RuntimeError( + f"Graph node '{self.tool_node_name}' is not a ToolNode. " + "Pass a ToolNode instance or register the proper graph node." ) - return tools - if node and isinstance(node.func, ToolNode): - return await node.func.all_tools(tags=self.tools_tags) + tools.extend(await node.func.all_tools(tags=self.tools_tags)) return tools def _extract_prompt(self, messages: list[dict[Any, Any]]) -> str: diff --git a/agentflow/core/graph/agent_internal/memory.py b/agentflow/core/graph/agent_internal/memory.py new file mode 100644 index 0000000..88165fb --- /dev/null +++ b/agentflow/core/graph/agent_internal/memory.py @@ -0,0 +1,168 @@ +"""Agent-level memory support.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from agentflow.core.graph.tool_node import ToolNode + + +if TYPE_CHECKING: + from agentflow.storage.store.memory_config import MemoryConfig + + +logger = logging.getLogger("agentflow.agent") + + +class AgentMemoryMixin: + """Memory registration helpers for Agent.""" + + _memory_config: MemoryConfig | None + _memory_integration: Any | None + _memory_prompt: dict[str, Any] | None + _tool_node: ToolNode | None + + def _setup_memory(self, memory: MemoryConfig | None) -> None: + """Initialize agent-level memory from ``MemoryConfig``.""" + self._memory_config = None + self._memory_integration = None + self._memory_prompt = None + + if memory is None: + return + + from agentflow.storage.store.long_term_memory import ( + MemoryIntegration, + get_agent_memory_system_prompt, + ) + from agentflow.storage.store.memory_config import MemoryConfig + + if not isinstance(memory, MemoryConfig): + raise TypeError(f"Expected MemoryConfig, got {type(memory)}") + + self._memory_config = memory + + if memory.inject_system_prompt: + self._memory_prompt = { + "role": "system", + "content": get_agent_memory_system_prompt(memory), + } + self.system_prompt.append(self._memory_prompt) + + memory_tools = memory.model_facing_tools() + if memory_tools and self._tool_node is None: + raise RuntimeError( + "Memory requires an existing ToolNode when model-facing memory tools are enabled. " + "Provide a ToolNode to the Agent or register the memory tools manually." + ) + + if self._tool_node is not None: + for memory_tool in memory_tools: + self._tool_node.add_tool(memory_tool) + + default_store = ( + memory.store + or (memory.user_memory.store if memory.user_memory else None) + or (memory.agent_memory.store if memory.agent_memory else None) + ) + if default_store is not None: + self._memory_integration = MemoryIntegration( + store=default_store, + retrieval_mode=memory.retrieval_mode, + limit=memory.limit, + score_threshold=memory.score_threshold, + max_tokens=memory.max_tokens, + ) + + logger.info( + "Memory enabled: user=%s agent=%s", + bool(memory.user_memory and memory.user_memory.enabled), + bool(memory.agent_memory and memory.agent_memory.enabled), + ) + + async def _build_memory_prompts( + self, + state: Any, + config: dict[str, Any], + ) -> list[dict[str, Any]]: + """Load and format memory context for preload mode.""" + from agentflow.prebuilt.tools.memory import ( + _memory_scope_config, + _memory_scope_limit, + _memory_scope_score_threshold, + _memory_scope_store, + ) + from agentflow.storage.store.long_term_memory import ( + ReadMode, + _format_search_results, + _strip_thread_id, + _validate_memory_type, + ) + + memory = getattr(self, "_memory_config", None) + if memory is None or memory.retrieval_mode != ReadMode.PRELOAD: + return [] + + query = self._latest_user_memory_query(state) + if not query: + return [] + + sections: list[str] = [] + scopes = [ + ("User memory", "user", memory.user_memory), + ("Agent memory", "agent", memory.agent_memory), + ] + for label, scope_name, scope_config in scopes: + if scope_config is None or not scope_config.enabled: + continue + store = _memory_scope_store(memory, scope_config, None) + if store is None: + continue + + search_config = _memory_scope_config( + config, + memory, + scope_config, + scope=scope_name, + ) + if scope_name == "user": + search_config = _strip_thread_id(search_config) + + try: + results = await store.asearch( + search_config, + query, + memory_type=_validate_memory_type(scope_config.memory_type), + category=scope_config.category, + limit=_memory_scope_limit(memory, scope_config, None), + score_threshold=_memory_scope_score_threshold(memory, scope_config), + **({"max_tokens": memory.max_tokens} if memory.max_tokens else {}), + ) + except Exception: + logger.exception("Memory preload search failed for %s", scope_name) + continue + + formatted = _format_search_results(results) + if not formatted: + continue + lines = [f"- {item['content']} (relevance: {item['score']})" for item in formatted] + sections.append(f"{label}:\n" + "\n".join(lines)) + + if not sections: + return [] + + return [ + { + "role": "system", + "content": "[Long-term Memory Context]\n" + "\n\n".join(sections), + } + ] + + @staticmethod + def _latest_user_memory_query(state: Any) -> str: + for msg in reversed(getattr(state, "context", []) or []): + if getattr(msg, "role", None) == "user": + text = msg.text() if hasattr(msg, "text") else str(msg) + return text.strip() + return "" diff --git a/agentflow/core/graph/agent_internal/skills.py b/agentflow/core/graph/agent_internal/skills.py index 589dc3c..0ff69b4 100644 --- a/agentflow/core/graph/agent_internal/skills.py +++ b/agentflow/core/graph/agent_internal/skills.py @@ -57,11 +57,13 @@ def _setup_skills(self, skills: SkillConfig | None) -> None: hot_reload=self._skills_config.hot_reload, ) - # Add skill tool to the tool node + # Add skill tool to the tool node; require an existing ToolNode. if self._tool_node is None: - self._tool_node = ToolNode([set_skill_fn]) - else: - self._tool_node.add_tool(set_skill_fn) + raise RuntimeError( + "Skills require an existing ToolNode when skills are enabled. " + "Provide a ToolNode to the Agent before configuring skills." + ) + self._tool_node.add_tool(set_skill_fn) # Build and cache trigger-table prompt once during setup. if self._skills_config.inject_trigger_table: diff --git a/agentflow/core/graph/base_agent.py b/agentflow/core/graph/base_agent.py index 05b45d1..3e8e8e6 100644 --- a/agentflow/core/graph/base_agent.py +++ b/agentflow/core/graph/base_agent.py @@ -7,10 +7,12 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from agentflow.core.graph.tool_node.base import ToolNode -from agentflow.core.graph.tool_node.base import ToolNode from agentflow.core.state import AgentState from agentflow.core.state.message import Message @@ -31,7 +33,7 @@ class BaseAgent(ABC): Attributes: model: LLM model identifier system_prompt: System prompt configuration - tools: Optional tool configuration + tool_node: ToolNode instance or named graph-node reference (str) kwargs: Additional configuration parameters Example: @@ -49,8 +51,7 @@ def __init__( model: str, provider: str | None = None, system_prompt: list[dict[str, Any]] | None = None, - tools: list[Callable] | ToolNode | None = None, - tool_node_name: str | None = None, + tool_node: "str | ToolNode | None" = None, extra_messages: list[Message] | None = None, client: Any = None, # Escape hatch: allow custom client base_url: str | None = None, # For OpenAI-compatible APIs (ollama, vllm, etc.) @@ -63,12 +64,13 @@ def __init__( Args: model: LLM model identifier (e.g., "gpt-4", "gemini/gemini-2.0-flash") system_prompt: System prompt as list of message dicts - tools: Optional list of tools or tool configuration + tool_node: ToolNode instance, or a string naming an existing graph node + whose ``func`` is a ToolNode. Pass ``None`` for no tools. **kwargs: Additional LLM or agent configuration parameters """ self.model = model self.system_prompt = system_prompt or [] - self.tools = tools + self.tool_node = tool_node self.llm_kwargs = llm_kwargs @abstractmethod @@ -119,7 +121,8 @@ def get_tool_node(self) -> "ToolNode | None": Example:: - agent = Agent(model="gpt-4o", tools=[my_tool]) + tool_node = ToolNode([my_tool]) + agent = Agent(model="gpt-4o", tool_node=tool_node) graph.add_node("TOOL", agent.get_tool_node()) """ return None diff --git a/agentflow/core/graph/tool_node/__init__.py b/agentflow/core/graph/tool_node/__init__.py index 5d089dd..fc18232 100644 --- a/agentflow/core/graph/tool_node/__init__.py +++ b/agentflow/core/graph/tool_node/__init__.py @@ -4,8 +4,6 @@ - ToolNode - HAS_FASTMCP, HAS_MCP - -Backwards-compatible import path: ``from agentflow.graph.tool_node import ToolNode`` """ from agentflow.core.state.tool_result import ToolResult diff --git a/agentflow/core/skills/registry.py b/agentflow/core/skills/registry.py index d8c5f9d..a724ca9 100644 --- a/agentflow/core/skills/registry.py +++ b/agentflow/core/skills/registry.py @@ -188,6 +188,6 @@ def build_trigger_table(self, tags: set[str] | None = None) -> str: def build_set_skill_tool(self, hot_reload: bool = True) -> Any: """Convenience — delegates to :func:`activation.make_set_skill_tool`.""" - from .activation import make_set_skill_tool + from agentflow.core.skills.activation import make_set_skill_tool return make_set_skill_tool(self, hot_reload=hot_reload) diff --git a/agentflow/prebuilt/__init__.py b/agentflow/prebuilt/__init__.py index b070de0..b142c90 100644 --- a/agentflow/prebuilt/__init__.py +++ b/agentflow/prebuilt/__init__.py @@ -1,58 +1,82 @@ -"""Prebuilt agents and tools for Agentflow. +"""Prebuilt tools and agent packages for Agentflow. -This package provides ready-to-use agent patterns and utility tools: - -- ``agentflow.prebuilt.agent`` — prebuilt agent implementations (ReactAgent, RAGAgent, ...) -- ``agentflow.prebuilt.tools`` — prebuilt tools (handoff, ...) +Import concrete agent implementations from ``agentflow.prebuilt.agent`` and +tool helpers from ``agentflow.prebuilt.tools``. """ from __future__ import annotations -from importlib import import_module -from typing import TYPE_CHECKING, Any - +from importlib import import_module as _import_module +from typing import Any as _Any -if TYPE_CHECKING: - from . import agent, tools - from .agent import RAGAgent, ReactAgent, RouterAgent - from .tools import create_handoff_tool, is_handoff_tool -__all__ = [ - # Agents +_AGENT_EXPORTS = { "RAGAgent", "ReactAgent", "RouterAgent", - # Submodules - "agent", - # Tools +} + +_TOOL_EXPORTS = { "create_handoff_tool", + "fetch_url", + "file_read", + "file_search", + "file_write", + "google_web_search", "is_handoff_tool", - "tools", -] - -_LAZY_EXPORTS: dict[str, tuple[str, str | None]] = { - "agent": (".agent", None), - "tools": (".tools", None), - "RAGAgent": (".agent", "RAGAgent"), - "ReactAgent": (".agent", "ReactAgent"), - "RouterAgent": (".agent", "RouterAgent"), - "create_handoff_tool": (".tools", "create_handoff_tool"), - "is_handoff_tool": (".tools", "is_handoff_tool"), + "make_agent_memory_tool", + "make_user_memory_tool", + "memory_tool", + "safe_calculator", + "vertex_ai_search", } -def __getattr__(name: str) -> Any: - """Lazily expose prebuilt agents and tools.""" - try: - module_name, attribute_name = _LAZY_EXPORTS[name] - except KeyError as exc: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc +def __getattr__(name: str) -> _Any: + """Load prebuilt agents and tools only when explicitly requested.""" + if name == "agent": + module = _import_module(f"{__name__}.agent") + globals()[name] = module + return module + + if name == "tools": + module = _import_module(f"{__name__}.tools") + globals()[name] = module + return module + + if name in _AGENT_EXPORTS: + agent_module = _import_module(f"{__name__}.agent") + value = getattr(agent_module, name) + globals()[name] = value + return value - module = import_module(module_name, __name__) - value = module if attribute_name is None else getattr(module, attribute_name) - globals()[name] = value - return value + if name in _TOOL_EXPORTS: + tools_module = _import_module(f"{__name__}.tools") + value = getattr(tools_module, name) + globals()[name] = value + return value + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -def __dir__() -> list[str]: - return sorted(set(globals()) | set(__all__)) + +__all__ = [ + "RAGAgent", + "ReactAgent", + "RouterAgent", + # Agents + "agent", + # Tools + "create_handoff_tool", + "fetch_url", + "file_read", + "file_search", + "file_write", + "google_web_search", + "is_handoff_tool", + "make_agent_memory_tool", + "make_user_memory_tool", + "memory_tool", + "safe_calculator", + "tools", + "vertex_ai_search", +] diff --git a/agentflow/prebuilt/tools/__init__.py b/agentflow/prebuilt/tools/__init__.py index e71dbcd..99f0389 100644 --- a/agentflow/prebuilt/tools/__init__.py +++ b/agentflow/prebuilt/tools/__init__.py @@ -1,9 +1,24 @@ """Prebuilt tools for agentflow graphs.""" +from .calculator import safe_calculator +from .fetch import fetch_url +from .files import file_read, file_search, file_write from .handoff import create_handoff_tool, is_handoff_tool +from .memory import make_agent_memory_tool, make_user_memory_tool, memory_tool +from .search import google_web_search, vertex_ai_search __all__ = [ "create_handoff_tool", + "fetch_url", + "file_read", + "file_search", + "file_write", + "google_web_search", "is_handoff_tool", + "make_agent_memory_tool", + "make_user_memory_tool", + "memory_tool", + "safe_calculator", + "vertex_ai_search", ] diff --git a/agentflow/prebuilt/tools/calculator.py b/agentflow/prebuilt/tools/calculator.py new file mode 100644 index 0000000..bee63c9 --- /dev/null +++ b/agentflow/prebuilt/tools/calculator.py @@ -0,0 +1,92 @@ +"""Safe arithmetic tools for AgentFlow agents.""" + +from __future__ import annotations + +import ast +import json +import math +import operator +from typing import Any + +from agentflow.utils.decorators import tool + + +_BINARY_OPERATORS: dict[type[ast.operator], Any] = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, +} +_UNARY_OPERATORS: dict[type[ast.unaryop], Any] = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, +} +_MAX_EXPRESSION_LENGTH = 500 +_MAX_ABS_VALUE = 10**12 +_MAX_POWER_EXPONENT = 12 + + +def _evaluate_node(node: ast.AST) -> int | float: + if isinstance(node, ast.Expression): + return _evaluate_node(node.body) + + if ( + isinstance(node, ast.Constant) + and isinstance(node.value, int | float) + and not isinstance(node.value, bool) + ): + value = node.value + if not math.isfinite(float(value)) or abs(value) > _MAX_ABS_VALUE: + raise ValueError("numeric value is outside the allowed range") + return value + + if isinstance(node, ast.UnaryOp): + op = _UNARY_OPERATORS.get(type(node.op)) + if op is None: + raise ValueError("unsupported unary operator") + return op(_evaluate_node(node.operand)) + + if isinstance(node, ast.BinOp): + op = _BINARY_OPERATORS.get(type(node.op)) + if op is None: + raise ValueError("unsupported binary operator") + left = _evaluate_node(node.left) + right = _evaluate_node(node.right) + if isinstance(node.op, ast.Pow) and abs(right) > _MAX_POWER_EXPONENT: + raise ValueError("power exponent is outside the allowed range") + result = op(left, right) + if not math.isfinite(float(result)) or abs(result) > _MAX_ABS_VALUE: + raise ValueError("result is outside the allowed range") + return result + + raise ValueError(f"unsupported expression element: {type(node).__name__}") + + +@tool( + name="safe_calculator", + description=( + "Safely evaluate a basic arithmetic expression. Supports numbers, parentheses, " + "and +, -, *, /, //, %, and ** with conservative size limits." + ), + tags=["math", "calculator"], + capabilities=["calculate"], +) +def safe_calculator(expression: str, precision: int | None = None) -> str: + """Evaluate a basic arithmetic expression safely.""" + if not expression or not expression.strip(): + return json.dumps({"error": "expression is required"}) + if len(expression) > _MAX_EXPRESSION_LENGTH: + return json.dumps({"error": "expression is too long"}) + + try: + tree = ast.parse(expression, mode="eval") + result = _evaluate_node(tree) + if precision is not None and isinstance(result, float): + safe_precision = max(0, min(int(precision), 12)) + result = round(result, safe_precision) + return json.dumps({"result": result}) + except Exception as exc: + return json.dumps({"error": str(exc)}) diff --git a/agentflow/prebuilt/tools/fetch.py b/agentflow/prebuilt/tools/fetch.py new file mode 100644 index 0000000..2847534 --- /dev/null +++ b/agentflow/prebuilt/tools/fetch.py @@ -0,0 +1,128 @@ +"""Network fetch tools for AgentFlow agents.""" + +from __future__ import annotations + +import asyncio +import ipaddress +import json +import socket +from html.parser import HTMLParser +from urllib import request +from urllib.error import HTTPError, URLError +from urllib.parse import urlparse + +from agentflow.utils.decorators import tool + + +_DEFAULT_MAX_CHARS = 20_000 +_USER_AGENT = "agentflow-prebuilt-tools/1.0" + + +class _HTMLTextParser(HTMLParser): + def __init__(self) -> None: + super().__init__() + self._skip_depth = 0 + self.parts: list[str] = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag in {"script", "style", "noscript"}: + self._skip_depth += 1 + if tag in {"p", "br", "div", "li", "h1", "h2", "h3", "h4", "h5", "h6"}: + self.parts.append("\n") + + def handle_endtag(self, tag: str) -> None: + if tag in {"script", "style", "noscript"} and self._skip_depth: + self._skip_depth -= 1 + if tag in {"p", "div", "li"}: + self.parts.append("\n") + + def handle_data(self, data: str) -> None: + if not self._skip_depth: + text = data.strip() + if text: + self.parts.append(text) + + def text(self) -> str: + return " ".join(" ".join(self.parts).split()) + + +def _is_public_hostname(hostname: str | None) -> bool: + if not hostname: + return False + try: + addresses = socket.getaddrinfo(hostname, None) + except socket.gaierror: + return False + + for addr in addresses: + ip = ipaddress.ip_address(addr[4][0]) + if ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + or ip.is_unspecified + ): + return False + return True + + +def _html_to_text(html: str) -> str: + parser = _HTMLTextParser() + parser.feed(html) + return parser.text() + + +def _fetch_sync(url: str, timeout: float, max_chars: int) -> dict[str, object]: + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + return {"error": "only http and https URLs are supported"} + if not _is_public_hostname(parsed.hostname): + return {"error": "URL host is not public or could not be resolved"} + + req = request.Request(url, headers={"User-Agent": _USER_AGENT}) # noqa: S310 + try: + with request.urlopen( # noqa: S310 # nosec B310 + req, timeout=max(1.0, min(float(timeout), 30.0)) + ) as response: + raw = response.read(max_chars + 1) + status_code = response.getcode() + final_url = response.geturl() + content_type = response.headers.get("content-type", "") + except HTTPError as exc: + return {"error": f"HTTP error: {exc.code}", "status_code": exc.code} + except URLError as exc: + return {"error": f"URL error: {exc.reason}"} + + truncated = len(raw) > max_chars + body = raw[:max_chars].decode("utf-8", errors="replace") + if "html" in content_type.lower(): + body = _html_to_text(body) + if len(body) > max_chars: + body = body[:max_chars] + truncated = True + + return { + "url": final_url, + "status_code": status_code, + "content_type": content_type, + "content": body, + "truncated": truncated, + } + + +@tool( + name="fetch_url", + description=( + "Fetch a public http/https URL and return text content. Blocks private/local hosts, " + "applies a timeout, and truncates long responses." + ), + tags=["web", "fetch", "network"], + capabilities=["network_access"], +) +async def fetch_url(url: str, timeout: float = 10.0, max_chars: int = _DEFAULT_MAX_CHARS) -> str: + """Fetch a public URL and return normalized text content.""" + safe_max_chars = max(1, min(int(max_chars), _DEFAULT_MAX_CHARS)) + result = await asyncio.to_thread(_fetch_sync, url, timeout, safe_max_chars) + return json.dumps(result) diff --git a/agentflow/prebuilt/tools/files.py b/agentflow/prebuilt/tools/files.py new file mode 100644 index 0000000..6f485d1 --- /dev/null +++ b/agentflow/prebuilt/tools/files.py @@ -0,0 +1,293 @@ +"""Workspace-scoped file tools for AgentFlow agents.""" + +from __future__ import annotations + +import fnmatch +import json +from pathlib import Path +from typing import Any, Literal + +from agentflow.utils.decorators import tool + + +_DEFAULT_MAX_READ_CHARS = 20_000 +_DEFAULT_MAX_WRITE_CHARS = 200_000 +_DEFAULT_MAX_SEARCH_RESULTS = 20 +_MAX_SEARCH_FILE_SIZE = 1_000_000 +_MAX_SEARCH_PREVIEW_CHARS = 240 +_SEARCH_PREVIEW_ELLIPSIS_CHARS = 3 +_SKIP_DIRS = { + ".git", + ".hg", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".tox", + ".venv", + "__pycache__", + "build", + "dist", + "htmlcov", + "node_modules", + "venv", +} + + +def _configured_root(config: dict[str, Any] | None) -> Path: + cfg = config or {} + root = cfg.get("file_tool_root") or cfg.get("workspace_root") or "." + return Path(str(root)).expanduser().resolve() + + +def _resolve_under_root(path: str, root: Path) -> Path: + if not path or not path.strip(): + raise ValueError("path is required") + + candidate = Path(path).expanduser() + if not candidate.is_absolute(): + candidate = root / candidate + resolved = candidate.resolve() + + try: + resolved.relative_to(root) + except ValueError: + raise ValueError(f"path must stay within the configured root: {root}") from None + + return resolved + + +def _relative(path: Path, root: Path) -> str: + try: + return path.relative_to(root).as_posix() + except ValueError: + return path.as_posix() + + +def _is_probably_text(path: Path) -> bool: + try: + with path.open("rb") as handle: + chunk = handle.read(2048) + except OSError: + return False + return b"\x00" not in chunk + + +def _write_file_content( + target: Path, content: str, mode: Literal["create", "overwrite", "append"] +) -> None: + if mode == "append": + with target.open("a", encoding="utf-8") as handle: + handle.write(content) + return + if mode in {"overwrite", "create"}: + target.write_text(content, encoding="utf-8") + return + raise ValueError(f"unsupported mode: {mode}") + + +def _is_search_candidate(candidate: Path, glob: str) -> bool: + if any(part in _SKIP_DIRS for part in candidate.parts): + return False + if not candidate.is_file(): + return False + if candidate.stat().st_size > _MAX_SEARCH_FILE_SIZE: + return False + return glob == "**/*" or fnmatch.fnmatch(candidate.name, Path(glob).name) + + +def _trim_search_preview(line: str) -> str: + preview = line.strip() + if len(preview) > _MAX_SEARCH_PREVIEW_CHARS: + preview_limit = _MAX_SEARCH_PREVIEW_CHARS - _SEARCH_PREVIEW_ELLIPSIS_CHARS + return f"{preview[:preview_limit]}..." + return preview + + +def _append_content_matches( + results: list[dict[str, Any]], + candidate: Path, + relative_path: str, + query_lower: str, + limit: int, +) -> None: + try: + for line_no, line in enumerate( + candidate.read_text(encoding="utf-8", errors="replace").splitlines(), + start=1, + ): + if query_lower not in line.lower(): + continue + results.append( + { + "path": relative_path, + "match_type": "content", + "line": line_no, + "preview": _trim_search_preview(line), + } + ) + if len(results) >= limit: + break + except OSError: + return + + +@tool( + name="file_read", + description=( + "Read a UTF-8 text file from the configured workspace root. Supports optional " + "1-based start_line/end_line and truncates long output." + ), + tags=["file", "filesystem", "read"], + capabilities=["read_files"], +) +def file_read( + path: str, + start_line: int = 1, + end_line: int = 0, + max_chars: int = _DEFAULT_MAX_READ_CHARS, + config: dict[str, Any] | None = None, +) -> str: + """Read a workspace-scoped text file.""" + root = _configured_root(config) + try: + target = _resolve_under_root(path, root) + if not target.exists(): + return json.dumps({"error": "file does not exist", "path": _relative(target, root)}) + if not target.is_file(): + return json.dumps({"error": "path is not a file", "path": _relative(target, root)}) + if not _is_probably_text(target): + return json.dumps( + {"error": "file appears to be binary", "path": _relative(target, root)} + ) + + lines = target.read_text(encoding="utf-8", errors="replace").splitlines() + start = max(1, int(start_line)) + end = int(end_line) if end_line else len(lines) + if end < start: + return json.dumps({"error": "end_line must be greater than or equal to start_line"}) + + selected = lines[start - 1 : end] + text = "\n".join(selected) + limit = max(1, min(int(max_chars), _DEFAULT_MAX_READ_CHARS)) + truncated = len(text) > limit + if truncated: + text = text[:limit] + + return json.dumps( + { + "path": _relative(target, root), + "start_line": start, + "end_line": min(end, len(lines)), + "content": text, + "truncated": truncated, + } + ) + except Exception as exc: + return json.dumps({"error": str(exc)}) + + +@tool( + name="file_write", + description=( + "Write UTF-8 text to a file under the configured workspace root. Use mode='create' " + "to avoid overwriting, mode='overwrite' to replace, or mode='append' to append." + ), + tags=["file", "filesystem", "write"], + capabilities=["write_files"], +) +def file_write( + path: str, + content: str, + mode: Literal["create", "overwrite", "append"] = "create", + create_dirs: bool = False, + config: dict[str, Any] | None = None, +) -> str: + """Write UTF-8 text to a workspace-scoped file.""" + root = _configured_root(config) + try: + target = _resolve_under_root(path, root) + if len(content) > _DEFAULT_MAX_WRITE_CHARS: + result = {"error": "content is too large"} + elif target.exists() and not target.is_file(): + result = {"error": "path exists and is not a file", "path": _relative(target, root)} + elif target.exists() and mode == "create": + result = {"error": "file already exists", "path": _relative(target, root)} + elif mode not in {"create", "overwrite", "append"}: + result = {"error": f"unsupported mode: {mode}"} + elif not target.parent.exists() and not create_dirs: + result = {"error": "parent directory does not exist"} + else: + if not target.parent.exists(): + target.parent.mkdir(parents=True, exist_ok=True) + _write_file_content(target, content, mode) + result = { + "status": "written", + "path": _relative(target, root), + "bytes": len(content.encode("utf-8")), + "mode": mode, + } + return json.dumps(result) + except Exception as exc: + return json.dumps({"error": str(exc)}) + + +@tool( + name="file_search", + description=( + "Search text files under the configured workspace root by filename and content. " + "Returns relative paths, line numbers, and short previews." + ), + tags=["file", "filesystem", "search"], + capabilities=["read_files"], +) +def file_search( + query: str, + path: str = "", + glob: str = "**/*", + max_results: int = _DEFAULT_MAX_SEARCH_RESULTS, + config: dict[str, Any] | None = None, +) -> str: + """Search workspace-scoped text files by filename and content.""" + if not query: + return json.dumps({"error": "query is required"}) + + root = _configured_root(config) + try: + search_root = _resolve_under_root(path or ".", root) + if not search_root.exists(): + return json.dumps({"error": "search path does not exist"}) + + query_lower = query.lower() + limit = max(1, min(int(max_results), 100)) + results: list[dict[str, Any]] = [] + candidates = [search_root] if search_root.is_file() else search_root.rglob(glob) + + for candidate in candidates: + if len(results) >= limit: + break + if not _is_search_candidate(candidate, glob): + continue + + relative_path = _relative(candidate, root) + if query_lower in candidate.name.lower(): + results.append( + { + "path": relative_path, + "match_type": "filename", + "line": None, + "preview": candidate.name, + } + ) + if len(results) >= limit: + break + + if not _is_probably_text(candidate): + continue + + _append_content_matches(results, candidate, relative_path, query_lower, limit) + + return json.dumps( + {"query": query, "root": _relative(search_root, root), "results": results} + ) + except Exception as exc: + return json.dumps({"error": str(exc)}) diff --git a/agentflow/prebuilt/tools/memory.py b/agentflow/prebuilt/tools/memory.py new file mode 100644 index 0000000..3c79826 --- /dev/null +++ b/agentflow/prebuilt/tools/memory.py @@ -0,0 +1,374 @@ +"""Model-facing memory tools for AgentFlow agents. + +These are the tools that are registered with the agent's ToolNode and exposed +to the LLM. Lower-level helpers (``MemoryIntegration``, preload-node factory, +read-mode constants) live in ``agentflow.storage.store.long_term_memory``. + +Public API +---------- +``memory_tool`` + Legacy LLM-callable tool used by the ``MemoryIntegration`` / manual graph + wiring path. Supports ``search``, ``store``, ``update``, ``delete``. + +``make_user_memory_tool(memory_config)`` + Factory that returns the ``user_memory_tool`` used by + ``Agent(..., memory=MemoryConfig(...))``. Supports ``search`` and + ``remember``. + +``make_agent_memory_tool(memory_config)`` + Factory that returns the read-only ``agent_memory_tool`` used by + ``Agent(..., memory=MemoryConfig(...))``. Supports ``search`` only. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Callable +from typing import Any, Literal + +from injectq import Inject + +from agentflow.storage.store.base_store import BaseStore +from agentflow.storage.store.long_term_memory import ( + _do_write, + _flush_pending_writes, + _format_search_results, + _strip_thread_id, + _validate_memory_type, +) +from agentflow.utils.background_task_manager import BackgroundTaskManager +from agentflow.utils.decorators import tool + + +logger = logging.getLogger("agentflow.prebuilt.tools.memory") + + +# --------------------------------------------------------------------------- +# Scope helpers +# --------------------------------------------------------------------------- +# Used by both the user/agent memory tool factories below AND by the preload +# path in AgentMemoryMixin._build_memory_prompts. +# --------------------------------------------------------------------------- + + +def _memory_scope_store( + memory_config: Any, + scope_config: Any, + injected_store: BaseStore | None, +) -> BaseStore | None: + return scope_config.store or memory_config.store or injected_store + + +def _memory_scope_limit(memory_config: Any, scope_config: Any, limit: int | None) -> int: + return limit or scope_config.limit or memory_config.limit + + +def _memory_scope_score_threshold(memory_config: Any, scope_config: Any) -> float | None: + if scope_config.score_threshold is not None: + return scope_config.score_threshold + return memory_config.score_threshold + + +def _memory_scope_config( + runtime_config: dict[str, Any] | None, + memory_config: Any, + scope_config: Any, + *, + scope: Literal["user", "agent"], +) -> dict[str, Any]: + cfg: dict[str, Any] = { + **(memory_config.config or {}), + **(scope_config.config or {}), + **(runtime_config or {}), + } + + if scope == "user": + user_id = getattr(scope_config, "user_id", None) + if user_id: + cfg["user_id"] = user_id + return cfg + + agent_id = getattr(scope_config, "agent_id", None) + app_id = getattr(scope_config, "app_id", None) + if agent_id: + cfg["agent_id"] = agent_id + # Existing Qdrant-backed stores use ``thread_id`` as the secondary + # scope field, so agent memory maps agent identity there intentionally. + cfg["thread_id"] = agent_id + if app_id: + cfg["app_id"] = app_id + return cfg + + +def _memory_tool_metadata(scope: Literal["user", "agent"]) -> dict[str, Any]: + return { + "source": f"{scope}_memory_tool", + "scope": scope, + } + + +async def _search_scope_memory( + *, + store: BaseStore, + config: dict[str, Any], + query: str, + memory_type: str | None, + category: str | None, + limit: int, + score_threshold: float | None, + task_manager: BackgroundTaskManager | None, +) -> str: + await _flush_pending_writes(task_manager) + results = await store.asearch( + config, + query, + memory_type=_validate_memory_type(memory_type or config.get("memory_type", "episodic")), + category=category or config.get("category", "general"), + limit=limit, + score_threshold=score_threshold, + ) + return json.dumps(_format_search_results(results)) + + +# --------------------------------------------------------------------------- +# memory_tool - legacy LLM-callable tool (MemoryIntegration / manual wiring) +# --------------------------------------------------------------------------- + + +@tool( + name="memory_tool", + description=( + "Search, store, update or delete long-term memories. " + "Use action='search' with a query to recall relevant memories. " + "Use action='store' with content and a short snake_case memory_key " + "(e.g. 'user_name', 'favorite_language') to save new memories. " + "The system uses memory_key to detect duplicates — if a memory with the " + "same key already exists it will be updated automatically. " + "Use action='delete' with memory_id to remove memories." + ), + tags=["memory", "long_term_memory"], +) +async def memory_tool( # noqa: PLR0911, PLR0913 + action: Literal["search", "store", "update", "delete"] = "search", + content: str = "", + memory_key: str = "", + memory_id: str = "", + query: str = "", + memory_type: str | None = None, + category: str | None = None, + metadata: dict[str, Any] | None = None, + limit: int = 5, + score_threshold: float | None = None, + write_mode: Literal["merge", "replace"] = "merge", + # Injectable params (excluded from LLM schema automatically) + config: dict[str, Any] | None = None, + store: BaseStore | None = Inject[BaseStore], + task_manager: BackgroundTaskManager = Inject[BackgroundTaskManager], +) -> str: + """Search, store, update, or delete long-term memories.""" + if store is None: + return json.dumps({"error": "no memory store configured"}) + + cfg = config or {} + # Resolve memory_type and category from config if not explicitly provided. + resolved_memory_type = memory_type or cfg.get("memory_type", "episodic") + resolved_category = category or cfg.get("category", "general") + mem_type = _validate_memory_type(resolved_memory_type) + + # Inject memory_key into metadata so _do_write can find it. + if memory_key: + metadata = {**(metadata or {}), "memory_key": memory_key} + + # --- Validation --- + if action == "search" and not query: + return json.dumps({"error": "query is required for search"}) + if action == "store" and not content: + return json.dumps({"error": "content is required for store"}) + if action == "update" and not memory_id: + return json.dumps({"error": "memory_id is required for update"}) + if action == "update" and not content: + return json.dumps({"error": "content is required for update"}) + if action == "delete" and not memory_id: + return json.dumps({"error": "memory_id is required for delete"}) + + try: + # --- Read --- + if action == "search": + # Flush any in-flight background writes so the search sees the + # latest data (e.g. writes scheduled during a previous query). + await _flush_pending_writes(task_manager) + + # Search across ALL threads for the user — long-term memory + # is not scoped to a single conversation thread. + results = await store.asearch( + _strip_thread_id(cfg), + query, + memory_type=mem_type, + limit=limit, + score_threshold=score_threshold, + ) + return json.dumps(_format_search_results(results)) + + # --- Write (always async / background) --- + write_coro = _do_write( + store, + cfg, + action, + content, + memory_id, + mem_type, + resolved_category, + metadata, + write_mode, + ) + try: + task_manager.create_task( + write_coro, + name=f"memory_{action}_{memory_id or 'new'}", + ) + except Exception: + write_coro.close() + raise + return json.dumps({"status": "scheduled", "action": action}) + + except Exception as e: + logger.exception("memory_tool error (action=%s): %s", action, e) + return json.dumps({"error": str(e)}) + + +# --------------------------------------------------------------------------- +# Agent-level model-facing tools (Agent(memory=MemoryConfig(...)) path) +# --------------------------------------------------------------------------- + + +def make_user_memory_tool(memory_config: Any) -> Callable: + """Create the user-scoped model-facing memory tool for an Agent.""" + user_config = memory_config.user_memory + + @tool( + name="user_memory_tool", + description=( + "Search or remember user-scoped long-term memories. " + "Use action='search' with text to recall durable user facts. " + "Use action='remember' with text to save useful user facts or preferences. " + "The model does not provide memory identifiers." + ), + tags=["memory", "long_term_memory", "user_memory"], + ) + async def user_memory_tool( + action: Literal["search", "remember"] = "search", + text: str = "", + memory_type: str | None = None, + category: str | None = None, + limit: int | None = None, + config: dict[str, Any] | None = None, + store: BaseStore | None = Inject[BaseStore], + task_manager: BackgroundTaskManager = Inject[BackgroundTaskManager], + ) -> str: + if user_config is None or not user_config.enabled: + return json.dumps({"error": "user memory is disabled"}) + resolved_store = _memory_scope_store(memory_config, user_config, store) + if resolved_store is None: + return json.dumps({"error": "no user memory store configured"}) + + cfg = _memory_scope_config(config, memory_config, user_config, scope="user") + resolved_memory_type = memory_type or user_config.memory_type + resolved_category = category or user_config.category + resolved_limit = _memory_scope_limit(memory_config, user_config, limit) + score_threshold = _memory_scope_score_threshold(memory_config, user_config) + + if not text: + return json.dumps({"error": "text is required"}) + + try: + if action == "search": + return await _search_scope_memory( + store=resolved_store, + config=_strip_thread_id(cfg), + query=text, + memory_type=resolved_memory_type, + category=resolved_category, + limit=resolved_limit, + score_threshold=score_threshold, + task_manager=task_manager, + ) + + metadata = _memory_tool_metadata("user") + write_coro = _do_write( + resolved_store, + cfg, + "store", + text, + "", + _validate_memory_type(resolved_memory_type), + resolved_category, + metadata, + "merge", + ) + try: + task_manager.create_task( + write_coro, + name="user_memory_remember_new", + ) + except Exception: + write_coro.close() + raise + return json.dumps({"status": "scheduled", "action": "remember"}) + except Exception as e: + logger.exception("user_memory_tool error (action=%s): %s", action, e) + return json.dumps({"error": str(e)}) + + user_memory_tool.__name__ = "user_memory_tool" + user_memory_tool.__qualname__ = "user_memory_tool" + return user_memory_tool + + +def make_agent_memory_tool(memory_config: Any) -> Callable: + """Create the read-only agent-scoped model-facing memory tool for an Agent.""" + agent_config = memory_config.agent_memory + + @tool( + name="agent_memory_tool", + description=( + "Search read-only agent/app-scoped long-term memories. " + "This tool cannot write, update, or delete memory." + ), + tags=["memory", "long_term_memory", "agent_memory"], + ) + async def agent_memory_tool( + query: str, + memory_type: str | None = None, + category: str | None = None, + limit: int | None = None, + config: dict[str, Any] | None = None, + store: BaseStore | None = Inject[BaseStore], + task_manager: BackgroundTaskManager = Inject[BackgroundTaskManager], + ) -> str: + if agent_config is None or not agent_config.enabled: + return json.dumps({"error": "agent memory is disabled"}) + resolved_store = _memory_scope_store(memory_config, agent_config, store) + if resolved_store is None: + return json.dumps({"error": "no agent memory store configured"}) + if not query: + return json.dumps({"error": "query is required"}) + + cfg = _memory_scope_config(config, memory_config, agent_config, scope="agent") + try: + return await _search_scope_memory( + store=resolved_store, + config=cfg, + query=query, + memory_type=memory_type or agent_config.memory_type, + category=category or agent_config.category, + limit=_memory_scope_limit(memory_config, agent_config, limit), + score_threshold=_memory_scope_score_threshold(memory_config, agent_config), + task_manager=task_manager, + ) + except Exception as e: + logger.exception("agent_memory_tool error: %s", e) + return json.dumps({"error": str(e)}) + + agent_memory_tool.__name__ = "agent_memory_tool" + agent_memory_tool.__qualname__ = "agent_memory_tool" + return agent_memory_tool diff --git a/agentflow/prebuilt/tools/search.py b/agentflow/prebuilt/tools/search.py new file mode 100644 index 0000000..86dbf53 --- /dev/null +++ b/agentflow/prebuilt/tools/search.py @@ -0,0 +1,155 @@ +"""Google-backed search tools for AgentFlow agents.""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any + +from agentflow.utils.decorators import tool + + +_DEFAULT_MODEL = "gemini-2.5-flash" +_DEFAULT_MAX_CHARS = 20_000 + + +def _to_plain(value: Any) -> Any: + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, dict): + return {str(k): _to_plain(v) for k, v in value.items()} + if isinstance(value, list | tuple): + return [_to_plain(item) for item in value] + if hasattr(value, "model_dump"): + return _to_plain(value.model_dump()) + if hasattr(value, "to_json_dict"): + return _to_plain(value.to_json_dict()) + return str(value) + + +def _response_payload(response: Any, max_chars: int) -> dict[str, Any]: + text = getattr(response, "text", "") or "" + truncated = len(text) > max_chars + if truncated: + text = text[:max_chars] + + grounding_metadata = None + candidates = getattr(response, "candidates", None) or [] + if candidates: + grounding_metadata = _to_plain(getattr(candidates[0], "grounding_metadata", None)) + + return { + "content": text, + "grounding_metadata": grounding_metadata, + "truncated": truncated, + } + + +def _google_web_search_sync(query: str, model: str, max_chars: int) -> dict[str, Any]: + try: + from google import genai + from google.genai import types + except ImportError: + return { + "error": ( + "google-genai is required for google_web_search. " + "Install with: pip install 10xscale-agentflow[google-genai]" + ) + } + + client = genai.Client() + response = client.models.generate_content( + model=model, + contents=query, + config=types.GenerateContentConfig( + tools=[types.Tool(google_search=types.GoogleSearch())], + ), + ) + return _response_payload(response, max_chars) + + +def _vertex_ai_search_sync( + query: str, + datastore: str, + model: str, + max_chars: int, +) -> dict[str, Any]: + if not datastore: + return {"error": "datastore is required"} + try: + from google import genai + from google.genai import types + except ImportError: + return { + "error": ( + "google-genai is required for vertex_ai_search. " + "Install with: pip install 10xscale-agentflow[google-genai]" + ) + } + + client = genai.Client(http_options=types.HttpOptions(api_version="v1")) + response = client.models.generate_content( + model=model, + contents=query, + config=types.GenerateContentConfig( + tools=[ + types.Tool( + retrieval=types.Retrieval( + vertex_ai_search=types.VertexAISearch(datastore=datastore), + ) + ) + ], + ), + ) + return _response_payload(response, max_chars) + + +@tool( + name="google_web_search", + description=( + "Search the public web with Gemini Google Search grounding and return the grounded " + "answer plus grounding metadata." + ), + tags=["web", "search", "google"], + capabilities=["network_access"], +) +async def google_web_search( + query: str, + model: str = _DEFAULT_MODEL, + max_chars: int = _DEFAULT_MAX_CHARS, +) -> str: + """Search the public web with Gemini Google Search grounding.""" + if not query: + return json.dumps({"error": "query is required"}) + safe_max_chars = max(1, min(int(max_chars), _DEFAULT_MAX_CHARS)) + result = await asyncio.to_thread(_google_web_search_sync, query, model, safe_max_chars) + return json.dumps(result) + + +@tool( + name="vertex_ai_search", + description=( + "Search a configured Vertex AI Search datastore with Gemini grounding. " + "The datastore must be a full Vertex AI Search datastore resource path." + ), + tags=["search", "google", "vertex_ai"], + capabilities=["network_access"], +) +async def vertex_ai_search( + query: str, + datastore: str, + model: str = _DEFAULT_MODEL, + max_chars: int = _DEFAULT_MAX_CHARS, +) -> str: + """Search a Vertex AI Search datastore with Gemini grounding.""" + if not query: + return json.dumps({"error": "query is required"}) + safe_max_chars = max(1, min(int(max_chars), _DEFAULT_MAX_CHARS)) + result = await asyncio.to_thread( + _vertex_ai_search_sync, + query, + datastore, + model, + safe_max_chars, + ) + return json.dumps(result) diff --git a/agentflow/qa/__init__.py b/agentflow/qa/__init__.py index a480ad6..7218768 100644 --- a/agentflow/qa/__init__.py +++ b/agentflow/qa/__init__.py @@ -7,100 +7,89 @@ - ``agentflow.qa.evaluation`` — eval sets, criteria, runners, reporters, and results """ -from __future__ import annotations - -from importlib import import_module -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from . import evaluation, testing - - # Evaluation - from .evaluation import ( - AgentEvaluator, - BaseCriterion, - BatchSimulator, - Colors, - CompositeCriterion, - ConsoleReporter, - ContainsKeywordsCriterion, - ConversationScenario, - CriterionConfig, - CriterionResult, - EvalCase, - EvalCaseResult, - EvalConfig, - EvalFixtures, - EvalPlugin, - EvalPresets, - EvalReport, - EvalSet, - EvalSetBuilder, - EvalSummary, - EvalTestCase, - EvaluationRunner, - EventCollector, - ExactMatchCriterion, - ExecutionResult, - FactualAccuracyCriterion, - HallucinationCriterion, - HTMLReporter, - Invocation, - JSONReporter, - JUnitXMLReporter, - LLMCallerMixin, - LLMJudgeCriterion, - MatchType, - MessageContent, - NodeOrderMatchCriterion, - NodeResponseData, - PublisherCallback, - QuickEval, - ReporterConfig, - ReporterManager, - ReporterOutput, - ResponseMatchCriterion, - RougeMatchCriterion, - Rubric, - RubricBasedCriterion, - SafetyCriterion, - SessionInput, - SimulationGoalsCriterion, - SimulationResult, - StepType, - SyncCriterion, - ToolCall, - ToolNameMatchCriterion, - TrajectoryCollector, - TrajectoryMatchCriterion, - TrajectoryStep, - UserSimulator, - UserSimulatorConfig, - WeightedCriterion, - assert_criterion_passed, - assert_eval_passed, - create_eval_app, - create_simple_eval_set, - eval_test, - make_trajectory_callback, - parametrize_eval_cases, - print_report, - run_eval, - ) - - # Testing - from .testing import ( - InMemoryStore, - MockComposioAdapter, - MockLangChainAdapter, - MockMCPClient, - MockToolRegistry, - QuickTest, - TestAgent, - TestContext, - TestResult, - ) +from . import evaluation, testing +from .evaluation import ( + AgentEvaluator, + BaseCriterion, + BatchSimulator, + Colors, + CompositeCriterion, + ConsoleReporter, + ContainsKeywordsCriterion, + ConversationScenario, + CriterionConfig, + CriterionResult, + EvalCase, + EvalCaseResult, + EvalConfig, + EvalFixtures, + EvalPlugin, + EvalPresets, + EvalReport, + EvalSet, + EvalSetBuilder, + EvalSummary, + EvalTestCase, + EvaluationRunner, + EventCollector, + ExactMatchCriterion, + ExecutionResult, + FactualAccuracyCriterion, + HallucinationCriterion, + HTMLReporter, + Invocation, + JSONReporter, + JUnitXMLReporter, + LLMCallerMixin, + LLMJudgeCriterion, + MatchType, + MessageContent, + NodeOrderMatchCriterion, + NodeResponseData, + PublisherCallback, + QuickEval, + ReporterConfig, + ReporterManager, + ReporterOutput, + ResponseMatchCriterion, + RougeMatchCriterion, + Rubric, + RubricBasedCriterion, + SafetyCriterion, + SessionInput, + SimulationGoalsCriterion, + SimulationResult, + StepType, + SyncCriterion, + ToolCall, + ToolNameMatchCriterion, + TrajectoryCollector, + TrajectoryMatchCriterion, + TrajectoryStep, + UserSimulator, + UserSimulatorConfig, + WeightedCriterion, + assert_criterion_passed, + assert_eval_passed, + create_eval_app, + create_simple_eval_set, + eval_test, + make_trajectory_callback, + parametrize_eval_cases, + print_report, + run_eval, +) +from .testing import ( + InMemoryStore, + MockComposioAdapter, + MockLangChainAdapter, + MockMCPClient, + MockToolRegistry, + QuickTest, + TestAgent, + TestContext, + TestResult, +) __all__ = [ @@ -201,107 +190,3 @@ "run_eval", "testing", ] - -_LAZY_EXPORTS: dict[str, tuple[str, str | None]] = { - # Submodules - "evaluation": (".evaluation", None), - "testing": (".testing", None), - # Testing - "InMemoryStore": (".testing", "InMemoryStore"), - "MockComposioAdapter": (".testing", "MockComposioAdapter"), - "MockLangChainAdapter": (".testing", "MockLangChainAdapter"), - "MockMCPClient": (".testing", "MockMCPClient"), - "MockToolRegistry": (".testing", "MockToolRegistry"), - "QuickTest": (".testing", "QuickTest"), - "TestAgent": (".testing", "TestAgent"), - "TestContext": (".testing", "TestContext"), - "TestResult": (".testing", "TestResult"), - # Evaluation - "AgentEvaluator": (".evaluation", "AgentEvaluator"), - "BaseCriterion": (".evaluation", "BaseCriterion"), - "BatchSimulator": (".evaluation", "BatchSimulator"), - "Colors": (".evaluation", "Colors"), - "CompositeCriterion": (".evaluation", "CompositeCriterion"), - "ConsoleReporter": (".evaluation", "ConsoleReporter"), - "ContainsKeywordsCriterion": (".evaluation", "ContainsKeywordsCriterion"), - "ConversationScenario": (".evaluation", "ConversationScenario"), - "CriterionConfig": (".evaluation", "CriterionConfig"), - "CriterionResult": (".evaluation", "CriterionResult"), - "EvalCase": (".evaluation", "EvalCase"), - "EvalCaseResult": (".evaluation", "EvalCaseResult"), - "EvalConfig": (".evaluation", "EvalConfig"), - "EvalFixtures": (".evaluation", "EvalFixtures"), - "EvalPlugin": (".evaluation", "EvalPlugin"), - "EvalPresets": (".evaluation", "EvalPresets"), - "EvalReport": (".evaluation", "EvalReport"), - "EvalSet": (".evaluation", "EvalSet"), - "EvalSetBuilder": (".evaluation", "EvalSetBuilder"), - "EvalSummary": (".evaluation", "EvalSummary"), - "EvalTestCase": (".evaluation", "EvalTestCase"), - "EvaluationRunner": (".evaluation", "EvaluationRunner"), - "EventCollector": (".evaluation", "EventCollector"), - "ExactMatchCriterion": (".evaluation", "ExactMatchCriterion"), - "ExecutionResult": (".evaluation", "ExecutionResult"), - "FactualAccuracyCriterion": (".evaluation", "FactualAccuracyCriterion"), - "HallucinationCriterion": (".evaluation", "HallucinationCriterion"), - "HTMLReporter": (".evaluation", "HTMLReporter"), - "Invocation": (".evaluation", "Invocation"), - "JSONReporter": (".evaluation", "JSONReporter"), - "JUnitXMLReporter": (".evaluation", "JUnitXMLReporter"), - "LLMCallerMixin": (".evaluation", "LLMCallerMixin"), - "LLMJudgeCriterion": (".evaluation", "LLMJudgeCriterion"), - "MatchType": (".evaluation", "MatchType"), - "MessageContent": (".evaluation", "MessageContent"), - "NodeOrderMatchCriterion": (".evaluation", "NodeOrderMatchCriterion"), - "NodeResponseData": (".evaluation", "NodeResponseData"), - "PublisherCallback": (".evaluation", "PublisherCallback"), - "QuickEval": (".evaluation", "QuickEval"), - "ReporterConfig": (".evaluation", "ReporterConfig"), - "ReporterManager": (".evaluation", "ReporterManager"), - "ReporterOutput": (".evaluation", "ReporterOutput"), - "ResponseMatchCriterion": (".evaluation", "ResponseMatchCriterion"), - "RougeMatchCriterion": (".evaluation", "RougeMatchCriterion"), - "Rubric": (".evaluation", "Rubric"), - "RubricBasedCriterion": (".evaluation", "RubricBasedCriterion"), - "SafetyCriterion": (".evaluation", "SafetyCriterion"), - "SessionInput": (".evaluation", "SessionInput"), - "SimulationGoalsCriterion": (".evaluation", "SimulationGoalsCriterion"), - "SimulationResult": (".evaluation", "SimulationResult"), - "StepType": (".evaluation", "StepType"), - "SyncCriterion": (".evaluation", "SyncCriterion"), - "ToolCall": (".evaluation", "ToolCall"), - "ToolNameMatchCriterion": (".evaluation", "ToolNameMatchCriterion"), - "TrajectoryCollector": (".evaluation", "TrajectoryCollector"), - "TrajectoryMatchCriterion": (".evaluation", "TrajectoryMatchCriterion"), - "TrajectoryStep": (".evaluation", "TrajectoryStep"), - "UserSimulator": (".evaluation", "UserSimulator"), - "UserSimulatorConfig": (".evaluation", "UserSimulatorConfig"), - "WeightedCriterion": (".evaluation", "WeightedCriterion"), - "assert_criterion_passed": (".evaluation", "assert_criterion_passed"), - "assert_eval_passed": (".evaluation", "assert_eval_passed"), - "create_eval_app": (".evaluation", "create_eval_app"), - "create_simple_eval_set": (".evaluation", "create_simple_eval_set"), - "eval_test": (".evaluation", "eval_test"), - "make_trajectory_callback": (".evaluation", "make_trajectory_callback"), - "parametrize_eval_cases": (".evaluation", "parametrize_eval_cases"), - "print_report": (".evaluation", "print_report"), - "run_eval": (".evaluation", "run_eval"), -} - - -def __getattr__(name: str) -> Any: - """Lazily expose QA subpackages and all entry points.""" - try: - module_name, attribute_name = _LAZY_EXPORTS[name] - except KeyError as exc: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc - - module = import_module(module_name, __name__) - value = module if attribute_name is None else getattr(module, attribute_name) - globals()[name] = value - return value - - -def __dir__() -> list[str]: - """Return module attributes plus lazy exports for discovery.""" - return sorted(set(globals()) | set(__all__)) diff --git a/agentflow/qa/evaluation/config/__init__.py b/agentflow/qa/evaluation/config/__init__.py index 7c9a07f..aafaba8 100644 --- a/agentflow/qa/evaluation/config/__init__.py +++ b/agentflow/qa/evaluation/config/__init__.py @@ -5,7 +5,6 @@ Example: ```python - # All old imports still work unchanged from agentflow.evaluation.config import EvalConfig, CriterionConfig from agentflow.evaluation.config import EvalPresets, MatchType, Rubric diff --git a/agentflow/runtime/__init__.py b/agentflow/runtime/__init__.py index 471597f..be47caf 100644 --- a/agentflow/runtime/__init__.py +++ b/agentflow/runtime/__init__.py @@ -4,69 +4,60 @@ - ``agentflow.runtime.adapters`` - LLM response converters and third-party tool adapters - ``agentflow.runtime.publisher`` - event publishers (console, Redis, Kafka, RabbitMQ) -- ``agentflow.runtime.protocols`` - agent communication protocols (ACP, A2A) +- ``agentflow.runtime.protocols`` - agent communication protocol packages """ -from __future__ import annotations - -import importlib - - -_MODULE_EXPORTS = { - "adapters": ".adapters", - "protocols": ".protocols", - "publisher": ".publisher", -} - -_SYMBOL_EXPORTS = { - # Adapters: LLM - "BaseConverter": ".adapters.llm", - "ConverterType": ".adapters.llm", - "GoogleGenAIConverter": ".adapters.llm", - "OpenAIConverter": ".adapters.llm", - "OpenAIResponsesConverter": ".adapters.llm", - # Adapters: Tools - "ComposioAdapter": ".adapters.tools", - "LangChainAdapter": ".adapters.tools", - # Publisher - "BasePublisher": ".publisher", - "ConsolePublisher": ".publisher", - "ContentType": ".publisher", - "Event": ".publisher", - "EventModel": ".publisher", - "EventType": ".publisher", - "KafkaPublisher": ".publisher", - "RabbitMQPublisher": ".publisher", - "RedisPublisher": ".publisher", - "publish_event": ".publisher", - # Protocols: ACP - "ACPMessage": ".protocols.acp", - "ACPMessageType": ".protocols.acp", - "ACPProtocol": ".protocols.acp", - "MessageContent": ".protocols.acp", - "MessageContext": ".protocols.acp", -} - -__all__ = list(_MODULE_EXPORTS) + list(_SYMBOL_EXPORTS) - - -def __getattr__(name: str): - """Lazily load runtime exports so optional extras stay optional.""" - if name in _MODULE_EXPORTS: - module = importlib.import_module(_MODULE_EXPORTS[name], __name__) - globals()[name] = module - return module - - module_name = _SYMBOL_EXPORTS.get(name) - if module_name is None: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - module = importlib.import_module(module_name, __name__) - value = getattr(module, name) - globals()[name] = value - return value - - -def __dir__() -> list[str]: - """Return the package exports for interactive discovery.""" - return sorted(__all__) +from . import adapters, protocols, publisher +from .adapters.llm import ( + BaseConverter, + ConverterType, + GoogleGenAIConverter, + OpenAIConverter, + OpenAIResponsesConverter, +) +from .adapters.tools import ComposioAdapter, LangChainAdapter +from .publisher import ( + BasePublisher, + ConsolePublisher, + ContentType, + Event, + EventModel, + EventType, + KafkaPublisher, + RabbitMQPublisher, + RedisPublisher, + publish_event, +) + + +__all__ = [ + "AgentFlowExecutor", + # Adapters + "BaseConverter", + "BasePublisher", + "ComposioAdapter", + "ConsolePublisher", + "ContentType", + "ConverterType", + "Event", + "EventModel", + "EventType", + "GoogleGenAIConverter", + "KafkaPublisher", + "LangChainAdapter", + "OpenAIConverter", + "OpenAIResponsesConverter", + "RabbitMQPublisher", + "RedisPublisher", + "a2a", + "adapters", + "build_a2a_app", + "create_a2a_client_node", + "create_a2a_server", + "delegate_to_a2a_agent", + "make_agent_card", + # Protocols + "protocols", + "publish_event", + "publisher", +] diff --git a/agentflow/runtime/adapters/__init__.py b/agentflow/runtime/adapters/__init__.py index 36c28a0..d078968 100644 --- a/agentflow/runtime/adapters/__init__.py +++ b/agentflow/runtime/adapters/__init__.py @@ -6,3 +6,26 @@ Adapters expose registry-based discovery, function-calling schemas, and normalized execution for supported providers. """ + +from . import llm, tools +from .llm import ( + BaseConverter, + ConverterType, + GoogleGenAIConverter, + OpenAIConverter, + OpenAIResponsesConverter, +) +from .tools import ComposioAdapter, LangChainAdapter + + +__all__ = [ + "BaseConverter", + "ComposioAdapter", + "ConverterType", + "GoogleGenAIConverter", + "LangChainAdapter", + "OpenAIConverter", + "OpenAIResponsesConverter", + "llm", + "tools", +] diff --git a/agentflow/runtime/adapters/llm/__init__.py b/agentflow/runtime/adapters/llm/__init__.py index a79a94d..202ed09 100644 --- a/agentflow/runtime/adapters/llm/__init__.py +++ b/agentflow/runtime/adapters/llm/__init__.py @@ -1,14 +1,12 @@ """Integration adapters for optional third-party LLM SDKs. -This package exposes a small, stable surface for response converters without -eagerly importing every concrete implementation during package import. The -lazy behavior avoids import cycles with graph/runtime modules that reference -converter types during test collection. +This package exposes the concrete response converters used by Agentflow. """ -from __future__ import annotations - from .base_converter import BaseConverter, ConverterType +from .google_genai_converter import GoogleGenAIConverter +from .openai_converter import OpenAIConverter +from .openai_responses_converter import OpenAIResponsesConverter __all__ = [ @@ -18,19 +16,3 @@ "OpenAIConverter", "OpenAIResponsesConverter", ] - - -def __getattr__(name: str): - if name == "GoogleGenAIConverter": - from .google_genai_converter import GoogleGenAIConverter - - return GoogleGenAIConverter - if name == "OpenAIConverter": - from .openai_converter import OpenAIConverter - - return OpenAIConverter - if name == "OpenAIResponsesConverter": - from .openai_responses_converter import OpenAIResponsesConverter - - return OpenAIResponsesConverter - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/agentflow/runtime/protocols/__init__.py b/agentflow/runtime/protocols/__init__.py index 204fd26..c7914bb 100644 --- a/agentflow/runtime/protocols/__init__.py +++ b/agentflow/runtime/protocols/__init__.py @@ -1,39 +1,26 @@ """Agent communication protocols for Agentflow. -Protocols: -- ACP (Agent Communication Protocol) - standardized agent-to-agent messaging -- A2A - Google A2A SDK bridge (requires ``pip install 10xscale-agentflow[a2a_sdk]``) +Import protocol implementations from their concrete packages, such as +``agentflow.runtime.protocols.a2a``. """ -from __future__ import annotations - -import importlib - - -_SYMBOL_EXPORTS = { - "ACPMessage": ".acp", - "ACPMessageType": ".acp", - "ACPProtocol": ".acp", - "MessageContent": ".acp", - "MessageContext": ".acp", - "a2a": ".a2a", -} - -__all__ = list(_SYMBOL_EXPORTS) - - -def __getattr__(name: str): - """Lazily load protocol implementations.""" - module_name = _SYMBOL_EXPORTS.get(name) - if module_name is None: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - module = importlib.import_module(module_name, __name__) - value = module if name == "a2a" else getattr(module, name) - globals()[name] = value - return value - - -def __dir__() -> list[str]: - """Return the package exports for interactive discovery.""" - return sorted(__all__) +from . import a2a +from .a2a import ( + AgentFlowExecutor, + build_a2a_app, + create_a2a_client_node, + create_a2a_server, + delegate_to_a2a_agent, + make_agent_card, +) + + +__all__ = [ + "AgentFlowExecutor", + "a2a", + "build_a2a_app", + "create_a2a_client_node", + "create_a2a_server", + "delegate_to_a2a_agent", + "make_agent_card", +] diff --git a/agentflow/runtime/protocols/a2a/__init__.py b/agentflow/runtime/protocols/a2a/__init__.py index ef1cbb9..22697c2 100644 --- a/agentflow/runtime/protocols/a2a/__init__.py +++ b/agentflow/runtime/protocols/a2a/__init__.py @@ -21,59 +21,16 @@ from agentflow.runtime.protocols.a2a import delegate_to_a2a_agent """ -from __future__ import annotations - -import importlib - - -_SYMBOL_EXPORTS = { - "AgentFlowExecutor": ".executor", - "build_a2a_app": ".server", - "create_a2a_client_node": ".client", - "create_a2a_server": ".server", - "delegate_to_a2a_agent": ".client", - "make_agent_card": ".server", -} - -__all__ = list(_SYMBOL_EXPORTS) - - -def _raise_missing_a2a_dependency(exc: BaseException) -> None: - raise ImportError( - "agentflow.runtime.protocols.a2a requires the 'a2a-sdk' package. " - "Install it with: pip install 10xscale-agentflow[a2a_sdk]" - ) from exc - - -def _is_missing_a2a_dependency(exc: BaseException) -> bool: - if isinstance(exc, ModuleNotFoundError): - missing_name = exc.name or "" - return missing_name == "a2a" or missing_name.startswith("a2a.") - - if isinstance(exc, ImportError): - return "a2a" in str(exc) - - return False - - -def __getattr__(name: str): - """Lazily load A2A helpers so the SDK remains an optional extra.""" - module_name = _SYMBOL_EXPORTS.get(name) - if module_name is None: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - try: - module = importlib.import_module(module_name, __name__) - except (ImportError, ModuleNotFoundError) as exc: - if _is_missing_a2a_dependency(exc): - _raise_missing_a2a_dependency(exc) - raise - - value = getattr(module, name) - globals()[name] = value - return value - - -def __dir__() -> list[str]: - """Return the package exports for interactive discovery.""" - return sorted(__all__) +from .client import create_a2a_client_node, delegate_to_a2a_agent +from .executor import AgentFlowExecutor +from .server import build_a2a_app, create_a2a_server, make_agent_card + + +__all__ = [ + "AgentFlowExecutor", + "build_a2a_app", + "create_a2a_client_node", + "create_a2a_server", + "delegate_to_a2a_agent", + "make_agent_card", +] diff --git a/agentflow/runtime/protocols/a2a/executor.py b/agentflow/runtime/protocols/a2a/executor.py index dc43d7a..9e8b084 100644 --- a/agentflow/runtime/protocols/a2a/executor.py +++ b/agentflow/runtime/protocols/a2a/executor.py @@ -21,7 +21,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any from a2a.server.agent_execution import AgentExecutor from a2a.server.agent_execution.context import RequestContext @@ -29,7 +29,9 @@ from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import TaskState, TextPart -from agentflow.core.graph.compiled_graph import CompiledGraph + +if TYPE_CHECKING: + from agentflow.core.graph.compiled_graph import CompiledGraph from agentflow.core.state.message import Message as AFMessage from agentflow.core.state.stream_chunks import StreamEvent from agentflow.utils.constants import ResponseGranularity diff --git a/agentflow/runtime/publisher/__init__.py b/agentflow/runtime/publisher/__init__.py index 409d49a..18cc6ca 100644 --- a/agentflow/runtime/publisher/__init__.py +++ b/agentflow/runtime/publisher/__init__.py @@ -4,39 +4,24 @@ such as console, Redis, Kafka, and RabbitMQ. """ -from __future__ import annotations - -import importlib - - -_SYMBOL_EXPORTS = { - "BasePublisher": ".base_publisher", - "ConsolePublisher": ".console_publisher", - "ContentType": ".events", - "Event": ".events", - "EventModel": ".events", - "EventType": ".events", - "KafkaPublisher": ".kafka_publisher", - "RabbitMQPublisher": ".rabbitmq_publisher", - "RedisPublisher": ".redis_publisher", - "publish_event": ".publish", -} - -__all__ = list(_SYMBOL_EXPORTS) - - -def __getattr__(name: str): - """Lazily load publisher exports so optional dependencies stay optional.""" - module_name = _SYMBOL_EXPORTS.get(name) - if module_name is None: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - module = importlib.import_module(module_name, __name__) - value = getattr(module, name) - globals()[name] = value - return value - - -def __dir__() -> list[str]: - """Return the package exports for interactive discovery.""" - return sorted(__all__) +from .base_publisher import BasePublisher +from .console_publisher import ConsolePublisher +from .events import ContentType, Event, EventModel, EventType +from .kafka_publisher import KafkaPublisher +from .publish import publish_event +from .rabbitmq_publisher import RabbitMQPublisher +from .redis_publisher import RedisPublisher + + +__all__ = [ + "BasePublisher", + "ConsolePublisher", + "ContentType", + "Event", + "EventModel", + "EventType", + "KafkaPublisher", + "RabbitMQPublisher", + "RedisPublisher", + "publish_event", +] diff --git a/agentflow/storage/__init__.py b/agentflow/storage/__init__.py index 6980dc5..54fe30e 100644 --- a/agentflow/storage/__init__.py +++ b/agentflow/storage/__init__.py @@ -9,6 +9,9 @@ from __future__ import annotations +from importlib import import_module as _import_module +from typing import Any as _Any + # Import media first to avoid circular dependency: # storage → checkpointer → core.state → core.graph → utils → storage.checkpointer from . import checkpointer, media, store @@ -46,11 +49,13 @@ # --- Store (vector / long-term memory) --- from .store import ( DEFAULT_COLLECTION, + AgentMemoryConfig, BaseEmbedding, BaseStore, DistanceMetric, GoogleEmbedding, Mem0Store, + MemoryConfig, MemoryIntegration, MemoryRecord, MemorySearchResult, @@ -58,22 +63,42 @@ OpenAIEmbedding, QdrantStore, ReadMode, + UserMemoryConfig, create_cloud_qdrant_store, create_local_qdrant_store, create_mem0_store, create_mem0_store_with_qdrant, create_memory_preload_node, create_remote_qdrant_store, + get_agent_memory_system_prompt, get_memory_system_prompt, - memory_tool, ) +_MEMORY_TOOL_EXPORTS = { + "make_agent_memory_tool", + "make_user_memory_tool", + "memory_tool", +} + + +def __getattr__(name: str) -> _Any: + """Keep prebuilt memory tools lazy at the storage package boundary.""" + if name in _MEMORY_TOOL_EXPORTS: + memory_tools = _import_module("agentflow.prebuilt.tools.memory") + value = getattr(memory_tools, name) + globals()[name] = value + return value + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ # Store "DEFAULT_COLLECTION", # Media "GOOGLE_INLINE_THRESHOLD", + "AgentMemoryConfig", # Checkpointer "BaseCheckpointer", "BaseEmbedding", @@ -91,6 +116,7 @@ "MediaProcessor", "MediaRefResolver", "Mem0Store", + "MemoryConfig", "MemoryIntegration", "MemoryRecord", "MemorySearchResult", @@ -101,6 +127,7 @@ "ProviderMediaCache", "QdrantStore", "ReadMode", + "UserMemoryConfig", # Submodules "checkpointer", "create_cloud_qdrant_store", @@ -113,7 +140,10 @@ "create_remote_qdrant_store", "enforce_file_size", "ensure_media_offloaded", + "get_agent_memory_system_prompt", "get_memory_system_prompt", + "make_agent_memory_tool", + "make_user_memory_tool", "media", "memory_tool", "sanitize_filename", diff --git a/agentflow/storage/store/__init__.py b/agentflow/storage/store/__init__.py index 45c81e9..af076fe 100644 --- a/agentflow/storage/store/__init__.py +++ b/agentflow/storage/store/__init__.py @@ -1,17 +1,23 @@ +from __future__ import annotations + +from importlib import import_module as _import_module +from typing import Any as _Any + from .base_store import BaseStore from .embedding import BaseEmbedding, GoogleEmbedding, OpenAIEmbedding from .long_term_memory import ( MemoryIntegration, ReadMode, create_memory_preload_node, + get_agent_memory_system_prompt, get_memory_system_prompt, - memory_tool, ) from .mem0_store import ( Mem0Store, create_mem0_store, create_mem0_store_with_qdrant, ) +from .memory_config import AgentMemoryConfig, MemoryConfig, UserMemoryConfig from .qdrant_store import ( DEFAULT_COLLECTION, QdrantStore, @@ -22,6 +28,24 @@ from .store_schema import DistanceMetric, MemoryRecord, MemorySearchResult, MemoryType +_MEMORY_TOOL_EXPORTS = { + "make_agent_memory_tool", + "make_user_memory_tool", + "memory_tool", +} + + +def __getattr__(name: str) -> _Any: + """Resolve prebuilt memory tools only when callers ask for them.""" + if name in _MEMORY_TOOL_EXPORTS: + memory_tools = _import_module("agentflow.prebuilt.tools.memory") + value = getattr(memory_tools, name) + globals()[name] = value + return value + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "BaseEmbedding", "BaseStore", @@ -29,6 +53,7 @@ "DistanceMetric", "GoogleEmbedding", "Mem0Store", + "MemoryConfig", "MemoryIntegration", "MemoryRecord", "MemorySearchResult", @@ -36,12 +61,17 @@ "OpenAIEmbedding", "QdrantStore", "ReadMode", + "UserMemoryConfig", + "AgentMemoryConfig", "create_cloud_qdrant_store", "create_local_qdrant_store", "create_mem0_store", "create_mem0_store_with_qdrant", "create_memory_preload_node", "create_remote_qdrant_store", + "get_agent_memory_system_prompt", "get_memory_system_prompt", + "make_agent_memory_tool", + "make_user_memory_tool", "memory_tool", ] diff --git a/agentflow/storage/store/long_term_memory.py b/agentflow/storage/store/long_term_memory.py index f998c90..47ad2e1 100644 --- a/agentflow/storage/store/long_term_memory.py +++ b/agentflow/storage/store/long_term_memory.py @@ -56,19 +56,17 @@ from __future__ import annotations import asyncio -import json import logging from collections.abc import Callable -from enum import Enum +from enum import StrEnum from typing import TYPE_CHECKING, Any, Literal -from injectq import Inject, InjectQ +from injectq import InjectQ from agentflow.core.state import AgentState, Message from agentflow.storage.store.base_store import BaseStore from agentflow.storage.store.store_schema import MemorySearchResult, MemoryType from agentflow.utils.background_task_manager import BackgroundTaskManager -from agentflow.utils.decorators import tool if TYPE_CHECKING: @@ -80,7 +78,7 @@ _VALID_MEMORY_TYPES = {m.value for m in MemoryType} -class ReadMode(str, Enum): +class ReadMode(StrEnum): NO_RETRIEVAL = "no_retrieval" PRELOAD = "preload" POSTLOAD = "postload" @@ -304,105 +302,49 @@ async def _do_write( return {"error": f"unknown write action: {action}"} -# --------------------------------------------------------------------------- -# memory_tool - the LLM-callable tool -# --------------------------------------------------------------------------- - +def get_agent_memory_system_prompt(memory_config: Any) -> str: + """Build the system prompt fragment for ``Agent(memory=...)``.""" + lines: list[str] = ["[Long-term Memory]"] -@tool( - name="memory_tool", - description=( - "Search, store, update or delete long-term memories. " - "Use action='search' with a query to recall relevant memories. " - "Use action='store' with content and a short snake_case memory_key " - "(e.g. 'user_name', 'favorite_language') to save new memories. " - "The system uses memory_key to detect duplicates — if a memory with the " - "same key already exists it will be updated automatically. " - "Use action='delete' with memory_id to remove memories." - ), - tags=["memory", "long_term_memory"], -) -async def memory_tool( # noqa: PLR0911, PLR0913 - action: Literal["search", "store", "update", "delete"] = "search", - content: str = "", - memory_key: str = "", - memory_id: str = "", - query: str = "", - memory_type: str | None = None, - category: str | None = None, - metadata: dict[str, Any] | None = None, - limit: int = 5, - score_threshold: float | None = None, - write_mode: Literal["merge", "replace"] = "merge", - # Injectable params (excluded from LLM schema automatically) - config: dict[str, Any] | None = None, - store: BaseStore | None = Inject[BaseStore], - task_manager: BackgroundTaskManager = Inject[BackgroundTaskManager], -) -> str: - """Search, store, update, or delete long-term memories.""" - if store is None: - return json.dumps({"error": "no memory store configured"}) - - cfg = config or {} - # Resolve memory_type and category from config if not explicitly provided. - resolved_memory_type = memory_type or cfg.get("memory_type", "episodic") - resolved_category = category or cfg.get("category", "general") - mem_type = _validate_memory_type(resolved_memory_type) - - # Inject memory_key into metadata so _do_write can find it. - if memory_key: - metadata = {**(metadata or {}), "memory_key": memory_key} - - # --- Validation --- - if action == "search" and not query: - return json.dumps({"error": "query is required for search"}) - if action == "store" and not content: - return json.dumps({"error": "content is required for store"}) - if action == "update" and not memory_id: - return json.dumps({"error": "memory_id is required for update"}) - if action == "update" and not content: - return json.dumps({"error": "content is required for update"}) - if action == "delete" and not memory_id: - return json.dumps({"error": "memory_id is required for delete"}) + if memory_config.retrieval_mode == ReadMode.PRELOAD: + lines.extend( + [ + "Relevant long-term memories may be provided as system context before " + "the user message.", + "Use the provided memory context when it is relevant, but do not assume " + "it is exhaustive.", + ] + ) + return "\n".join(lines) + + if memory_config.retrieval_mode != ReadMode.POSTLOAD: + lines.append("Long-term memory retrieval tools are not available in this mode.") + return "\n".join(lines) + + if memory_config.user_memory and memory_config.user_memory.enabled: + lines.extend( + [ + "You have access to user_memory_tool for user-specific long-term memory.", + "- To recall durable user facts, call user_memory_tool with " + "action='search' and text.", + "- To save useful user facts or preferences, call user_memory_tool with " + "action='remember' and text.", + "- Do not invent memory identifiers; the memory layer manages identity " + "and deduplication.", + ] + ) - try: - # --- Read --- - if action == "search": - # Flush any in-flight background writes so the search sees the - # latest data (e.g. writes scheduled during a previous query). - await _flush_pending_writes(task_manager) - - # Search across ALL threads for the user — long-term memory - # is not scoped to a single conversation thread. - results = await store.asearch( - _strip_thread_id(cfg), - query, - memory_type=mem_type, - limit=limit, - score_threshold=score_threshold, - ) - return json.dumps(_format_search_results(results)) - - # --- Write (always async / background) --- - task_manager.create_task( - _do_write( - store, - cfg, - action, - content, - memory_id, - mem_type, - resolved_category, - metadata, - write_mode, - ), - name=f"memory_{action}_{memory_id or 'new'}", + if memory_config.agent_memory and memory_config.agent_memory.enabled: + lines.extend( + [ + "You have access to agent_memory_tool for read-only agent/app memory.", + "- To recall agent-level context, call agent_memory_tool with query.", + "- Agent memory is read-only in this agent flow; do not attempt to write, " + "update, or delete it.", + ] ) - return json.dumps({"status": "scheduled", "action": action}) - except Exception as e: - logger.exception("memory_tool error (action=%s): %s", action, e) - return json.dumps({"error": str(e)}) + return "\n".join(lines) # --------------------------------------------------------------------------- @@ -688,6 +630,10 @@ def tools(self) -> list[Callable]: All modes include ``memory_tool`` so the LLM can always decide to write. In *postload* mode the LLM also uses it for reads. """ + # Lazy import avoids a circular dependency: + # prebuilt.tools.memory → storage.store.long_term_memory + from agentflow.prebuilt.tools.memory import memory_tool + return [memory_tool] @property diff --git a/agentflow/storage/store/mem0_store.py b/agentflow/storage/store/mem0_store.py index 18fb5e5..9e0c246 100644 --- a/agentflow/storage/store/mem0_store.py +++ b/agentflow/storage/store/mem0_store.py @@ -21,9 +21,8 @@ avoid thread explosion. The store interprets the supplied ``config`` mapping passed to every method as: -``{"user_id": str | None, "thread_id": str | None, "app_id": str | None}``. -`thread_id` is stored into metadata under ``agent_id`` for backward compatibility -with earlier implementations where agent_id served a similar role. +``{"user_id": str | None, "app_id": str | None}``. Conversation-level scoping is +intentionally ignored so long-term memory retrieval and writes work across threads. Prerequisite: install mem0. ``` @@ -117,20 +116,15 @@ async def _get_client(self) -> AsyncMemory: # type: ignore # Internal helpers # --------------------------------------------------------------------- - def _extract_ids(self, config: dict[str, Any]) -> tuple[str, str | None, str | None]: - """Extract user_id, thread_id, app_id from per-call config with fallbacks.""" + def _extract_ids(self, config: dict[str, Any]) -> tuple[str, str | None]: + """Extract user_id and app_id from per-call config with fallbacks.""" user_id = config.get("user_id") - thread_id = config.get("thread_id") app_id = config.get("app_id") or self.app_id - # if user id and thread id are not provided, we cannot proceed if not user_id: raise ValueError("user_id must be provided in config") - if not thread_id: - raise ValueError("thread_id must be provided in config") - - return user_id, thread_id, app_id + return user_id, app_id def _create_result( self, @@ -164,7 +158,6 @@ def _create_result( memory_type=memory_type, metadata=metadata, user_id=user_id, - thread_id=metadata.get("run_id"), ) def _iter_results(self, response: Any) -> Iterable[dict[str, Any]]: @@ -196,7 +189,7 @@ async def astore( if not text.strip(): raise ValueError("Content cannot be empty") - user_id, thread_id, app_id = self._extract_ids(config) + user_id, app_id = self._extract_ids(config) mem_meta = { "memory_type": memory_type.value, @@ -212,12 +205,11 @@ async def astore( messages=[{"role": "user", "content": text}], user_id=user_id, agent_id=app_id, - run_id=thread_id, metadata=mem_meta, infer=infer, ) - logger.debug("Stored memory for user=%s thread=%s id=%s", user_id, thread_id, result) + logger.debug("Stored memory for user=%s app=%s id=%s", user_id, app_id, result) return result @@ -235,7 +227,7 @@ async def asearch( max_tokens: int = 4000, **kwargs: Any, ) -> list[MemorySearchResult]: - user_id, thread_id, app_id = self._extract_ids(config) + user_id, app_id = self._extract_ids(config) client = await self._get_client() result = await client.search( # type: ignore @@ -262,9 +254,9 @@ async def asearch( ] logger.debug( - "Searched memories for user=%s thread=%s query=%s found=%d", + "Searched memories for user=%s app=%s query=%s found=%d", user_id, - thread_id, + app_id, query, len(out), ) @@ -276,7 +268,7 @@ async def aget( memory_id: str, **kwargs: Any, ) -> MemorySearchResult | None: - user_id, _, _ = self._extract_ids(config) + user_id, _ = self._extract_ids(config) # If we stored mapping use that user id instead (authoritative) client = await self._get_client() @@ -292,7 +284,7 @@ async def aget_all( limit: int = 100, **kwargs: Any, ) -> list[MemorySearchResult]: - user_id, thread_id, app_id = self._extract_ids(config) + user_id, app_id = self._extract_ids(config) client = await self._get_client() result = await client.get_all( # type: ignore @@ -316,9 +308,9 @@ async def aget_all( ] logger.debug( - "Fetched all memories for user=%s thread=%s count=%d", + "Fetched all memories for user=%s app=%s count=%d", user_id, - thread_id, + app_id, len(out), ) return out @@ -356,7 +348,7 @@ async def adelete( memory_id: str, **kwargs: Any, ) -> Any: - user_id, _, _ = self._extract_ids(config) + user_id, _ = self._extract_ids(config) existing = await self.aget(config, memory_id) if not existing: logger.warning("Memory %s not found for deletion", memory_id) @@ -382,7 +374,7 @@ async def aforget_memory( **kwargs: Any, ) -> Any: # Delete all memories for a user - user_id, _, _ = self._extract_ids(config) + user_id, _ = self._extract_ids(config) client = await self._get_client() res = await client.delete_all(user_id=user_id) # type: ignore logger.debug("Forgot all memories for user %s", user_id) @@ -398,14 +390,12 @@ async def arelease(self) -> None: def create_mem0_store( config: dict[str, Any], user_id: str = "default_user", - thread_id: str | None = None, app_id: str = "agentflow_app", ) -> Mem0Store: """Factory for a basic Mem0 long-term store.""" return Mem0Store( config=config, default_user_id=user_id, - default_thread_id=thread_id, app_id=app_id, ) diff --git a/agentflow/storage/store/memory_config.py b/agentflow/storage/store/memory_config.py new file mode 100644 index 0000000..53a15e3 --- /dev/null +++ b/agentflow/storage/store/memory_config.py @@ -0,0 +1,132 @@ +"""Configuration models for agent-level long-term memory.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from agentflow.storage.store.base_store import BaseStore + +from .long_term_memory import ReadMode + + +class _MemoryScopeConfig(BaseModel): + """Shared settings for a memory scope.""" + + enabled: bool = True + store: BaseStore | None = None + memory_type: str = "episodic" + category: str = "general" + limit: int | None = None + score_threshold: float | None = None + config: dict[str, Any] = Field(default_factory=dict) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("memory_type", "category") + @classmethod + def _validate_non_empty(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError("memory scope values must not be empty") + return value + + @field_validator("limit") + @classmethod + def _validate_limit(cls, value: int | None) -> int | None: + if value is not None and value < 1: + raise ValueError("limit must be greater than zero") + return value + + @field_validator("score_threshold") + @classmethod + def _validate_score_threshold(cls, value: float | None) -> float | None: + if value is not None and value < 0: + raise ValueError("score_threshold must be non-negative") + return value + + +class UserMemoryConfig(_MemoryScopeConfig): + """User-scoped memory that the model may search and write.""" + + user_id: str | None = None + + @field_validator("user_id") + @classmethod + def _validate_user_id(cls, value: str | None) -> str | None: + if value is None: + return value + value = value.strip() + if not value: + raise ValueError("user_id must not be empty") + return value + + +class AgentMemoryConfig(_MemoryScopeConfig): + """Agent/app-scoped memory that the model may only search.""" + + enabled: bool = False + agent_id: str | None = None + app_id: str | None = None + + @field_validator("agent_id", "app_id") + @classmethod + def _validate_scope_id(cls, value: str | None) -> str | None: + if value is None: + return value + value = value.strip() + if not value: + raise ValueError("agent/app ids must not be empty") + return value + + +class MemoryConfig(BaseModel): + """Primary public configuration object for ``Agent(..., memory=...)``.""" + + store: BaseStore | None = None + retrieval_mode: ReadMode | str = ReadMode.POSTLOAD + limit: int = 5 + score_threshold: float = 0.0 + max_tokens: int | None = None + inject_system_prompt: bool = True + config: dict[str, Any] = Field(default_factory=dict) + user_memory: UserMemoryConfig | None = Field(default_factory=UserMemoryConfig) + agent_memory: AgentMemoryConfig | None = Field(default_factory=AgentMemoryConfig) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("retrieval_mode") + @classmethod + def _validate_retrieval_mode(cls, value: ReadMode | str) -> ReadMode: + if isinstance(value, ReadMode): + return value + return ReadMode(value) + + @field_validator("limit") + @classmethod + def _validate_limit(cls, value: int) -> int: + if value < 1: + raise ValueError("limit must be greater than zero") + return value + + @field_validator("score_threshold") + @classmethod + def _validate_score_threshold(cls, value: float) -> float: + if value < 0: + raise ValueError("score_threshold must be non-negative") + return value + + def model_facing_tools(self) -> list[Any]: + """Return the tools this memory config should expose to an Agent.""" + if self.retrieval_mode != ReadMode.POSTLOAD: + return [] + + from agentflow.prebuilt.tools.memory import make_agent_memory_tool, make_user_memory_tool + + tools: list[Any] = [] + if self.user_memory and self.user_memory.enabled: + tools.append(make_user_memory_tool(self)) + if self.agent_memory and self.agent_memory.enabled: + tools.append(make_agent_memory_tool(self)) + return tools diff --git a/examples/a2a_sdk/currency_agent_cli/client.py b/examples/a2a_sdk/currency_agent_cli/client.py index 174b155..2c3e166 100644 --- a/examples/a2a_sdk/currency_agent_cli/client.py +++ b/examples/a2a_sdk/currency_agent_cli/client.py @@ -11,7 +11,8 @@ import asyncio import uuid -from agentflow.a2a_integration.client import delegate_to_a2a_agent +from agentflow.runtime.protocols import delegate_to_a2a_agent + SERVER_URL = "http://localhost:10000" diff --git a/examples/a2a_sdk/currency_agent_cli/executor.py b/examples/a2a_sdk/currency_agent_cli/executor.py index fdc47a0..3b4b2e9 100644 --- a/examples/a2a_sdk/currency_agent_cli/executor.py +++ b/examples/a2a_sdk/currency_agent_cli/executor.py @@ -17,10 +17,11 @@ from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import TaskState, TextPart -from agentflow.a2a_integration.executor import AgentFlowExecutor from agentflow.core.state import Message as AFMessage +from agentflow.runtime.protocols.a2a.executor import AgentFlowExecutor from agentflow.utils.constants import ResponseGranularity + logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # diff --git a/examples/agent-class/graph.py b/examples/agent-class/graph.py index 82e67dc..eb0d73e 100644 --- a/examples/agent-class/graph.py +++ b/examples/agent-class/graph.py @@ -35,7 +35,7 @@ def get_weather( "content": "You are a helpful assistant, Help user queries effectively.", } ], - tool_node_name="TOOL", + tool_node="TOOL", ), ) graph.add_node("TOOL", tool_node) diff --git a/examples/handoff/handoff_multi_agent.py b/examples/handoff/handoff_multi_agent.py index d61e96a..af4dd95 100644 --- a/examples/handoff/handoff_multi_agent.py +++ b/examples/handoff/handoff_multi_agent.py @@ -120,7 +120,7 @@ def write_document( """, }, ], - tool_node_name="COORDINATOR_TOOLS", + tool_node="COORDINATOR_TOOLS", trim_context=True, ) @@ -141,7 +141,7 @@ def write_document( """, }, ], - tool_node_name="RESEARCHER_TOOLS", + tool_node="RESEARCHER_TOOLS", trim_context=True, ) @@ -161,7 +161,7 @@ def write_document( """, }, ], - tool_node_name="WRITER_TOOLS", + tool_node="WRITER_TOOLS", trim_context=True, ) diff --git a/examples/react/react_sync.py b/examples/react/react_sync.py index 2d78258..c495cd2 100644 --- a/examples/react/react_sync.py +++ b/examples/react/react_sync.py @@ -66,9 +66,9 @@ def get_weather( }, {"role": "user", "content": "Today Date is 2024-06-15"}, ], - tool_node_name="TOOL", trim_context=True, reasoning_config=True, + tool_node=tool_node, ) diff --git a/examples/react/react_sync_validation.py b/examples/react/react_sync_validation.py index c8e45aa..013c782 100644 --- a/examples/react/react_sync_validation.py +++ b/examples/react/react_sync_validation.py @@ -141,7 +141,7 @@ def get_weather( }, {"role": "user", "content": "Today Date is 2024-06-15"}, ], - tool_node_name="TOOL", + tool_node="TOOL", trim_context=True, reasoning_config=True, ) diff --git a/examples/skills/graph.py b/examples/skills/graph.py index 3607c40..2547254 100644 --- a/examples/skills/graph.py +++ b/examples/skills/graph.py @@ -31,7 +31,7 @@ from dotenv import load_dotenv -from agentflow.core.graph import Agent, StateGraph +from agentflow.core.graph import Agent, StateGraph, ToolNode from agentflow.core.state import AgentState, Message from agentflow.core.state.message_context_manager import MessageContextManager from agentflow.graph.skills import SkillConfig @@ -89,7 +89,7 @@ def get_weather(location: str) -> str: ), } ], - tools=[get_weather], # ← Add custom tools here (alongside skills) + tool_node=ToolNode([get_weather]), # ← Add custom tools here (alongside skills) skills=SkillConfig( skills_dir=SKILLS_DIR, inject_trigger_table=True, # auto-appends skill trigger table to system prompt @@ -99,8 +99,10 @@ def get_weather(location: str) -> str: ) # --------------------------------------------------------------------------- -# Tool node — use the public get_tool_node() method, not agent._tool_node. -# When skills are enabled, this ToolNode contains both: +# Tool node — when skills are enabled, the ToolNode passed to Agent gets the +# set_skill tool injected into it automatically. Use get_tool_node() to get +# the final ToolNode (including set_skill) to register as a graph node. +# It contains both: # 1. set_skill (auto-added by skills system) # 2. get_weather (our custom tool passed to Agent) # --------------------------------------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 70ea647..0617af6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -253,5 +253,6 @@ dev = [ "langchain-community>=0.3.29", "mem0ai>=0.1.117", "google-genai>=1.56.0", + "a2a_sdk", ] diff --git a/tests/graph/test_agent_integration.py b/tests/graph/test_agent_integration.py index b91bdc1..2ca48e9 100644 --- a/tests/graph/test_agent_integration.py +++ b/tests/graph/test_agent_integration.py @@ -27,15 +27,16 @@ def test_agent_can_be_added_to_graph(self): assert isinstance(graph.nodes["agent"].func, Agent) def test_agent_with_tools_can_be_added(self): - """Test that Agent with tools can be added to graph.""" + """Test that Agent with ToolNode can be added to graph.""" def test_tool(query: str) -> str: return f"Result for {query}" + tool_node = ToolNode([test_tool]) agent = Agent( model="gpt-4o-mini", system_prompt=[{"role": "system", "content": "You are a test assistant."}], - tools=[test_tool], + tool_node=tool_node, ) graph = StateGraph() diff --git a/tests/graph/test_agent_internal.py b/tests/graph/test_agent_internal.py index c06d512..6bcb82e 100644 --- a/tests/graph/test_agent_internal.py +++ b/tests/graph/test_agent_internal.py @@ -21,6 +21,13 @@ from agentflow.core.graph.agent import Agent from agentflow.core.graph.tool_node import ToolNode from agentflow.core.state import AgentState, Message +from agentflow.storage.store.base_store import BaseStore +from agentflow.storage.store.memory_config import ( + AgentMemoryConfig, + MemoryConfig, + UserMemoryConfig, +) +from agentflow.storage.store.store_schema import MemorySearchResult, MemoryType # ───────────────────────────────────────────────────────────────────────────── @@ -113,6 +120,32 @@ def _make_google_agent(model: str = "gemini-2.0-flash") -> Agent: return agent +class _FakeMemoryStore(BaseStore): + def __init__(self, results: list[MemorySearchResult]): + self.asearch_mock = AsyncMock(return_value=results) + + async def astore(self, *args, **kwargs): + return "stored" + + async def asearch(self, *args, **kwargs): + return await self.asearch_mock(*args, **kwargs) + + async def aget(self, *args, **kwargs): + return None + + async def aget_all(self, *args, **kwargs): + return [] + + async def aupdate(self, *args, **kwargs): + return None + + async def adelete(self, *args, **kwargs): + return None + + async def aforget_memory(self, *args, **kwargs): + return None + + # --------------------------------------------------------------------------- # Minimal fake google.genai.types objects used in conversion tests # --------------------------------------------------------------------------- @@ -960,20 +993,21 @@ async def test_unsupported_output_type_raises(self): class TestSetupTools: def test_none_returns_none(self): agent = _make_openai_agent() - agent.tools = None + agent.tool_node = None assert agent._setup_tools() is None def test_tool_node_instance_returned_as_is(self): agent = _make_openai_agent() tn = ToolNode([lambda x: x]) - agent.tools = tn + agent.tool_node = tn assert agent._setup_tools() is tn - def test_list_of_callables_converted_to_tool_node(self): + def test_str_sets_tool_node_name_returns_none(self): agent = _make_openai_agent() - agent.tools = [lambda x: x] + agent.tool_node = "TOOL" result = agent._setup_tools() - assert isinstance(result, ToolNode) + assert result is None + assert agent.tool_node_name == "TOOL" # ═════════════════════════════════════════════════════════════════════════════ @@ -1063,24 +1097,17 @@ def my_tool(x: str) -> str: """A test tool.""" return x - agent = _make_openai_agent(tools=[my_tool]) + agent = _make_openai_agent(tool_node=ToolNode([my_tool])) from injectq import InjectQ result = await agent._resolve_tools(InjectQ.get_instance()) assert isinstance(result, list) assert len(result) > 0 - async def test_named_node_not_found_returns_inline_tools(self): - def my_tool(x: str) -> str: - """A test tool.""" - return x - - agent = _make_openai_agent(tools=[my_tool], tool_node_name="nonexistent_node") + async def test_named_node_not_found_raises(self): + agent = _make_openai_agent(tool_node="nonexistent_node") from injectq import InjectQ - # Should fall back to inline tools without raising - result = await agent._resolve_tools(InjectQ.get_instance()) - assert isinstance(result, list) - # Inline tool is still returned even though named node is missing - assert len(result) > 0 + with pytest.raises(RuntimeError, match="ToolNode named 'nonexistent_node' was not found"): + await agent._resolve_tools(InjectQ.get_instance()) # ═════════════════════════════════════════════════════════════════════════════ @@ -1163,9 +1190,10 @@ def test_default_output_type_is_text(self): # ── misc attributes ──────────────────────────────────────────────────── - def test_tool_node_name_stored(self): - agent = _make_openai_agent(tool_node_name="my_tools") + def test_tool_node_str_sets_tool_node_name(self): + agent = _make_openai_agent(tool_node="my_tools") assert agent.tool_node_name == "my_tools" + assert agent._tool_node is None def test_extra_messages_stored(self): msg = Message.text_message("hi", role="user") @@ -1176,19 +1204,78 @@ def test_trim_context_stored(self): agent = _make_openai_agent(trim_context=True) assert agent.trim_context is True - def test_tools_list_creates_internal_tool_node(self): + def test_tool_node_instance_stored(self): + tn = ToolNode([lambda x: x]) + agent = _make_openai_agent(tool_node=tn) + assert agent._tool_node is tn + assert agent.tool_node_name is None + + def test_tool_node_none_gives_none_internal_tool_node(self): + agent = _make_openai_agent(tool_node=None) + assert agent._tool_node is None + assert agent.tool_node_name is None + + def test_memory_config_requires_existing_tool_node_for_postload(self): + memory = MemoryConfig(user_memory=UserMemoryConfig(user_id="u1")) + with pytest.raises(RuntimeError, match="Memory requires an existing ToolNode"): + _make_openai_agent(tool_node=None, memory=memory) + + def test_memory_config_adds_tools_to_existing_tool_node(self): + memory = MemoryConfig(user_memory=UserMemoryConfig(user_id="u1")) + tool_node = ToolNode([]) + agent = _make_openai_agent(tool_node=tool_node, memory=memory) + + assert agent._tool_node is tool_node + assert any("user_memory_tool" in p["content"] for p in agent.system_prompt) + assert "user_memory_tool" in agent._tool_node._funcs + assert agent.get_tool_node() is agent._tool_node + + def test_memory_config_preserves_existing_tools(self): def my_tool(x: str) -> str: - """tool.""" return x - agent = _make_openai_agent(tools=[my_tool]) - assert isinstance(agent._tool_node, ToolNode) + memory = MemoryConfig( + user_memory=UserMemoryConfig(), + agent_memory=AgentMemoryConfig(enabled=True, agent_id="agent-1"), + ) + agent = _make_openai_agent(tool_node=ToolNode([my_tool]), memory=memory) + + assert agent._tool_node is not None + assert set(agent._tool_node._funcs) == { + "my_tool", + "user_memory_tool", + "agent_memory_tool", + } + + @pytest.mark.asyncio + async def test_preload_memory_does_not_register_tools_and_builds_prompt(self): + store = _FakeMemoryStore( + [ + MemorySearchResult( + id="m1", + content="User prefers concise answers", + score=0.91, + memory_type=MemoryType.SEMANTIC, + ) + ] + ) + memory = MemoryConfig( + retrieval_mode="preload", + user_memory=UserMemoryConfig(store=store, user_id="u1", memory_type="semantic"), + ) + agent = _make_openai_agent(tool_node=None, memory=memory) - def test_tools_none_gives_none_internal_tool_node(self): - agent = _make_openai_agent(tools=None) assert agent._tool_node is None + assert not any("user_memory_tool" in p["content"] for p in agent.system_prompt) - def test_tools_tool_node_instance_reused(self): - tn = ToolNode([lambda x: x]) - agent = _make_openai_agent(tools=tn) - assert agent._tool_node is tn + prompts = await agent._build_memory_prompts( + AgentState(context=[Message.text_message("What do I like?", role="user")]), + {"user_id": "runtime-user", "thread_id": "runtime-thread"}, + ) + + assert len(prompts) == 1 + assert prompts[0]["role"] == "system" + assert "Long-term Memory Context" in prompts[0]["content"] + assert "User prefers concise answers" in prompts[0]["content"] + store.asearch_mock.assert_awaited_once() + assert store.asearch_mock.call_args.args[0] == {"user_id": "u1"} diff --git a/tests/prebuilt/test_tools.py b/tests/prebuilt/test_tools.py new file mode 100644 index 0000000..a72fe39 --- /dev/null +++ b/tests/prebuilt/test_tools.py @@ -0,0 +1,185 @@ +"""Tests for prebuilt AgentFlow tools.""" + +from __future__ import annotations + +import json +import sys +import types as pytypes +from types import SimpleNamespace + +import pytest + +from agentflow.core.graph.tool_node import ToolNode +from agentflow.prebuilt.tools import ( + fetch_url, + file_read, + file_search, + file_write, + google_web_search, + safe_calculator, + vertex_ai_search, +) +from agentflow.prebuilt.tools import fetch as fetch_module + + +def test_safe_calculator_evaluates_basic_math() -> None: + result = json.loads(safe_calculator("(2 + 3) * 4")) + + assert result == {"result": 20} + + +def test_safe_calculator_rejects_unsupported_expressions() -> None: + result = json.loads(safe_calculator("__import__('os').system('echo nope')")) + + assert "error" in result + + +def test_file_read_write_and_search_are_workspace_scoped(tmp_path) -> None: + config = {"file_tool_root": str(tmp_path)} + + write_result = json.loads( + file_write( + "notes/info.txt", + "hello agentflow\nsecond line", + create_dirs=True, + config=config, + ) + ) + assert write_result["status"] == "written" + + read_result = json.loads(file_read("notes/info.txt", start_line=1, end_line=1, config=config)) + assert read_result["content"] == "hello agentflow" + + search_result = json.loads(file_search("agentflow", path="notes", config=config)) + assert search_result["results"][0]["path"] == "notes/info.txt" + assert search_result["results"][0]["line"] == 1 + + blocked = json.loads(file_read("../outside.txt", config=config)) + assert "configured root" in blocked["error"] + + +@pytest.mark.asyncio +async def test_fetch_url_returns_normalized_html(monkeypatch) -> None: + class FakeResponse: + headers = {"content-type": "text/html; charset=utf-8"} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self, size: int) -> bytes: + return ( + b"

Hello

" + b"

AgentFlow

" + ) + + def getcode(self) -> int: + return 200 + + def geturl(self) -> str: + return "https://example.com/" + + monkeypatch.setattr(fetch_module, "_is_public_hostname", lambda hostname: True) + monkeypatch.setattr(fetch_module.request, "urlopen", lambda req, timeout: FakeResponse()) + + result = json.loads(await fetch_url("https://example.com")) + + assert result["status_code"] == 200 + assert "Hello" in result["content"] + assert "AgentFlow" in result["content"] + assert "skip" not in result["content"] + + +def _install_fake_google_genai(monkeypatch) -> None: + response = SimpleNamespace( + text="grounded answer", + candidates=[ + SimpleNamespace( + grounding_metadata=SimpleNamespace( + model_dump=lambda: {"sources": [{"uri": "https://example.com"}]} + ) + ) + ], + ) + + class FakeModels: + def generate_content(self, **kwargs): + return response + + class FakeClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.models = FakeModels() + + class SimpleType: + def __init__(self, **kwargs): + self.kwargs = kwargs + + google_mod = pytypes.ModuleType("google") + genai_mod = pytypes.ModuleType("google.genai") + genai_types_mod = pytypes.ModuleType("google.genai.types") + + genai_mod.Client = FakeClient + genai_types_mod.GenerateContentConfig = SimpleType + genai_types_mod.GoogleSearch = SimpleType + genai_types_mod.HttpOptions = SimpleType + genai_types_mod.Retrieval = SimpleType + genai_types_mod.Tool = SimpleType + genai_types_mod.VertexAISearch = SimpleType + genai_mod.types = genai_types_mod + google_mod.genai = genai_mod + + monkeypatch.setitem(sys.modules, "google", google_mod) + monkeypatch.setitem(sys.modules, "google.genai", genai_mod) + monkeypatch.setitem(sys.modules, "google.genai.types", genai_types_mod) + + +@pytest.mark.asyncio +async def test_google_web_search_uses_google_genai(monkeypatch) -> None: + _install_fake_google_genai(monkeypatch) + + result = json.loads(await google_web_search("latest agent frameworks")) + + assert result["content"] == "grounded answer" + assert result["grounding_metadata"]["sources"][0]["uri"] == "https://example.com" + + +@pytest.mark.asyncio +async def test_vertex_ai_search_requires_datastore() -> None: + result = json.loads(await vertex_ai_search("policies", datastore="")) + + assert result["error"] == "datastore is required" + + +@pytest.mark.asyncio +async def test_vertex_ai_search_uses_google_genai(monkeypatch) -> None: + _install_fake_google_genai(monkeypatch) + + result = json.loads( + await vertex_ai_search( + "policies", + datastore=( + "projects/demo/locations/global/collections/default_collection/" + "dataStores/policies" + ), + ) + ) + + assert result["content"] == "grounded answer" + + +@pytest.mark.asyncio +async def test_tool_node_hides_injected_file_config_from_schema() -> None: + node = ToolNode([file_read, file_write, file_search]) + + schemas = await node.all_tools() + params_by_name = { + item["function"]["name"]: item["function"]["parameters"]["properties"] + for item in schemas + } + + assert "config" not in params_by_name["file_read"] + assert "config" not in params_by_name["file_write"] + assert "config" not in params_by_name["file_search"] diff --git a/tests/store/test_long_term_memory.py b/tests/store/test_long_term_memory.py index 18a5552..568ab87 100644 --- a/tests/store/test_long_term_memory.py +++ b/tests/store/test_long_term_memory.py @@ -21,9 +21,19 @@ _validate_memory_type, create_memory_preload_node, get_memory_system_prompt, +) +from agentflow.prebuilt.tools.memory import ( + make_agent_memory_tool, + make_user_memory_tool, memory_tool, ) +from agentflow.storage.store.memory_config import ( + AgentMemoryConfig, + MemoryConfig, + UserMemoryConfig, +) from agentflow.storage.store.store_schema import MemorySearchResult, MemoryType +from agentflow.core.graph.tool_node import ToolNode # --------------------------------------------------------------------------- @@ -48,7 +58,13 @@ def mock_task_manager(): task = MagicMock(spec=asyncio.Task) task.done.return_value = False task.add_done_callback = MagicMock() - mgr.create_task = MagicMock(return_value=task) + + def _create_task(coro, *args, **kwargs): + if hasattr(coro, "close"): + coro.close() + return task + + mgr.create_task = MagicMock(side_effect=_create_task) mgr.get_task_count = MagicMock(return_value=0) mgr.wait_for_all = AsyncMock() return mgr @@ -399,6 +415,107 @@ async def test_search_exception_returns_error(self, mock_store, mock_task_manage assert "error" in result +# --------------------------------------------------------------------------- +# Agent-level memory tools +# --------------------------------------------------------------------------- + + +class TestAgentLevelMemoryTools: + @pytest.mark.asyncio + async def test_user_memory_tool_search_uses_cross_thread_config( + self, mock_store, mock_task_manager, sample_search_results, config + ): + memory = MemoryConfig(user_memory=UserMemoryConfig(user_id="u1")) + tool_fn = make_user_memory_tool(memory) + mock_store.asearch.return_value = sample_search_results + + result = json.loads( + await tool_fn( + action="search", + text="preferences", + store=mock_store, + task_manager=mock_task_manager, + config=config, + ) + ) + + assert len(result) == 2 + call_args = mock_store.asearch.call_args + assert call_args.args[0] == {"user_id": "u1"} + + @pytest.mark.asyncio + async def test_user_memory_tool_remember_schedules_write( + self, mock_store, mock_task_manager, config + ): + memory = MemoryConfig(user_memory=UserMemoryConfig(user_id="u1")) + tool_fn = make_user_memory_tool(memory) + + result = json.loads( + await tool_fn( + action="remember", + text="User prefers short answers", + store=mock_store, + task_manager=mock_task_manager, + config=config, + ) + ) + + assert result == {"status": "scheduled", "action": "remember"} + mock_task_manager.create_task.assert_called_once() + + @pytest.mark.asyncio + async def test_agent_memory_tool_search_is_read_only( + self, mock_store, mock_task_manager, sample_search_results, config + ): + memory = MemoryConfig( + agent_memory=AgentMemoryConfig(enabled=True, agent_id="agent-1") + ) + tool_fn = make_agent_memory_tool(memory) + mock_store.asearch.return_value = sample_search_results + + result = json.loads( + await tool_fn( + query="release policy", + store=mock_store, + task_manager=mock_task_manager, + config=config, + ) + ) + + assert len(result) == 2 + call_args = mock_store.asearch.call_args + assert call_args.args[0]["thread_id"] == "agent-1" + mock_task_manager.create_task.assert_not_called() + + @pytest.mark.asyncio + async def test_default_agent_tool_schemas_hide_admin_fields(self): + memory = MemoryConfig( + user_memory=UserMemoryConfig(), + agent_memory=AgentMemoryConfig(enabled=True), + ) + tools = ToolNode(memory.model_facing_tools()) + schemas = await tools.all_tools() + + schema_by_name = {s["function"]["name"]: s for s in schemas} + user_props = schema_by_name["user_memory_tool"]["function"]["parameters"]["properties"] + agent_props = schema_by_name["agent_memory_tool"]["function"]["parameters"]["properties"] + + assert set(user_props) == {"action", "text", "memory_type", "category", "limit"} + assert set(agent_props) == {"query", "memory_type", "category", "limit"} + assert "memory_id" not in user_props + assert "memory_key" not in user_props + assert "action" not in agent_props + + def test_preload_does_not_expose_model_facing_tools(self): + memory = MemoryConfig( + retrieval_mode="preload", + user_memory=UserMemoryConfig(), + agent_memory=AgentMemoryConfig(enabled=True), + ) + + assert memory.model_facing_tools() == [] + + # --------------------------------------------------------------------------- # _strip_thread_id # --------------------------------------------------------------------------- diff --git a/tests/store/test_mem0_store_async.py b/tests/store/test_mem0_store_async.py index bf3eee7..88924f1 100644 --- a/tests/store/test_mem0_store_async.py +++ b/tests/store/test_mem0_store_async.py @@ -14,9 +14,19 @@ class MockAsyncMem0: def __init__(self): self.items = [] # list[dict] self._id = 1 + self.add_calls = [] - async def add(self, messages, user_id, agent_id=None, run_id=None, metadata=None, **kwargs): + async def add(self, messages, user_id, agent_id=None, metadata=None, **kwargs): """Mock AsyncMemory.add method""" + self.add_calls.append( + { + "messages": messages, + "user_id": user_id, + "agent_id": agent_id, + "metadata": metadata, + "kwargs": kwargs, + } + ) text = messages[0]["content"] if messages else "" metadata = metadata or {} mem0_id = f"m{self._id}" @@ -32,7 +42,6 @@ async def add(self, messages, user_id, agent_id=None, run_id=None, metadata=None }, "user_id": user_id, "score": 0.9, - "run_id": run_id, "agent_id": agent_id } self.items.append(rec) @@ -102,17 +111,27 @@ async def mock_from_config(config): @pytest.mark.asyncio async def test_store_and_search(store): - # Provide both user_id and thread_id as required by _extract_ids - config = {"user_id": "u1", "thread_id": "t1"} + config = {"user_id": "u1"} mem_id = await store.astore(config, "Alice likes tea", memory_type=MemoryType.SEMANTIC) assert mem_id results = await store.asearch(config, "likes") assert results and results[0].content.startswith("Alice") + assert "run_id" not in store._client.add_calls[-1]["kwargs"] @pytest.mark.asyncio -async def test_get_update_delete(store): +async def test_store_ignores_thread_id_in_config(store): config = {"user_id": "u1", "thread_id": "t1"} + mem_id = await store.astore(config, "Alice likes coffee", memory_type=MemoryType.SEMANTIC) + assert mem_id + results = await store.asearch(config, "coffee") + assert results and results[0].content.startswith("Alice") + assert "run_id" not in store._client.add_calls[-1]["kwargs"] + + +@pytest.mark.asyncio +async def test_get_update_delete(store): + config = {"user_id": "u1"} mem_id = await store.astore(config, "Berlin is in Germany") got = await store.aget(config, mem_id["results"][0]["id"]) # Use actual returned ID assert got and got.content.startswith("Berlin") @@ -128,7 +147,7 @@ async def test_get_update_delete(store): @pytest.mark.asyncio async def test_batch_and_forget(store): - config = {"user_id": "u1", "thread_id": "t1"} + config = {"user_id": "u1"} # Note: batch_store is not implemented in the current Mem0Store, so let's test individual stores await store.astore(config, "A") await store.astore(config, "B") diff --git a/tests/test_skills.py b/tests/test_skills.py index 2ac822f..9c70f2b 100644 --- a/tests/test_skills.py +++ b/tests/test_skills.py @@ -907,19 +907,29 @@ def test_setup_skills_invalid_type_raises(self): with pytest.raises(TypeError, match="Expected SkillConfig"): mixin._setup_skills("not-a-config") - def test_setup_skills_creates_registry(self, tmp_path: Path): + def test_setup_skills_requires_existing_tool_node(self, tmp_path: Path): from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin _make_skill_dir(tmp_path, "alpha") - _make_skill_dir(tmp_path, "beta") mixin = AgentSkillsMixin() mixin._tool_node = None + with pytest.raises(RuntimeError, match="Skills require an existing ToolNode"): + mixin._setup_skills(SkillConfig(skills_dir=str(tmp_path))) + + def test_setup_skills_creates_registry_with_existing_tool_node(self, tmp_path: Path): + from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin + from agentflow.core.graph.tool_node import ToolNode + + _make_skill_dir(tmp_path, "alpha") + _make_skill_dir(tmp_path, "beta") + + mixin = AgentSkillsMixin() + mixin._tool_node = ToolNode([]) mixin._setup_skills(SkillConfig(skills_dir=str(tmp_path))) assert mixin._skills_registry is not None assert len(mixin._skills_registry) == 2 - # Tool node should have been created with set_skill tool assert mixin._tool_node is not None def test_setup_skills_adds_tool_to_existing_toolnode(self, tmp_path: Path): @@ -952,11 +962,12 @@ def test_build_skill_prompts_no_skills(self): def test_build_skill_prompts_appends_trigger_table(self, tmp_path: Path): from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin + from agentflow.core.graph.tool_node import ToolNode _make_skill_dir(tmp_path, "review", triggers=["review code"]) mixin = AgentSkillsMixin() - mixin._tool_node = None + mixin._tool_node = ToolNode([]) mixin._setup_skills(SkillConfig(skills_dir=str(tmp_path), inject_trigger_table=True)) base = [{"role": "system", "content": "Be helpful"}] @@ -969,11 +980,12 @@ def test_build_skill_prompts_appends_trigger_table(self, tmp_path: Path): def test_build_skill_prompts_no_trigger_table_when_disabled(self, tmp_path: Path): from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin + from agentflow.core.graph.tool_node import ToolNode _make_skill_dir(tmp_path, "review") mixin = AgentSkillsMixin() - mixin._tool_node = None + mixin._tool_node = ToolNode([]) mixin._setup_skills(SkillConfig(skills_dir=str(tmp_path), inject_trigger_table=False)) base = [{"role": "system", "content": "Be helpful"}] @@ -982,11 +994,12 @@ def test_build_skill_prompts_no_trigger_table_when_disabled(self, tmp_path: Path def test_build_skill_prompts_does_not_mutate_original(self, tmp_path: Path): from agentflow.core.graph.agent_internal.skills import AgentSkillsMixin + from agentflow.core.graph.tool_node import ToolNode _make_skill_dir(tmp_path, "review") mixin = AgentSkillsMixin() - mixin._tool_node = None + mixin._tool_node = ToolNode([]) mixin._setup_skills(SkillConfig(skills_dir=str(tmp_path), inject_trigger_table=True)) base = [{"role": "system", "content": "Be helpful"}] diff --git a/uv.lock b/uv.lock index 07d85c3..c187e6b 100644 --- a/uv.lock +++ b/uv.lock @@ -8,7 +8,7 @@ resolution-markers = [ [[package]] name = "10xscale-agentflow" -version = "0.6.8" +version = "0.7.0" source = { editable = "." } dependencies = [ { name = "injectq" },