diff --git a/README.md b/README.md index 98c68ce..c15d84c 100644 --- a/README.md +++ b/README.md @@ -6,72 +6,75 @@ A lightweight multimodal RAG (Retrieval-Augmented Generation) library that uses ## 🌟 Features -- **Vision-First Approach**: Documents processed as images using PyMuPDF, preserving visual information and formatting -- **No Vector Database Required**: Eliminates the complexity of embeddings and vector storage -- **Adaptive RAG Agent**: Single intelligent agent that dynamically plans tasks and selects relevant pages -- **Multi-Provider Support**: Works with OpenAI GPT-4V, Anthropic Claude, and OpenRouter -- **Modern CLI Interface**: Beautiful terminal UI built with Textual -- **Conversation Aware**: Maintains context across multiple queries -- **Pluggable Storage**: Local filesystem or in-memory storage backends +- **Vision-First Approach**: Documents processed as images using PyMuPDF, preserving visual information and formatting. +- **No Vector Database Required**: Eliminates the complexity of embeddings and vector storage. +- **Adaptive RAG Agent**: An intelligent agent that dynamically plans and executes tasks to answer your queries. +- **Multi-Provider Support**: Works with OpenAI GPT-4V, Anthropic Claude, and any model supported by OpenRouter. +- **Modern CLI Interface**: A beautiful and intuitive terminal UI built with Textual. +- **Conversation Aware**: Maintains context across multiple queries for a natural chat experience. +- **Pluggable Storage**: Supports local filesystem and in-memory storage backends. ## 🚀 Quick Start ### Installation ```bash -# use uv (recommended) +# Using uv (recommended) uv pip install docpixie -# or pip +# Using pip pip install docpixie ``` -Try the CLI: +Then, launch the CLI: ```bash docpixie ``` -### Basic Usage +### Basic Usage (as a library) ```python import asyncio from docpixie import DocPixie async def main(): - # Initialize with your API key + # Initialize DocPixie, which will use environment variables for API keys. + # For example, set OPENROUTER_API_KEY for the default OpenRouter provider. docpixie = DocPixie() - # Add a document + # Add a document to the system. document = await docpixie.add_document("path/to/your/document.pdf") print(f"Added document: {document.name}") - # Query the document + # Query the document with a question. result = await docpixie.query("What are the key findings?") print(f"Answer: {result.answer}") print(f"Pages used: {result.page_numbers}") -# Run the example +# Run the asynchronous main function. asyncio.run(main()) ``` ### Using the CLI -Start the interactive terminal interface: +Start the interactive terminal interface with a single command: ```bash docpixie ``` -The CLI provides: +The CLI provides a rich user experience with: - Interactive document chat -- Document management +- Document management (indexing, deletion) - Conversation history - Model configuration -- Command palette with shortcuts +- A command palette with shortcuts for all major actions ## 🛠️ Configuration -DocPixie uses environment variables for API key configuration: +DocPixie can be configured via environment variables or directly in code. + +### Environment Variables ```bash # For OpenAI (default) @@ -80,95 +83,94 @@ export OPENAI_API_KEY="your-openai-key" # For Anthropic Claude export ANTHROPIC_API_KEY="your-anthropic-key" -# For OpenRouter (supports many models) +# For OpenRouter (recommended for access to a wide range of models) export OPENROUTER_API_KEY="your-openrouter-key" ``` -You can also specify the provider: +### In Code + +You can also specify the provider and models directly when initializing `DocPixie`: ```python from docpixie import DocPixie, DocPixieConfig +# Example configuration for Anthropic's Claude 3 Opus config = DocPixieConfig( - provider="anthropic", # or "openai", "openrouter" + provider="anthropic", model="claude-3-opus-20240229", vision_model="claude-3-opus-20240229" ) +# Initialize DocPixie with the custom configuration docpixie = DocPixie(config=config) ``` ## 📚 Supported File Types -- **PDF files** (.pdf) - Full multipage support -- More file types coming soon +- **PDF files** (`.pdf`): Full multipage support. +- **Image files** (`.jpg`, `.jpeg`, `.png`, `.webp`): Each image is treated as a single-page document. +- More file types are coming soon! ## 🏗️ Architecture -DocPixie uses a clean, modular architecture: +DocPixie is built on a clean, modular architecture: ``` 📁 Core Components -├── 🧠 Adaptive RAG Agent - Dynamic task planning and execution -├── 👁️ Vision Processing - Document-to-image conversion via PyMuPDF -├── 🔌 Provider System - Unified interface for AI providers -├── 💾 Storage Backends - Local filesystem or in-memory storage -└── 🖥️ CLI Interface - Modern terminal UI with Textual +├── 🧠 Adaptive RAG Agent: Dynamically plans and executes tasks. +├── 👁️ Vision Processing: Converts documents to images using PyMuPDF. +├── 🔌 Provider System: A unified interface for different AI providers. +├── 💾 Storage Backends: Pluggable storage for local or in-memory data. +└── 🖥️ CLI Interface: A modern terminal UI powered by Textual. 📁 Processing Flow -1. Document → Images (PyMuPDF) -2. Vision-based summarization -3. Adaptive query processing -4. Intelligent page selection -5. Response synthesis +1. Document → Images (via PyMuPDF for PDFs) +2. Vision-based summarization of the document. +3. Adaptive query processing by the RAG agent. +4. Intelligent page selection using vision models. +5. Synthesis of the final response from task results. ``` -### Key Design Principles - -- **Provider-Agnostic**: Generic model configuration works across all providers -- **Image-Based Processing**: All documents converted to images, preserving visual context -- **Business Logic Separation**: Raw API operations separate from workflow logic -- **Adaptive Intelligence**: Single agent mode that dynamically adjusts based on findings - ## 🎯 Use Cases -- **Research & Analysis**: Query academic papers, reports, and research documents -- **Document Q&A**: Interactive questioning of PDFs, contracts, and manuals -- **Content Discovery**: Find specific information across large document collections -- **Visual Document Processing**: Handle documents with charts, diagrams, and complex layouts +- **Research & Analysis**: Query academic papers, reports, and research documents. +- **Document Q&A**: Interactively question PDFs, contracts, and manuals. +- **Content Discovery**: Find specific information across large document collections. +- **Visual Document Processing**: Handle documents with charts, diagrams, and complex layouts that traditional text-based RAG systems struggle with. ## 🌍 Environment Variables | Variable | Description | Default | |----------|-------------|---------| -| `OPENAI_API_KEY` | OpenAI API key | None | -| `ANTHROPIC_API_KEY` | Anthropic API key | None | -| `OPENROUTER_API_KEY` | OpenRouter API key | None | -| `DOCPIXIE_PROVIDER` | AI provider | `openai` | -| `DOCPIXIE_STORAGE_PATH` | Storage directory | `./docpixie_data` | -| `DOCPIXIE_JPEG_QUALITY` | Image quality (1-100) | `90` | +| `OPENAI_API_KEY` | Your OpenAI API key. | None | +| `ANTHROPIC_API_KEY` | Your Anthropic API key. | None | +| `OPENROUTER_API_KEY` | Your OpenRouter API key. | None | +| `DOCPIXIE_PROVIDER` | The AI provider to use. | `openai` | +| `DOCPIXIE_STORAGE_PATH` | The directory for local storage. | `./docpixie_data` | +| `DOCPIXIE_JPEG_QUALITY` | The image quality for JPEG conversion (1-100). | `90` | ## 📖 Documentation -- [Getting Started Guide](docs/getting-started.md) - Detailed examples and tutorials -- [CLI Tool Guide](docs/cli-tool.md) - Complete CLI documentation +For more detailed information, please refer to the docstrings within the source code. ## 🤝 Contributing -1. Fork the repository -2. Create a feature branch (`git checkout -b feature/amazing-feature`) -3. Commit your changes (`git commit -m 'Add amazing feature'`) -4. Push to the branch (`git push origin feature/amazing-feature`) -5. Open a Pull Request +We welcome contributions! Please follow these steps: + +1. Fork the repository. +2. Create a feature branch (`git checkout -b feature/your-amazing-feature`). +3. Commit your changes (`git commit -m 'Add your amazing feature'`). +4. Push to the branch (`git push origin feature/your-amazing-feature`). +5. Open a Pull Request. ## 📄 License -This project is licensed under the MIT License - see the LICENSE file for details. +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. ## 🙏 Acknowledgments -- Built with [PyMuPDF](https://pymupdf.readthedocs.io/) for PDF processing -- CLI powered by [Textual](https://textual.textualize.io/) -- Supports OpenAI, Anthropic, and OpenRouter APIs +- Built with [PyMuPDF](https://pymupdf.readthedocs.io/) for high-performance PDF processing. +- The beautiful CLI is powered by [Textual](https://textual.textualize.io/). +- Supports APIs from [OpenAI](https://openai.com/), [Anthropic](https://www.anthropic.com/), and [OpenRouter](https://openrouter.ai/). --- diff --git a/docpixie/ai/agent.py b/docpixie/ai/agent.py index 7076812..beafef2 100644 --- a/docpixie/ai/agent.py +++ b/docpixie/ai/agent.py @@ -1,6 +1,9 @@ """ -DocPixie Adaptive RAG Agent -Main orchestrator for vision-based document analysis with adaptive task planning +The DocPixie Adaptive RAG Agent. + +This module contains the main orchestrator for the vision-based document +analysis pipeline. The `PixieRAGAgent` coordinates various sub-components to +process user queries, plan and execute tasks, and synthesize responses. """ import time @@ -31,13 +34,29 @@ class PixieRAGAgent: """ - Adaptive RAG agent with vision-based page selection and dynamic task planning + An adaptive RAG agent with vision-based page selection and dynamic task planning. + + This agent is the core of the DocPixie system, orchestrating the entire + query processing pipeline. It uses a vision-first approach to analyze + documents, dynamically plans a series of tasks to answer a query, and can + adapt its plan based on new findings. Key features: - - Vision-first page selection (analyzes actual page images) - - Adaptive task planning (can modify plan based on findings) - - Single-mode operation (no Flash/Pro distinction) - - Conversation-aware query processing + - Vision-first page selection: Analyzes actual page images. + - Adaptive task planning: Can modify its plan based on findings. + - Single-mode operation: No distinction between "Flash" and "Pro" modes. + - Conversation-aware: Can process queries within a conversational context. + + Attributes: + provider: The AI provider for making API calls. + storage: The storage backend for accessing documents. + config: The DocPixie configuration object. + context_processor: Component for processing conversation history. + query_reformulator: Component for reformulating queries. + query_classifier: Component for classifying queries. + task_planner: Component for creating and updating task plans. + page_selector: Component for selecting relevant pages. + synthesizer: Component for synthesizing the final response. """ def __init__( @@ -46,11 +65,18 @@ def __init__( storage: BaseStorage, config: DocPixieConfig ): + """ + Initializes the PixieRAGAgent. + + Args: + provider: An instance of a `BaseProvider` subclass. + storage: An instance of a `BaseStorage` subclass. + config: The DocPixie configuration object. + """ self.provider = provider self.storage = storage self.config = config - # Initialize components self.context_processor = ContextProcessor(provider, config) self.query_reformulator = QueryReformulator(provider) self.query_classifier = QueryClassifier(provider) @@ -61,7 +87,7 @@ def __init__( logger.info("Initialized DocPixie RAG Agent") def _accumulate_cost(self, total_cost: float) -> float: - """Accumulate cost from last API call if available""" + """Accumulates the cost from the last API call, if available.""" if hasattr(self.provider, 'get_last_cost'): last_cost = self.provider.get_last_cost() if last_cost is not None: @@ -75,81 +101,63 @@ async def process_query( task_update_callback: Optional[Any] = None ) -> AgentQueryResult: """ - Process a user query with adaptive task planning and execution + Processes a user query through the adaptive RAG pipeline. Args: - query: User's question - conversation_history: Previous conversation context + query: The user's question. + conversation_history: The history of the current conversation. + task_update_callback: An optional callback function to receive + updates on the agent's progress. Returns: - AgentQueryResult with comprehensive response and metadata + An `AgentQueryResult` containing the response and metadata. """ start_time = time.time() - total_cost = 0.0 # Track total cost for this query + total_cost = 0.0 try: logger.info(f"Processing query: {query[:100]}...") - # Step 1: Context Processing (conversation summarization if needed) processed_context = "" - display_messages = conversation_history or [] - if conversation_history: - processed_context, display_messages = await self.context_processor.process_conversation_context( + processed_context, _ = await self.context_processor.process_conversation_context( conversation_history, query ) total_cost = self._accumulate_cost(total_cost) logger.info("Processed conversation context") - # Step 2: Query Reformulation (if conversation context exists) - reformulated_query = query - if conversation_history: - reformulated_query = await self.query_reformulator.reformulate_with_context( - query, processed_context - ) + reformulated_query = await self.query_reformulator.reformulate_with_context( + query, processed_context + ) if conversation_history else query + if query != reformulated_query: logger.info(f"Reformulated query: '{query}' → '{reformulated_query}'") - # Step 3: Query Classification classification = await self.query_classifier.classify_query(reformulated_query) total_cost = self._accumulate_cost(total_cost) logger.info(f"Query classification: {classification['reasoning']}") - # If query doesn't need documents, return direct answer if not classification["needs_documents"]: return self._create_direct_answer_result(query, classification["reasoning"], total_cost) - # Step 4: Get all available documents and pages documents = await self.storage.get_all_documents() - if not documents: logger.warning("No documents available for analysis") return self._create_no_documents_result(query) logger.info(f"Found {len(documents)} documents") - # Step 5: Task Planning + Document Selection (merged) task_plan = await self.task_planner.create_initial_plan(reformulated_query, documents) - - # Report initial task plan if task_update_callback: await task_update_callback('plan_created', task_plan) - # Step 6: Execute tasks adaptively task_results = await self._execute_adaptive_plan( task_plan, reformulated_query, documents, conversation_history, task_update_callback ) - - # Accumulate any costs from task execution total_cost = self._accumulate_cost(total_cost) - # Step 7: Synthesize final response final_answer = await self.synthesizer.synthesize_response(reformulated_query, task_results) - - # Step 8: Build final result processing_time = time.time() - start_time - all_selected_pages = [] - for result in task_results: - all_selected_pages.extend(result.selected_pages) + all_selected_pages = [page for result in task_results for page in result.selected_pages] result = AgentQueryResult( query=query, @@ -158,16 +166,14 @@ async def process_query( task_results=task_results, total_iterations=task_plan.current_iteration, processing_time_seconds=processing_time, - total_cost=total_cost # Always include cost, even if 0 + total_cost=total_cost ) - logger.info(f"Query processed successfully in {processing_time:.2f}s") return result except Exception as e: logger.error(f"Failed to process query: {e}") - processing_time = time.time() - start_time - return self._create_error_result(query, str(e), processing_time) + return self._create_error_result(query, str(e), time.time() - start_time) async def _execute_adaptive_plan( self, @@ -177,53 +183,38 @@ async def _execute_adaptive_plan( conversation_history: Optional[List[ConversationMessage]] = None, task_update_callback: Optional[Any] = None ) -> List[TaskResult]: - """Execute task plan with adaptive replanning""" + """Executes a task plan with adaptive replanning.""" task_results = [] iteration = 0 - - while (task_plan.has_pending_tasks() and - iteration < self.config.max_agent_iterations): - + while (task_plan.has_pending_tasks() and iteration < self.config.max_agent_iterations): iteration += 1 logger.info(f"Agent iteration {iteration}") - # Get next task to execute current_task = task_plan.get_next_pending_task() if not current_task: break logger.info(f"Executing task: {current_task.name}") current_task.status = TaskStatus.IN_PROGRESS - - # Report task starting if task_update_callback: await task_update_callback('task_started', {'task': current_task, 'plan': task_plan}) - # Execute the task task_result = await self._execute_single_task( current_task, documents, original_query, conversation_history, task_update_callback ) - # Mark task completed current_task.status = TaskStatus.COMPLETED task_results.append(task_result) - - logger.info(f"Task completed: {current_task.name} " - f"(analyzed {task_result.pages_analyzed} pages)") - - # Report task completion + logger.info(f"Task completed: {current_task.name} (analyzed {task_result.pages_analyzed} pages)") if task_update_callback: await task_update_callback('task_completed', {'task': current_task, 'result': task_result, 'plan': task_plan}) - # Update plan adaptively if there are still pending tasks if task_plan.has_pending_tasks(): logger.info("Checking if task plan needs updating...") old_task_count = len(task_plan.tasks) task_plan = await self.task_planner.update_plan( task_plan, task_result, original_query, documents ) - - # Report plan update if it changed if task_update_callback and len(task_plan.tasks) != old_task_count: await task_update_callback('plan_updated', task_plan) @@ -239,59 +230,31 @@ async def _execute_single_task( conversation_history: Optional[List[ConversationMessage]] = None, task_update_callback: Optional[Any] = None ) -> TaskResult: - """Execute a single task: document filtering + page selection + analysis""" + """Executes a single task, including page selection and analysis.""" try: - # Step 1: Filter pages to only the task's assigned document - task_pages = [] - if task.document: - # Find the document assigned to this task - task_doc = next((doc for doc in documents if doc.id == task.document), None) - if task_doc: - task_pages = task_doc.pages - logger.info(f"Task {task.name} assigned to document: {task_doc.name} ({len(task_pages)} pages)") - else: - logger.warning(f"Task {task.name} assigned to document {task.document} but document not found") + task_doc = next((doc for doc in documents if doc.id == task.document), None) + if task_doc: + task_pages = task_doc.pages + logger.info(f"Task {task.name} assigned to document: {task_doc.name} ({len(task_pages)} pages)") else: - # No specific document assigned - use all pages (fallback) - task_pages = [] - for doc in documents: - task_pages.extend(doc.pages) + task_pages = [page for doc in documents for page in doc.pages] logger.warning(f"Task {task.name} has no document assignment, using all pages") - # Step 2: Select relevant pages for this task selected_pages = await self.page_selector.select_pages_for_task( - query=task.name, - query_description=task.description, - task_pages=task_pages + query=task.name, query_description=task.description, task_pages=task_pages ) - logger.info(f"Selected {len(selected_pages)} pages for task: {task.name}") - - # Report page selection if task_update_callback: page_numbers = [p.page_number for p in selected_pages] await task_update_callback('pages_selected', {'task': task, 'page_numbers': page_numbers}) - # Step 3: Analyze selected pages to complete the task analysis = await self._analyze_pages_for_task( task, selected_pages, original_query, conversation_history ) - - # Step 4: Create task result - return TaskResult( - task=task, - selected_pages=selected_pages, - analysis=analysis - ) - + return TaskResult(task=task, selected_pages=selected_pages, analysis=analysis) except Exception as e: logger.error(f"Failed to execute task {task.name}: {e}") - # Return result with error message - return TaskResult( - task=task, - selected_pages=[], - analysis=f"Task execution failed: {e}" - ) + return TaskResult(task=task, selected_pages=[], analysis=f"Task execution failed: {e}") async def _analyze_pages_for_task( self, @@ -300,120 +263,69 @@ async def _analyze_pages_for_task( original_query: str, conversation_history: Optional[List[ConversationMessage]] = None ) -> str: - """Analyze selected pages to complete a specific task""" + """Analyzes a set of pages to complete a specific task.""" if not pages: return f"No relevant pages found for task: {task.name}" - try: - # Build memory summary from conversation if available memory_summary = self._build_memory_summary(conversation_history) - - # Create task processing prompt prompt = TASK_PROCESSING_PROMPT.format( task_description=task.description, - search_queries=task.description, # Use task description as query + search_queries=task.description, memory_summary=memory_summary ) - - # Build multimodal message with selected page images + user_content = [{"type": "text", "text": prompt}] + for i, page in enumerate(pages, 1): + user_content.extend([ + {"type": "image_path", "image_path": page.image_path, "detail": "high"}, + {"type": "text", "text": f"[Page {i} from document]"} + ]) messages = [ {"role": "system", "content": SYSTEM_DOCPIXIE}, - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - } - ] - } + {"role": "user", "content": user_content} ] - - # Add page images to message - for i, page in enumerate(pages, 1): - messages[1]["content"].extend([ - { - "type": "image_path", - "image_path": page.image_path, - "detail": "high" # Use high detail for task analysis - }, - { - "type": "text", - "text": f"[Page {i} from document]" - } - ]) - - # Process with vision model result = await self.provider.process_multimodal_messages( - messages=messages, - max_tokens=600, - temperature=0.3 + messages=messages, max_tokens=600, temperature=0.3 ) - return result.strip() - except Exception as e: logger.error(f"Failed to analyze pages for task {task.name}: {e}") return f"Page analysis failed for task {task.name}: {e}" - def _build_memory_summary( - self, - conversation_history: Optional[List[ConversationMessage]] - ) -> str: - """Build conversation memory summary for context""" - if not conversation_history or len(conversation_history) == 0: + def _build_memory_summary(self, conversation_history: Optional[List[ConversationMessage]]) -> str: + """Builds a summary of the recent conversation history.""" + if not conversation_history: return "CONVERSATION CONTEXT: This is the first query in the conversation." - - # Get last few messages for context - recent_messages = conversation_history[-4:] if len(conversation_history) > 4 else conversation_history - + recent_messages = conversation_history[-4:] context_parts = ["CONVERSATION CONTEXT:"] for msg in recent_messages: role = "User" if msg.role == "user" else "Assistant" - content = msg.content[:100] + "..." if len(msg.content) > 100 else msg.content + content = (msg.content[:100] + "...") if len(msg.content) > 100 else msg.content context_parts.append(f"- {role}: {content}") - return "\n".join(context_parts) def _create_no_documents_result(self, query: str) -> AgentQueryResult: - """Create result when no documents are available""" + """Creates a result for when no documents are available.""" return AgentQueryResult( query=query, answer="I don't have any documents to analyze. Please upload some documents first.", - selected_pages=[], - task_results=[], - total_iterations=0, - processing_time_seconds=0.0, - total_cost=0.0 # Always include cost, even if 0 + selected_pages=[], task_results=[], total_cost=0.0 ) - def _create_error_result( - self, - query: str, - error_message: str, - processing_time: float - ) -> AgentQueryResult: - """Create result when processing fails""" + def _create_error_result(self, query: str, error_message: str, processing_time: float) -> AgentQueryResult: + """Creates a result for when an error occurs.""" return AgentQueryResult( query=query, answer=f"I encountered an error while processing your query: {error_message}", - selected_pages=[], - task_results=[], - total_iterations=0, - processing_time_seconds=processing_time, - total_cost=0.0 # Always include cost, even if 0 + selected_pages=[], task_results=[], + processing_time_seconds=processing_time, total_cost=0.0 ) def _create_direct_answer_result(self, query: str, reasoning: str, total_cost: float = 0.0) -> AgentQueryResult: - """Create result when query doesn't need document analysis""" + """Creates a result for queries that don't require document analysis.""" return AgentQueryResult( query=query, answer=f"This query doesn't require document analysis. {reasoning}", - selected_pages=[], - task_results=[], - total_iterations=0, - processing_time_seconds=0.0, - total_cost=total_cost # Always include cost, even if 0 + selected_pages=[], task_results=[], total_cost=total_cost ) async def process_conversation_query( @@ -422,13 +334,12 @@ async def process_conversation_query( conversation_history: List[ConversationMessage] ) -> AgentQueryResult: """ - Process a query in conversation context - This is a convenience method that handles conversation-aware processing + A convenience method to process a query within a conversation. """ return await self.process_query(query, conversation_history) def get_agent_stats(self) -> Dict[str, Any]: - """Get agent configuration and statistics""" + """Returns statistics and configuration information about the agent.""" return { "provider": self.provider.__class__.__name__, "storage": self.storage.__class__.__name__, diff --git a/docpixie/ai/context_processor.py b/docpixie/ai/context_processor.py index b95d533..e0ec581 100644 --- a/docpixie/ai/context_processor.py +++ b/docpixie/ai/context_processor.py @@ -1,5 +1,9 @@ """ -Context Processor - Handles conversation history summarization and context building +Context Processor for handling conversation history. + +This module provides the `ContextProcessor` class, which is responsible for +managing and summarizing conversation history to create an optimized context +for the RAG agent. """ import logging @@ -16,15 +20,31 @@ class ContextProcessor: """ - Processes conversation history to create optimized context for RAG - - When conversation exceeds max_turns: - - Summarizes first turns_to_summarize turns - - Includes last turns_to_keep_full turns in full - - Creates condensed context for query reformulation + Processes conversation history to create an optimized context for the RAG agent. + + When a conversation exceeds a certain number of turns, this class summarizes + the earlier parts of the conversation while keeping the most recent turns + in full. This helps to maintain context without exceeding the token limits + of the language models. + + Attributes: + provider: The AI provider for making API calls. + max_turns_before_summary: The number of turns after which to start + summarizing the conversation. + turns_to_summarize: The number of initial turns to include in the + summary. + turns_to_keep_full: The number of recent turns to keep in their + entirety. """ def __init__(self, provider: BaseProvider, config: DocPixieConfig): + """ + Initializes the ContextProcessor. + + Args: + provider: An instance of a `BaseProvider` subclass. + config: The DocPixie configuration object. + """ self.provider = provider self.max_turns_before_summary = config.max_conversation_turns self.turns_to_summarize = config.turns_to_summarize @@ -36,58 +56,44 @@ async def process_conversation_context( current_query: str ) -> Tuple[str, List[ConversationMessage]]: """ - Process conversation history and return optimized context + Processes the conversation history and returns an optimized context. + + If the conversation is short, the full history is returned. If it is + long, a summary of the early conversation is generated and combined with + the recent messages. Args: - messages: List of conversation messages - current_query: The current user query + messages: A list of `ConversationMessage` objects. + current_query: The current user query. Returns: - Tuple of (processed_context_string, messages_for_display) + A tuple containing: + - A string with the processed context for the RAG agent. + - A list of `ConversationMessage` objects for display purposes. Raises: - ContextProcessingError: If context processing fails + ContextProcessingError: If the context processing fails. """ try: - # Calculate number of turns (1 turn = 1 user message + 1 assistant message) - turns = self._count_turns(messages) - - if turns <= self.max_turns_before_summary: - # No summarization needed - context = self._format_messages_as_context(messages) - return context, messages + if self._count_turns(messages) <= self.max_turns_before_summary: + return self._format_messages_as_context(messages), messages - logger.info(f"Conversation has {turns} turns, applying context summarization") + logger.info(f"Conversation has {self._count_turns(messages)} turns, applying context summarization") - # Split messages for summarization messages_to_summarize, messages_to_keep = self._split_messages_for_summary(messages) - - # Summarize the first part summary = await self._summarize_conversation_chunk(messages_to_summarize) - # Build final context - context_parts = [] - - # Add summary - context_parts.append(f"Previous Conversation Summary:\n{summary}\n") - - # Add recent messages in full + context_parts = [f"Previous Conversation Summary:\n{summary}\n"] if messages_to_keep: - context_parts.append("Recent Conversation:") - context_parts.append(self._format_messages_as_context(messages_to_keep)) - - # Add current query + context_parts.extend(["Recent Conversation:", self._format_messages_as_context(messages_to_keep)]) context_parts.append(f"\nCurrent Query: {current_query}") - final_context = "\n".join(context_parts) - # Create display messages (summary + recent) summary_message = ConversationMessage( role="system", content=f"[Conversation Summary of First {self.turns_to_summarize} Turns]\n{summary}" ) display_messages = [summary_message] + messages_to_keep - return final_context, display_messages except Exception as e: @@ -95,69 +101,50 @@ async def process_conversation_context( raise ContextProcessingError(f"Failed to process conversation context: {e}") def _count_turns(self, messages: List[ConversationMessage]) -> int: - """Count conversation turns (user messages only)""" - user_messages = sum(1 for msg in messages if msg.role == "user") - return user_messages + """Counts the number of turns in a conversation (one user message per turn).""" + return sum(1 for msg in messages if msg.role == "user") def _split_messages_for_summary( self, messages: List[ConversationMessage] ) -> Tuple[List[ConversationMessage], List[ConversationMessage]]: - """Split messages into parts to summarize and keep""" - # Find the split point based on turns + """Splits messages into a part to be summarized and a part to be kept.""" turn_count = 0 split_index = 0 - - for i in range(0, len(messages), 2): # Process in pairs + for i in range(0, len(messages), 2): if i + 1 < len(messages) and messages[i].role == "user": turn_count += 1 if turn_count == self.turns_to_summarize: - split_index = i + 2 # Include the assistant response + split_index = i + 2 break messages_to_summarize = messages[:split_index] messages_to_keep = messages[split_index:] - # Ensure we keep at most the last N turns if self.turns_to_keep_full > 0: - max_messages_to_keep = self.turns_to_keep_full * 2 # Each turn has 2 messages + max_messages_to_keep = self.turns_to_keep_full * 2 if len(messages_to_keep) > max_messages_to_keep: messages_to_keep = messages_to_keep[-max_messages_to_keep:] return messages_to_summarize, messages_to_keep def _format_messages_as_context(self, messages: List[ConversationMessage]) -> str: - """Format messages as readable context""" - formatted_parts = [] - - for msg in messages: - role = "User" if msg.role == "user" else "Assistant" - formatted_parts.append(f"{role}: {msg.content}") - - return "\n\n".join(formatted_parts) + """Formats a list of messages into a readable string.""" + return "\n\n".join(f"{'User' if msg.role == 'user' else 'Assistant'}: {msg.content}" for msg in messages) async def _summarize_conversation_chunk(self, messages: List[ConversationMessage]) -> str: - """Summarize a chunk of conversation""" + """Summarizes a chunk of conversation using the language model.""" try: conversation_text = self._format_messages_as_context(messages) - - prompt = CONVERSATION_SUMMARIZATION_PROMPT.format( - conversation_text=conversation_text - ) - + prompt = CONVERSATION_SUMMARIZATION_PROMPT.format(conversation_text=conversation_text) messages_for_api = [ {"role": "system", "content": "You are a helpful assistant that creates concise conversation summaries."}, {"role": "user", "content": prompt} ] - summary = await self.provider.process_text_messages( - messages=messages_for_api, - max_tokens=500, - temperature=0.3 + messages=messages_for_api, max_tokens=500, temperature=0.3 ) - return summary.strip() - except Exception as e: logger.error(f"Conversation summarization failed: {e}") raise ContextProcessingError(f"Failed to summarize conversation: {e}") diff --git a/docpixie/ai/page_selector.py b/docpixie/ai/page_selector.py index 1973f5e..b5fbebe 100644 --- a/docpixie/ai/page_selector.py +++ b/docpixie/ai/page_selector.py @@ -1,6 +1,10 @@ """ -Vision-based page selector for DocPixie RAG Agent -Selects relevant pages by analyzing page images directly with vision models +Vision-based page selector for the DocPixie RAG Agent. + +This module provides the `VisionPageSelector`, which is a core component of +the DocPixie system. It uses a vision-capable language model to analyze the +images of document pages directly, selecting the most relevant ones for a +given task. """ import json @@ -19,11 +23,26 @@ class VisionPageSelector: """ - Selects relevant document pages using vision model analysis - Key feature: Analyzes actual page IMAGES, not text summaries + Selects relevant document pages by analyzing their images with a vision model. + + This class implements a vision-first approach to page selection, which is a + key feature of the DocPixie agent. Instead of relying on text summaries, it + sends the actual page images to a multimodal model to determine their + relevance to a given task. + + Attributes: + provider: The AI provider for making API calls. + config: The DocPixie configuration object. """ def __init__(self, provider: BaseProvider, config: DocPixieConfig): + """ + Initializes the VisionPageSelector. + + Args: + provider: An instance of a `BaseProvider` subclass. + config: The DocPixie configuration object. + """ self.provider = provider self.config = config @@ -34,41 +53,31 @@ async def select_pages_for_task( task_pages: List[Page] ) -> List[Page]: """ - Select most relevant pages by analyzing page IMAGES with vision model + Selects the most relevant pages for a task by analyzing their images. Args: - query: The question/task to find pages for - task_pages: Pages from the task's assigned document + query: The question or task to find relevant pages for. + query_description: A more detailed description of the task. + task_pages: The list of pages to select from. Returns: - List of selected pages, ordered by relevance + A list of the selected `Page` objects, ordered by relevance. Raises: - PageSelectionError: If page selection fails + PageSelectionError: If the page selection process fails. """ if not task_pages: logger.warning("No pages provided for selection") return [] - try: logger.info(f"Selecting most relevant pages from {len(task_pages)} task pages") - - # Build vision-based selection message messages = self._build_vision_selection_messages(query, query_description, task_pages) - - # Use vision model to analyze page images and select best ones result = await self.provider.process_multimodal_messages( - messages=messages, - max_tokens=200, - temperature=0.1 # Low temperature for consistent selection + messages=messages, max_tokens=200, temperature=0.1 ) - - # Parse selection result selected_pages = self._parse_page_selection(result, task_pages) - logger.info(f"Successfully selected {len(selected_pages)} pages") return selected_pages - except Exception as e: logger.error(f"Vision page selection failed: {e}") raise PageSelectionError(f"Failed to select pages for task: {e}") @@ -80,75 +89,55 @@ def _build_vision_selection_messages( all_pages: List[Page] ) -> List[Dict[str, Any]]: """ - Build multimodal message with all page images for vision analysis - This is the key method that makes our system vision-first + Builds the multimodal message payload for the vision model. + + This method constructs a list of messages that includes the system + prompt, all the page images, and the user's query, formatted for a + multimodal API call. """ - messages = [ - { - "role": "system", - "content": SYSTEM_PAGE_SELECTOR - } - ] user_content = [] - # Add ALL page images to the message for vision analysis for i, page in enumerate(all_pages, 1): user_content.extend([ - { - "type": "image_path", - "image_path": page.image_path, - "detail": self.config.vision_detail - }, - { - "type": "text", - "text": f"[Page {i}]" - } + {"type": "image_path", "image_path": page.image_path, "detail": self.config.vision_detail}, + {"type": "text", "text": f"[Page {i}]"} ]) + user_content.append({ + "type": "text", + "text": VISION_PAGE_SELECTION_PROMPT.format(query=query, query_description=query_description) + }) + return [ + {"role": "system", "content": SYSTEM_PAGE_SELECTOR}, + {"role": "user", "content": user_content} + ] - user_content.append( - { - "type": "text", - "text": VISION_PAGE_SELECTION_PROMPT.format(query=query, query_description=query_description) - } - ) + def _parse_page_selection(self, result: str, all_pages: List[Page]) -> List[Page]: + """ + Parses the page selection response from the vision model. - messages.append( - { - "role": "user", - "content": user_content - } - ) + Args: + result: The JSON string response from the model. + all_pages: The original list of pages that were analyzed. - return messages + Returns: + A list of the selected `Page` objects. - def _parse_page_selection( - self, - result: str, - all_pages: List[Page] - ) -> List[Page]: - """ - Parse the vision model's page selection response + Raises: + PageSelectionError: If the response is not valid JSON or if no + valid pages are selected. """ try: - # Parse JSON response selection_data = json.loads(sanitize_llm_json(result)) selected_indices = selection_data.get("selected_pages", []) - - selected_pages = [] - for idx in selected_indices: - if isinstance(idx, int) and 1 <= idx <= len(all_pages): - page = all_pages[idx - 1] - selected_pages.append(page) - logger.debug(f"Selected page {idx}: {page.image_path}") - - # If no valid pages were selected, return empty list and raise error + selected_pages = [ + all_pages[idx - 1] + for idx in selected_indices + if isinstance(idx, int) and 1 <= idx <= len(all_pages) + ] if not selected_pages: logger.error("No valid pages selected by vision model") raise PageSelectionError("Vision model failed to select any valid pages") - return selected_pages - except (json.JSONDecodeError, KeyError, TypeError) as e: logger.error(f"Failed to parse page selection JSON: {e}") logger.debug(f"Raw vision model response: {result}") - raise PageSelectionError(f"Failed to parse vision model page selection response: {e}, raw response: \n{result}") diff --git a/docpixie/ai/query_classifier.py b/docpixie/ai/query_classifier.py index 4f7a55f..10a52ed 100644 --- a/docpixie/ai/query_classifier.py +++ b/docpixie/ai/query_classifier.py @@ -1,5 +1,9 @@ """ -Query Classifier - Determines if queries need document retrieval +Query Classifier for determining if a query requires document retrieval. + +This module provides the `QueryClassifier` class, which analyzes a user's +query to decide whether it can be answered directly or if it requires +access to the document store. """ import json @@ -15,64 +19,62 @@ class QueryClassifier: """ - Classifies queries to determine processing strategy + Classifies user queries to determine the appropriate processing strategy. + + The primary function of this class is to determine if a query necessitates + document retrieval and analysis. This allows the RAG agent to bypass the + retrieval and generation steps for simple, conversational queries. - Key classification: - - needs_documents: Whether query requires document retrieval + Attributes: + provider: The AI provider for making API calls. """ def __init__(self, provider: BaseProvider): + """ + Initializes the QueryClassifier. + + Args: + provider: An instance of a `BaseProvider` subclass. + """ self.provider = provider async def classify_query(self, query: str) -> dict: """ - Classify a query to determine processing approach + Classifies a query to determine if it requires document retrieval. Args: - query: The user's query (potentially reformulated) + query: The user's query, which may have been reformulated. Returns: - Dict with classification results: - { - "reasoning": "explanation", - "needs_documents": bool - } + A dictionary containing the classification results, with the keys + "reasoning" (a string explaining the decision) and "needs_documents" + (a boolean). Raises: - QueryClassificationError: If classification fails + QueryClassificationError: If the classification process fails or the + response from the language model is invalid. """ - result = None - try: - # Build classification prompt prompt = QUERY_CLASSIFICATION_PROMPT.format(query=query) - messages_for_api = [ {"role": "system", "content": SYSTEM_QUERY_CLASSIFIER}, {"role": "user", "content": prompt} ] - response = await self.provider.process_text_messages( - messages=messages_for_api, - max_tokens=1024, - temperature=0.1 + messages=messages_for_api, max_tokens=1024, temperature=0.1 ) - # Parse JSON response try: result = json.loads(sanitize_llm_json(response)) - - # Validate required fields if "reasoning" not in result or "needs_documents" not in result: raise QueryClassificationError( f"Missing required fields in classification response: {result}" ) - - logger.info(f"Query classified: needs_documents={result['needs_documents']}, " - f"reasoning='{result['reasoning']}'") - + logger.info( + f"Query classified: needs_documents={result['needs_documents']}, " + f"reasoning='{result['reasoning']}'" + ) return result - except json.JSONDecodeError as e: logger.error(f"Failed to parse classification JSON: {response}") raise QueryClassificationError(f"Invalid JSON response from classification: {e}") diff --git a/docpixie/ai/query_reformulator.py b/docpixie/ai/query_reformulator.py index 4c80cdd..81ccd45 100644 --- a/docpixie/ai/query_reformulator.py +++ b/docpixie/ai/query_reformulator.py @@ -1,5 +1,9 @@ """ -Query Reformulator - Creates optimized search queries from conversation context +Query Reformulator for creating optimized search queries. + +This module provides the `QueryReformulator` class, which takes a user's +query and the conversation context to create a new, standalone query that is +optimized for document retrieval. """ import json @@ -15,16 +19,30 @@ class QueryReformulator: """ - Reformulates queries by resolving references for better search + Reformulates user queries to be more effective for document retrieval. + + This class resolves ambiguities and references in a user's query by + incorporating the conversation context. For example, it can resolve + pronouns like "it" or "that" to their specific referents from earlier in + the conversation. + + The key objectives of the reformulation are: + - Resolve pronouns and other references. + - Keep the query concise and focused on the current user intent. + - Avoid combining multiple questions into one. + - Maintain a query length that is optimal for search. - Focuses on: - - Resolving pronouns and references (e.g., "it", "this", "that") - - Keeping queries concise and focused on current intent - - NOT combining multiple questions or intents - - Maintaining optimal length for search + Attributes: + provider: The AI provider for making API calls. """ def __init__(self, provider: BaseProvider): + """ + Initializes the QueryReformulator. + + Args: + provider: An instance of a `BaseProvider` subclass. + """ self.provider = provider async def reformulate_with_context( @@ -33,46 +51,39 @@ async def reformulate_with_context( conversation_context: str ) -> str: """ - Reformulate query by resolving references while keeping it concise + Reformulates a query using the conversation context. Args: - current_query: The current user query - conversation_context: Processed context from ContextProcessor + current_query: The user's most recent query. + conversation_context: A string representing the processed context + from the `ContextProcessor`. Returns: - Reformulated query with resolved references + The reformulated, standalone query. Raises: - QueryReformulationError: If reformulation fails + QueryReformulationError: If the reformulation process fails or the + response from the language model is invalid. """ try: - # Build prompt using existing template prompt = QUERY_REFORMULATION_PROMPT.format( conversation_context=conversation_context, - recent_topics="", # Let AI extract topics from context + recent_topics="", # Let the AI extract topics from the context current_query=current_query ) - messages_for_api = [ {"role": "system", "content": SYSTEM_QUERY_REFORMULATOR}, {"role": "user", "content": prompt} ] - response = await self.provider.process_text_messages( - messages=messages_for_api, - max_tokens=8192, - temperature=0.2 + messages=messages_for_api, max_tokens=8192, temperature=0.2 ) - # Parse JSON response - result = None try: result = json.loads(sanitize_llm_json(response)) reformulated = result.get("reformulated_query", current_query) - logger.info(f"Query reformulation: '{current_query}' → '{reformulated}'") return reformulated - except json.JSONDecodeError as e: logger.error(f"Failed to parse reformulation JSON: {response}") raise QueryReformulationError(f"Invalid JSON response from reformulation: {e}") diff --git a/docpixie/ai/synthesizer.py b/docpixie/ai/synthesizer.py index ef41f62..7e40876 100644 --- a/docpixie/ai/synthesizer.py +++ b/docpixie/ai/synthesizer.py @@ -1,6 +1,9 @@ """ -Response synthesizer for DocPixie RAG Agent -Combines multiple task results into coherent final answers +Response synthesizer for the DocPixie RAG Agent. + +This module provides the `ResponseSynthesizer` class, which is responsible +for combining the results of multiple agent tasks into a single, coherent, +and comprehensive final answer. """ import logging @@ -15,11 +18,23 @@ class ResponseSynthesizer: """ - Synthesizes multiple task results into a comprehensive final response - Key feature: Combines findings from different tasks into coherent narrative + Synthesizes multiple task results into a comprehensive final response. + + This class takes the findings from various agent tasks, each of which may + have focused on a different aspect of the user's query, and weaves them + together into a single, well-structured answer. + + Attributes: + provider: The AI provider for making API calls. """ def __init__(self, provider: BaseProvider): + """ + Initializes the ResponseSynthesizer. + + Args: + provider: An instance of a `BaseProvider` subclass. + """ self.provider = provider async def synthesize_response( @@ -28,14 +43,15 @@ async def synthesize_response( task_results: List[TaskResult] ) -> str: """ - Synthesize multiple task results into a final comprehensive response + Synthesizes a final response from a list of task results. Args: - original_query: The user's original question - task_results: List of completed task results to combine + original_query: The user's original, unmodified query. + task_results: A list of `TaskResult` objects from the agent's + executed plan. Returns: - Synthesized response that addresses the original query + A string containing the synthesized final answer. """ if not task_results: logger.warning("No task results provided for synthesis") @@ -43,123 +59,72 @@ async def synthesize_response( try: logger.info(f"Synthesizing response from {len(task_results)} task results") - - # Build results text from all task findings results_text = self._build_results_text(task_results) - - # Generate synthesis prompt - prompt = SYNTHESIS_PROMPT.format( - original_query=original_query, - results_text=results_text - ) - + prompt = SYNTHESIS_PROMPT.format(original_query=original_query, results_text=results_text) messages = [ {"role": "system", "content": SYSTEM_SYNTHESIS}, {"role": "user", "content": prompt} ] - - # Get synthesized response result = await self.provider.process_text_messages( - messages=messages, - max_tokens=2048, # Longer response for synthesis - temperature=0.2 # Low temperature for consistent synthesis + messages=messages, max_tokens=2048, temperature=0.2 ) - logger.info("Successfully synthesized final response") return result.strip() - except Exception as e: logger.error(f"Failed to synthesize response: {e}") - # Fallback: return basic combination of results return self._create_fallback_response(original_query, task_results) def _build_results_text(self, task_results: List[TaskResult]) -> str: - """Build formatted text from all task results""" - results_sections = [] - - for i, result in enumerate(task_results, 1): - section = f"""TASK {i}: {result.task.name} -Description: {result.task.description} -Analysis: {result.analysis} - ----""" - results_sections.append(section) - - return "\n".join(results_sections) - - def _create_fallback_response( - self, - original_query: str, - task_results: List[TaskResult] - ) -> str: - """Create a simple fallback response if synthesis fails""" + """Formats the results of all tasks into a single string.""" + return "\n".join([ + f"TASK {i}: {result.task.name}\nDescription: {result.task.description}\nAnalysis: {result.analysis}\n\n---" + for i, result in enumerate(task_results, 1) + ]) + + def _create_fallback_response(self, original_query: str, task_results: List[TaskResult]) -> str: + """Creates a simple, direct response if the synthesis process fails.""" logger.warning("Using fallback response synthesis") - response_parts = [ f"Based on my analysis of the documents, here's what I found regarding your query: {original_query}\n" ] - - for i, result in enumerate(task_results, 1): - response_parts.append(f"**{result.task.name}:**") - response_parts.append(result.analysis) - - if i < len(task_results): - response_parts.append("") # Add blank line between results - + for result in task_results: + response_parts.extend([f"**{result.task.name}:**", result.analysis, ""]) return "\n".join(response_parts) - async def synthesize_single_result( - self, - original_query: str, - task_result: TaskResult - ) -> str: + async def synthesize_single_result(self, original_query: str, task_result: TaskResult) -> str: """ - Handle synthesis for single task result (simpler case) + Handles the synthesis for a single task result. Args: - original_query: The user's original question - task_result: Single task result to present + original_query: The user's original query. + task_result: The single `TaskResult` to present. Returns: - Formatted response for single task + A formatted response for the single task. """ try: - # For single results, we can often just clean up the analysis - # But still use synthesis prompt for consistency return await self.synthesize_response(original_query, [task_result]) - except Exception as e: logger.error(f"Failed to synthesize single result: {e}") - - # Simple fallback for single result - response = f"Based on my analysis, here's what I found regarding your query:\n\n" - response += f"**{task_result.task.name}**\n{task_result.analysis}" - - return response + return ( + "Based on my analysis, here's what I found regarding your query:\n\n" + f"**{task_result.task.name}**\n{task_result.analysis}" + ) def validate_synthesis_quality(self, synthesized_response: str) -> bool: """ - Basic validation of synthesis quality + Performs a basic validation of the synthesis quality. Args: - synthesized_response: The synthesized response to validate + synthesized_response: The synthesized response to validate. Returns: - True if response meets basic quality criteria + `True` if the response meets basic quality criteria, `False` otherwise. """ - if not synthesized_response or not synthesized_response.strip(): + if not synthesized_response or not synthesized_response.strip() or len(synthesized_response.strip()) < 50: return False - - # Check minimum length (synthesis should be substantial) - if len(synthesized_response.strip()) < 50: - return False - - # Check it doesn't just repeat the prompt if "SYNTHESIS_PROMPT" in synthesized_response: return False - - # Check for basic structure indicators if "I couldn't find" in synthesized_response and len(synthesized_response) < 100: return False - return True diff --git a/docpixie/ai/task_planner.py b/docpixie/ai/task_planner.py index 4b704d9..e7b02c7 100644 --- a/docpixie/ai/task_planner.py +++ b/docpixie/ai/task_planner.py @@ -1,6 +1,9 @@ """ -Adaptive task planner for DocPixie RAG Agent -Creates and dynamically updates task plans based on agent findings +Adaptive task planner for the DocPixie RAG Agent. + +This module provides the `TaskPlanner` class, which is responsible for creating +an initial plan of action to answer a user's query and for dynamically +updating that plan based on the findings of the agent. """ import json @@ -24,11 +27,24 @@ class TaskPlanner: """ - Adaptive task planner that can create and modify task plans based on findings - Key feature: Agent can add/remove/modify tasks based on what it learns + An adaptive task planner that creates and modifies task plans for the agent. + + This class is a key component of the agent's adaptive behavior. It can + create an initial set of tasks to address a user's query and, more + importantly, it can revise the plan by adding, removing, or modifying tasks + based on the information gathered during the execution of previous tasks. + + Attributes: + provider: The AI provider for making API calls. """ def __init__(self, provider: BaseProvider): + """ + Initializes the TaskPlanner. + + Args: + provider: An instance of a `BaseProvider` subclass. + """ self.provider = provider async def create_initial_plan( @@ -37,58 +53,34 @@ async def create_initial_plan( documents: Optional[List[Document]] = None ) -> TaskPlan: """ - Create initial task plan from user query with document selection + Creates an initial task plan based on the user's query and available documents. Args: - query: User's question/request - documents: Available documents (required for document selection) + query: The user's question or request. + documents: A list of available documents for analysis. Returns: - TaskPlan with 2-4 initial tasks, each with assigned documents + A `TaskPlan` object containing an initial set of tasks. Raises: - TaskPlanningError: If task planning fails + TaskPlanningError: If the task planning process fails. """ try: logger.info(f"Creating initial task plan for query: {query[:50]}...") - - # Build context about available documents with full summaries - documents_text = "" - if documents: - doc_list = [] - for doc in documents: - summary = doc.summary or f"Document with {len(doc.pages)} pages" - doc_list.append(f"{doc.id}: {doc.name}\nSummary: {summary}") - documents_text = "\n\n".join(doc_list) - else: - documents_text = "No documents available" - - # Generate initial plan - prompt = ADAPTIVE_INITIAL_PLANNING_PROMPT.format( - query=query, - documents=documents_text - ) - + documents_text = self._format_documents_for_prompt(documents) + prompt = ADAPTIVE_INITIAL_PLANNING_PROMPT.format(query=query, documents=documents_text) messages = [ {"role": "system", "content": SYSTEM_ADAPTIVE_PLANNER}, {"role": "user", "content": prompt} ] - result = await self.provider.process_text_messages( - messages=messages, - max_tokens=8192, - temperature=0.3 + messages=messages, max_tokens=8192, temperature=0.3 ) - - # Parse and create task plan task_plan = self._parse_initial_plan(result, query, documents) - logger.info(f"Created initial plan with {len(task_plan.tasks)} tasks") for task in task_plan.tasks: logger.debug(f"Task: {task.name} - Document: {task.document}") - return task_plan - except Exception as e: logger.error(f"Failed to create initial plan: {e}") raise TaskPlanningError(f"Failed to create initial task plan: {e}") @@ -101,40 +93,23 @@ async def update_plan( documents: Optional[List[Document]] = None ) -> TaskPlan: """ - Adaptively update task plan based on latest findings - This is the key adaptive feature - agent can modify its own plan + Adaptively updates the task plan based on the latest findings. Args: - current_plan: Current task plan - latest_result: Result from the task just completed - original_query: Original user query for context - documents: Available documents (for new task assignments) + current_plan: The current `TaskPlan`. + latest_result: The result from the most recently completed task. + original_query: The original user query for context. + documents: The list of available documents. Returns: - Updated task plan (may have added/removed/modified tasks) + An updated `TaskPlan` object. """ result = None try: logger.info(f"Updating task plan after completing: {latest_result.task.name}") - - # Build current plan status plan_status = self._build_plan_status(current_plan) - - # Build progress summary from completed tasks progress_summary = self._build_progress_summary(current_plan, latest_result) - - # Build available documents text with full summaries - available_documents = "" - if documents: - doc_list = [] - for doc in documents: - summary = doc.summary or f"Document with {len(doc.pages)} pages" - doc_list.append(f"{doc.id}: {doc.name}\nSummary: {summary}") - available_documents = "\n\n".join(doc_list) - else: - available_documents = "No documents available" - - # Ask agent to evaluate and update plan + available_documents = self._format_documents_for_prompt(documents) prompt = ADAPTIVE_PLAN_UPDATE_PROMPT.format( original_query=original_query, available_documents=available_documents, @@ -143,144 +118,95 @@ async def update_plan( task_findings=latest_result.analysis, progress_summary=progress_summary ) - messages = [ {"role": "system", "content": SYSTEM_ADAPTIVE_PLANNER}, {"role": "user", "content": prompt} ] - result = await self.provider.process_text_messages( - messages=messages, - max_tokens=8192, - temperature=0.3 + messages=messages, max_tokens=8192, temperature=0.3 ) - - # Apply plan updates - updated_plan = self._apply_plan_updates(current_plan, result, latest_result) - + updated_plan = self._apply_plan_updates(current_plan, result) logger.info(f"Plan updated - now has {len(updated_plan.tasks)} tasks") return updated_plan - except Exception as e: logger.error(f"Failed to update plan: {e}") raise TaskPlanningError(f"Failed to update task plan: {e}. \nRaw response: \n{result}") def _parse_initial_plan(self, result: str, query: str, documents: Optional[List[Document]] = None) -> TaskPlan: - """Parse initial planning response and create TaskPlan with document assignments""" + """Parses the initial planning response and creates a TaskPlan.""" try: plan_data = json.loads(sanitize_llm_json(result)) - tasks = [] - - # Create map of available document IDs for validation - valid_doc_ids = set() - if documents: - valid_doc_ids = {doc.id for doc in documents} - - for task_data in plan_data.get("tasks", []): - # Parse and validate single document assignment - assigned_doc = task_data.get("document", "") - valid_assigned_doc = assigned_doc if assigned_doc in valid_doc_ids else "" - - task = AgentTask( + valid_doc_ids = {doc.id for doc in documents} if documents else set() + tasks = [ + AgentTask( id=str(uuid.uuid4()), name=task_data.get("name", "Unnamed Task"), description=task_data.get("description", ""), - document=valid_assigned_doc, + document=task_data.get("document", "") if task_data.get("document", "") in valid_doc_ids else "", status=TaskStatus.PENDING ) - tasks.append(task) - - # Limit to reasonable number of initial tasks + for task_data in plan_data.get("tasks", []) + ] if len(tasks) > 4: tasks = tasks[:4] logger.debug("Limited initial tasks to 4") - - return TaskPlan( - initial_query=query, - tasks=tasks, - current_iteration=0 - ) - + return TaskPlan(initial_query=query, tasks=tasks) except (json.JSONDecodeError, KeyError) as e: logger.error(f"Failed to parse initial plan: {e}") raise TaskPlanningError(f"Failed to parse task plan JSON: {e}, Raw response: {result}") - def _apply_plan_updates( - self, - current_plan: TaskPlan, - update_result: str, - latest_result: TaskResult - ) -> TaskPlan: - """Apply updates to the current plan based on agent's decision""" + def _apply_plan_updates(self, current_plan: TaskPlan, update_result: str) -> TaskPlan: + """Applies updates to the current plan based on the agent's decision.""" try: update_data = json.loads(sanitize_llm_json(update_result)) action = update_data.get("action", "continue") reason = update_data.get("reason", "No reason provided") - logger.debug(f"Plan update action: {action} - {reason}") - if action == "continue": - # No changes needed - logger.info("Continuing with current plan unchanged") - - elif action == "add_tasks": - # Add new tasks - new_tasks_data = update_data.get("new_tasks", []) - for task_data in new_tasks_data: - assigned_doc = task_data.get("document", "") - new_task = AgentTask( + if action == "add_tasks": + for task_data in update_data.get("new_tasks", []): + current_plan.add_task(AgentTask( name=task_data.get("name", "New Task"), description=task_data.get("description", ""), - document=assigned_doc, + document=task_data.get("document", ""), status=TaskStatus.PENDING - ) - current_plan.add_task(new_task) - logger.info(f"Added new task: {new_task.name} - Document: {assigned_doc}") - + )) + logger.info(f"Added new task: {task_data.get('name')} - Document: {task_data.get('document')}") elif action == "remove_tasks": - # Remove specified tasks - task_ids_to_remove = update_data.get("tasks_to_remove", []) - for task_id in task_ids_to_remove: + for task_id in update_data.get("tasks_to_remove", []): if current_plan.remove_task(task_id): logger.info(f"Removed task: {task_id}") - elif action == "modify_tasks": - # Modify existing tasks - modifications = update_data.get("modified_tasks", []) - for modification in modifications: - task_id = modification.get("task_id") - task = next((t for t in current_plan.tasks if t.id == task_id), None) + for modification in update_data.get("modified_tasks", []): + task = next((t for t in current_plan.tasks if t.id == modification.get("task_id")), None) if task and task.status == TaskStatus.PENDING: - old_name = task.name - old_doc = task.document task.name = modification.get("new_name", task.name) task.description = modification.get("new_description", task.description) task.document = modification.get("new_document", task.document) - logger.info(f"Modified task '{old_name}' -> '{task.name}' (Document: {old_doc} -> {task.document})") + logger.info(f"Modified task '{task.name}'") current_plan.current_iteration += 1 return current_plan - except (json.JSONDecodeError, KeyError) as e: logger.error(f"Failed to parse plan updates: {e}") raise TaskPlanningError(f"Failed to parse plan update JSON: {e}") + def _format_documents_for_prompt(self, documents: Optional[List[Document]]) -> str: + """Formats a list of documents into a string for the prompt.""" + if not documents: + return "No documents available" + return "\n\n".join([ + f"{doc.id}: {doc.name}\nSummary: {doc.summary or f'Document with {len(doc.pages)} pages'}" + for doc in documents + ]) + def _build_plan_status(self, plan: TaskPlan) -> str: - """Build text summary of current plan status""" - status_lines = [] - for task in plan.tasks: - status_lines.append(f"- {task.name}: {task.status.value}") - return "\n".join(status_lines) + """Builds a text summary of the current plan status.""" + return "\n".join([f"- {task.name}: {task.status.value}" for task in plan.tasks]) def _build_progress_summary(self, plan: TaskPlan, latest_result: TaskResult) -> str: - """Build summary of progress so far""" + """Builds a summary of the progress so far.""" completed_tasks = plan.get_completed_tasks() - if not completed_tasks: return f"Just completed first task: {latest_result.task.name}" - - summary_parts = [] - for task in completed_tasks: - summary_parts.append(f"✓ {task.name}") - - return "Completed tasks:\n" + "\n".join(summary_parts) + return "Completed tasks:\n" + "\n".join([f"✓ {task.name}" for task in completed_tasks]) diff --git a/docpixie/cli/app.py b/docpixie/cli/app.py index b91d3a4..645b657 100644 --- a/docpixie/cli/app.py +++ b/docpixie/cli/app.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 """ -DocPixie Textual CLI - Modern terminal interface for document chat +The main application file for the DocPixie Textual CLI. + +This module defines the main user interface for the DocPixie command-line +application, built using the Textual framework. It includes the main app class, +a setup screen for initial configuration, and custom widgets for the chat +interface. """ import asyncio @@ -31,7 +36,9 @@ class ChatInput(TextArea): - """Custom TextArea for chat input with Enter to submit""" + """ + A custom `TextArea` for chat input that submits on Enter. + """ BINDINGS = [ Binding("ctrl+d", "show_documents", "Documents", priority=True), @@ -43,36 +50,29 @@ class ChatInput(TextArea): ] def action_submit_message(self) -> None: - """Submit on Enter""" - app = self.app - if hasattr(app, 'submit_chat_message'): - asyncio.create_task(app.submit_chat_message()) + """Submits the chat message when the Enter key is pressed.""" + if hasattr(self.app, 'submit_chat_message'): + asyncio.create_task(self.app.submit_chat_message()) def action_add_newline(self) -> None: - """Add a newline on Shift+Enter""" + """Inserts a newline character on Shift+Enter.""" self.insert("\n") def action_show_documents(self) -> None: - """Forward Ctrl+D to app's document manager action""" - app = self.app - try: - if hasattr(app, 'action_show_documents'): - app.action_show_documents() - except Exception: - pass + """Forwards the 'show_documents' action to the main app.""" + if hasattr(self.app, 'action_show_documents'): + self.app.action_show_documents() def action_toggle_palette(self) -> None: - """Forward Ctrl+/ to app's command palette toggle""" - app = self.app - try: - if hasattr(app, 'action_toggle_palette'): - app.action_toggle_palette() - except Exception: - pass + """Forwards the 'toggle_palette' action to the main app.""" + if hasattr(self.app, 'action_toggle_palette'): + self.app.action_toggle_palette() class SetupScreen(Screen): - """First-time setup screen for API key configuration""" + """ + A screen for first-time setup, specifically for API key configuration. + """ CSS = SETUP_SCREEN_CSS BINDINGS = [ @@ -80,34 +80,20 @@ class SetupScreen(Screen): ] def compose(self) -> ComposeResult: + """Composes the UI for the setup screen.""" with Container(id="setup-container"): yield Static("[bold]Welcome to DocPixie![/bold]", classes="title") - yield Static( - "DocPixie needs an OpenRouter API key to work with documents.", - classes="setup-text", - ) - yield Static( - "Get your API key from: https://openrouter.ai/keys", - classes="setup-text", - ) - yield Input( - placeholder="Enter your OpenRouter API key...", - id="api-input", - password=True, - ) - yield Static( - "Press Enter to confirm • Press Esc to quit (only if key empty)", - id="setup-hint", - ) + yield Static("DocPixie needs an OpenRouter API key to work with documents.", classes="setup-text") + yield Static("Get your API key from: https://openrouter.ai/keys", classes="setup-text") + yield Input(placeholder="Enter your OpenRouter API key...", id="api-input", password=True) + yield Static("Press Enter to confirm • Press Esc to quit (only if key empty)", id="setup-hint") async def on_mount(self) -> None: - # Focus the input when screen shows - try: - self.query_one("#api-input", Input).focus() - except Exception: - pass + """Called when the screen is mounted.""" + self.query_one("#api-input", Input).focus() def action_confirm(self) -> None: + """Confirms and saves the entered API key.""" api_input = self.query_one("#api-input", Input) api_key = api_input.value.strip() @@ -115,28 +101,20 @@ def action_confirm(self) -> None: api_input.placeholder = "API key cannot be empty!" return - # Save API key - config_manager = get_config_manager() - config_manager.set_api_key(api_key) - - # Return to main app and initialize + get_config_manager().set_api_key(api_key) self.app.pop_screen() asyncio.create_task(self.app.docpixie_manager.initialize_docpixie()) def action_quit_if_empty(self) -> None: + """Quits the application if the API key input is empty.""" api_input = self.query_one("#api-input", Input) if not api_input.value.strip(): self.app.exit() else: - # N only quits when empty; hint the user - try: - hint = self.query_one("#setup-hint", Static) - hint.update("Clear the key to quit with Esc, or press Enter to save.") - except Exception: - pass + self.query_one("#setup-hint", Static).update("Clear the key to quit with Esc, or press Enter to save.") async def on_input_submitted(self, event: Input.Submitted) -> None: - # Submit on Enter while focused in the input + """Handles the submission of the input field.""" self.action_confirm() @@ -147,10 +125,22 @@ class DocPixieTUI( ModelEventMixin, DocumentEventMixin ): - """Main DocPixie Terminal UI Application""" + """ + The main DocPixie Terminal UI application class. - CSS = MAIN_APP_CSS + This class orchestrates the entire CLI, managing state, handling user + input, and displaying the chat interface. + Attributes: + docpixie: An optional instance of the `DocPixie` main class. + state_manager: Manages the application's state. + config_manager: Manages configuration settings. + command_handler: Handles user commands. + docpixie_manager: Manages the `DocPixie` instance and its operations. + task_display_manager: Manages the display of agent task updates. + """ + + CSS = MAIN_APP_CSS BINDINGS = [ ("ctrl+n", "new_conversation", "New Conversation"), ("ctrl+l", "show_conversations", "Conversations"), @@ -161,6 +151,7 @@ class DocPixieTUI( ] def __init__(self): + """Initializes the DocPixieTUI application.""" super().__init__() self.docpixie: Optional[DocPixie] = None self.state_manager = AppStateManager() @@ -170,258 +161,155 @@ def __init__(self): self.task_display_manager = TaskDisplayManager(self, self.state_manager) def compose(self) -> ComposeResult: - """Create the main UI layout""" + """Composes the main UI layout of the application.""" yield Header(show_clock=True) - with Container(id="chat-container"): yield ChatArea(id="chat-log") - with Horizontal(id="status-bar"): yield Label(self.state_manager.get_status_text(), id="status-label") - with Horizontal(id="input-container"): yield Static(">", id="prompt-indicator") - text_area = ChatInput( - "", - id="chat-input", - language=None, - tab_behavior="indent" - ) - text_area.show_line_numbers = False - yield text_area - + yield ChatInput("", id="chat-input", language=None, tab_behavior="indent", show_line_numbers=False) yield Label(self.state_manager.default_input_hint, id="input-hint") - - yield CommandPalette(id="command-palette") - yield Footer() - async def on_mount(self) -> None: - """Initialize the app when mounted""" + """Called when the application is mounted.""" self.set_timer(0.1, self.deferred_init) - try: - self.call_after_refresh( - lambda: self.query_one("#chat-input", ChatInput).focus() - ) - except Exception: - pass + self.call_after_refresh(lambda: self.query_one("#chat-input", ChatInput).focus()) async def deferred_init(self) -> None: - """Deferred initialization to allow UI to render""" + """Performs deferred initialization after the UI has been rendered.""" if not self.config_manager.has_api_key(): await self.push_screen(SetupScreen()) else: await self.docpixie_manager.initialize_docpixie() - - - - - def show_welcome_message(self) -> None: - """Display welcome message and instructions""" - chat_log = self.query_one("#chat-log", ChatArea) - + """Displays the welcome message and instructions.""" from rich.panel import Panel from rich.align import Align from rich.text import Text - # Create colorful ASCII art using pyfiglet - ascii_art = Text() figlet_text = pyfiglet.figlet_format("DocPixie CLI", font="big") - - # Define a pink-forward gradient (deep → light) - colors = [ - "#ff4da6", # deep pink - "#ff66b3", - "#ff80bf", - "#ff99cc", # brand pink - "#ffb3d9", - "#ffcce6" # light pink - ] - - lines = figlet_text.split("\n") - for line in lines: + colors = ["#ff4da6", "#ff66b3", "#ff80bf", "#ff99cc", "#ffb3d9", "#ffcce6"] + ascii_art = Text() + for line in figlet_text.split("\n"): if line.strip(): colored_line = Text() - chars = list(line) - for i, char in enumerate(chars): - if char != " " and char != "\n": - color_index = (i * (len(colors) - 1)) // max(len(chars) - 1, 1) - colored_line.append(char, style=colors[color_index] + " bold") + for i, char in enumerate(line): + if char.strip(): + color_index = (i * (len(colors) - 1)) // max(len(line) - 1, 1) + colored_line.append(char, style=f"{colors[color_index]} bold") else: colored_line.append(char) - ascii_art.append(colored_line) - ascii_art.append("\n") - - welcome_content = Text() - welcome_content.append("\n") - welcome_content.append(ascii_art) - welcome_content.append("\n\n") + ascii_art.append(colored_line).append("\n") + welcome_content = Text("\n", justify="center").append(ascii_art).append("\n\n") if self.state_manager.indexed_documents: welcome_content.append(f"{len(self.state_manager.indexed_documents)} document(s) indexed and ready!\n\n", style="bold green") else: - welcome_content.append("No documents indexed yet\n", style="yellow") - welcome_content.append("Add PDFs to ./documents and type ", style="dim") - welcome_content.append(" to get started\n\n", style="dim") - - welcome_content.append("Start chatting with your documents or type ", style="white") - welcome_content.append("/", style="bold cyan") - welcome_content.append(" to see all commands", style="white") - - panel = Panel( - Align.center(welcome_content), - title="[bold #ff99cc]DocPixie[/]", - border_style="#ff99cc", - padding=(1, 2), - expand=False - ) - - chat_log.write(panel) - chat_log.add_static_text("\n") + welcome_content.append("No documents indexed yet\n", style="yellow").append("Add PDFs to ./documents and type ", style="dim").append("/index", style="bold cyan").append(" to get started\n\n", style="dim") + welcome_content.append("Start chatting with your documents or type ").append("/", style="bold cyan").append(" to see all commands") + self.query_one("#chat-log", ChatArea).write(Panel( + Align.center(welcome_content), title="[bold #ff99cc]DocPixie[/]", border_style="#ff99cc", padding=(1, 2) + )) async def submit_chat_message(self) -> None: - """Submit the chat message from the TextArea""" + """Submits the chat message from the input `TextArea`.""" if self.state_manager.command_palette_active: command_palette = self.query_one("#command-palette", CommandPalette) - selected_command = command_palette.select_current_command() - if selected_command: + if selected_command := command_palette.select_current_command(): command_palette.hide() self.state_manager.command_palette_active = False - text_area = self.query_one("#chat-input", ChatInput) - text_area.clear() + self.query_one("#chat-input", ChatInput).clear() await self.handle_command(selected_command) return text_area = self.query_one("#chat-input", ChatInput) - user_input = text_area.text.strip() - - if user_input: + if user_input := text_area.text.strip(): text_area.clear() await self.submit_text(user_input) - async def submit_text(self, user_input: str) -> None: - """Handle text submission from TextArea""" - if self.state_manager.processing: - return - - chat_log = self.query_one("#chat-log", ChatArea) - - if not user_input: + """Handles the submission of text from the input `TextArea`.""" + if self.state_manager.processing or not user_input: return if self.state_manager.command_palette_active: - command_palette = self.query_one("#command-palette", CommandPalette) - command_palette.hide() + self.query_one("#command-palette", CommandPalette).hide() self.state_manager.command_palette_active = False if user_input.startswith("/"): await self.handle_command(user_input.lower()) return - chat_log.add_user_message(user_input) - + self.query_one("#chat-log", ChatArea).add_user_message(user_input) self.set_chat_input_enabled(False) try: - async def task_callback(event_type: str, data: Any): - def _update(): - try: - self.task_display_manager.display_task_update(event_type, data) - except Exception: - pass - try: - self.call_from_thread(_update) - except Exception: - _update() - - await self.docpixie_manager.process_query(user_input, task_callback) + await self.docpixie_manager.process_query(user_input, self._task_callback) finally: self.set_chat_input_enabled(True) - def set_chat_input_enabled(self, enabled: bool) -> None: - """Enable or disable the chat input and update hint text.""" - try: - text_area = self.query_one("#chat-input", ChatInput) - hint = self.query_one("#input-hint", Label) - except Exception: - return - - try: - text_area.disabled = not enabled - except Exception: - pass + async def _task_callback(self, event_type: str, data: Any): + """A callback function to handle task updates from the agent.""" + self.call_from_thread(self.task_display_manager.display_task_update, event_type, data) + def set_chat_input_enabled(self, enabled: bool) -> None: + """Enables or disables the chat input and updates the hint text.""" + text_area = self.query_one("#chat-input", ChatInput) + hint = self.query_one("#input-hint", Label) + text_area.disabled = not enabled if enabled: hint.update(self.state_manager.default_input_hint) - try: - self.call_after_refresh(lambda: text_area.focus()) - except Exception: - try: - text_area.focus() - except Exception: - pass + self.call_after_refresh(text_area.focus) else: hint.update("⏳ Agent is working… input disabled until response.") - - - - - - - async def handle_command(self, command: str) -> None: - """Handle slash commands""" + """Handles the execution of slash commands.""" await self.command_handler.handle_command(command) - - def action_quit(self) -> None: - """Quit the application""" + """Quits the application.""" self.exit() def action_new_conversation(self) -> None: - """Start a new conversation""" + """Starts a new conversation.""" asyncio.create_task(self.handle_command("/new")) def action_show_conversations(self) -> None: - """Show conversation list""" + """Shows the conversation list.""" asyncio.create_task(self.handle_command("/conversations")) def action_show_models(self) -> None: - """Show model selector""" + """Shows the model selector.""" asyncio.create_task(self.handle_command("/model")) def action_show_documents(self) -> None: - """Show document manager""" + """Shows the document manager.""" asyncio.create_task(self.handle_command("/documents")) def action_toggle_palette(self) -> None: - """Toggle command palette""" + """Toggles the visibility of the command palette.""" if self.state_manager.processing: return command_palette = self.query_one("#command-palette", CommandPalette) text_area = self.query_one("#chat-input", ChatInput) - + self.state_manager.command_palette_active = not self.state_manager.command_palette_active if self.state_manager.command_palette_active: - command_palette.hide() - self.state_manager.command_palette_active = False - else: command_palette.show("/") - self.state_manager.command_palette_active = True text_area.text = "/" text_area.cursor_location = (0, 1) + else: + command_palette.hide() def main(): - """Main entry point for Textual CLI""" - app = DocPixieTUI() - app.run() + """The main entry point for the Textual CLI.""" + DocPixieTUI().run() if __name__ == "__main__": diff --git a/docpixie/cli/commands.py b/docpixie/cli/commands.py index 5559ee3..d241c76 100644 --- a/docpixie/cli/commands.py +++ b/docpixie/cli/commands.py @@ -1,5 +1,9 @@ """ -Command handling for DocPixie CLI +Command handling for the DocPixie CLI. + +This module provides the `CommandHandler` class, which is responsible for +processing all slash commands (e.g., `/new`, `/help`) entered by the user in +the CLI application. """ from typing import TYPE_CHECKING, Optional @@ -16,98 +20,103 @@ class CommandHandler: - """Handles all slash commands for the CLI application""" - + """ + Handles all slash commands for the CLI application. + + This class centralizes the logic for interpreting and executing user + commands, acting as a bridge between the main application, the state + manager, and the various dialogs and widgets. + + Attributes: + app: The main `DocPixieTUI` application instance. + state_manager: The application's state manager. + """ + def __init__(self, app: 'DocPixieTUI', state_manager: AppStateManager): + """ + Initializes the CommandHandler. + + Args: + app: The main `DocPixieTUI` application instance. + state_manager: The application's state manager. + """ self.app = app self.state_manager = state_manager - + async def handle_command(self, command: str) -> None: - """Handle slash commands""" + """ + Handles the execution of a given slash command. + + Args: + command: The command string entered by the user. + """ chat_log = self.app.query_one("#chat-log", ChatArea) - - if command == "/exit": - self.state_manager.save_current_conversation() - self.app.exit() - - elif command == "/new": - await self._handle_new_command(chat_log) - - elif command == "/clear": - self._handle_clear_command(chat_log) - - elif command == "/save": - self._handle_save_command(chat_log) - - elif command == "/conversations": - await self._handle_conversations_command() - - elif command == "/model": - await self._handle_model_command() - - elif command == "/documents": - await self._handle_documents_command() - - elif command == "/help": - self._handle_help_command(chat_log) - + command_map = { + "/exit": self._handle_exit_command, + "/new": self._handle_new_command, + "/clear": self._handle_clear_command, + "/save": self._handle_save_command, + "/conversations": self._handle_conversations_command, + "/model": self._handle_model_command, + "/documents": self._handle_documents_command, + "/help": self._handle_help_command, + } + handler = command_map.get(command) + if handler: + await handler(chat_log) else: - chat_log.write(f"[warning]Unknown command: {command}[/warning]\n") - chat_log.write("Type /help for available commands\n\n") - + chat_log.write(f"[warning]Unknown command: {command}[/warning]\nType /help for available commands\n\n") + + async def _handle_exit_command(self, chat_log: ChatArea) -> None: + """Handles the /exit command.""" + self.state_manager.save_current_conversation() + self.app.exit() + async def _handle_new_command(self, chat_log: ChatArea) -> None: - """Handle /new command""" + """Handles the /new command.""" self.state_manager.save_current_conversation() self.state_manager.create_new_conversation() self.state_manager.clear_task_plan() - chat_log.clear() self.app.show_welcome_message() chat_log.write("[green bold]●[/green bold] Started new conversation\n\n") - - status_label = self.app.query_one("#status-label") - status_label.update(self.state_manager.get_status_text()) - - def _handle_clear_command(self, chat_log: ChatArea) -> None: - """Handle /clear command""" + self.app.query_one("#status-label").update(self.state_manager.get_status_text()) + + async def _handle_clear_command(self, chat_log: ChatArea) -> None: + """Handles the /clear command.""" self.state_manager.clear_task_plan() chat_log.clear() self.app.show_welcome_message() - - def _handle_save_command(self, chat_log: ChatArea) -> None: - """Handle /save command""" + + async def _handle_save_command(self, chat_log: ChatArea) -> None: + """Handles the /save command.""" if self.state_manager.current_conversation_id and self.state_manager.conversation_history: self.state_manager.save_current_conversation() chat_log.write("[green bold]●[/green bold] Conversation saved!\n\n") else: chat_log.write("[warning]No conversation to save[/warning]\n\n") - - async def _handle_conversations_command(self) -> None: - """Handle /conversations command""" - await self.app.push_screen(ConversationManagerDialog( - self.state_manager.current_conversation_id - )) - - async def _handle_model_command(self) -> None: - """Handle /model command""" + + async def _handle_conversations_command(self, chat_log: ChatArea) -> None: + """Handles the /conversations command.""" + await self.app.push_screen(ConversationManagerDialog(self.state_manager.current_conversation_id)) + + async def _handle_model_command(self, chat_log: ChatArea) -> None: + """Handles the /model command.""" await self.app.push_screen(ModelSelectorDialog()) - - async def _handle_documents_command(self) -> None: - """Handle /documents command""" - await self.app.push_screen(DocumentManagerDialog( - self.state_manager.documents_folder, - self.app.docpixie - )) - - def _handle_help_command(self, chat_log: ChatArea) -> None: - """Handle /help command""" - chat_log.write("\n[bold]Available Commands:[/bold]\n") - chat_log.write(" /new - Start a new conversation (Ctrl+N)\n") - chat_log.write(" /conversations - Switch between conversations (Ctrl+L)\n") - chat_log.write(" /save - Save current conversation\n") - chat_log.write(" /clear - Clear the chat display\n") - chat_log.write(" /model - Configure AI models (Ctrl+O)\n") - chat_log.write(" /documents - Manage and index documents (Ctrl+D)\n") - chat_log.write(" /help - Show this help message\n") - chat_log.write(" /exit - Exit the program (Ctrl+Q)\n\n") - chat_log.write("[dim]Press Ctrl+/ to open command palette[/dim]\n\n") \ No newline at end of file + + async def _handle_documents_command(self, chat_log: ChatArea) -> None: + """Handles the /documents command.""" + await self.app.push_screen(DocumentManagerDialog(self.state_manager.documents_folder, self.app.docpixie)) + + async def _handle_help_command(self, chat_log: ChatArea) -> None: + """Handles the /help command.""" + chat_log.write("\n[bold]Available Commands:[/bold]\n" + " /new - Start a new conversation (Ctrl+N)\n" + " /conversations - Switch between conversations (Ctrl+L)\n" + " /save - Save current conversation\n" + " /clear - Clear the chat display\n" + " /model - Configure AI models (Ctrl+O)\n" + " /documents - Manage and index documents (Ctrl+D)\n" + " /help - Show this help message\n" + " /exit - Exit the program (Ctrl+Q)\n\n" + "[dim]Press Ctrl+/ to open command palette[/dim]\n\n") \ No newline at end of file diff --git a/docpixie/cli/config.py b/docpixie/cli/config.py index d50c638..6c2429c 100644 --- a/docpixie/cli/config.py +++ b/docpixie/cli/config.py @@ -1,6 +1,8 @@ """ -Global configuration manager for DocPixie CLI -Handles API keys, model preferences, and user settings +Global configuration management for the DocPixie CLI. + +This module handles the loading, saving, and accessing of configuration data +for the CLI, including API keys, model preferences, and other user settings. """ import json @@ -9,93 +11,82 @@ from typing import Optional, Dict, Any from dataclasses import dataclass, asdict, field - +# Lists of recommended models for different purposes PLANNING_MODELS = [ - "anthropic/claude-opus-4.1", - "anthropic/claude-sonnet-4", - "anthropic/claude-3.5-haiku", - "google/gemini-2.5-flash", - "google/gemini-2.5-pro", - "openai/gpt-5", - "openai/gpt-5-mini", - "openai/gpt-4.1", - "openai/gpt-4.1-mini", - "qwen/qwen-max", - "qwen/qwen-plus", - "nousresearch/hermes-4-70b", - "deepseek/deepseek-chat-v3.1", - "mistralai/mistral-medium-3.1", + "anthropic/claude-opus-4.1", "anthropic/claude-sonnet-4", "anthropic/claude-3.5-haiku", + "google/gemini-2.5-flash", "google/gemini-2.5-pro", "openai/gpt-5", "openai/gpt-5-mini", + "openai/gpt-4.1", "openai/gpt-4.1-mini", "qwen/qwen-max", "qwen/qwen-plus", + "nousresearch/hermes-4-70b", "deepseek/deepseek-chat-v3.1", "mistralai/mistral-medium-3.1", ] - VISION_MODELS = [ - "google/gemini-2.5-pro", - "google/gemini-2.5-flash", - "google/gemini-2.5-flash-lite", - "openai/gpt-4.1", - "openai/gpt-4.1-mini", - "openai/gpt-4.1-nano", - "anthropic/claude-sonnet-4", + "google/gemini-2.5-pro", "google/gemini-2.5-flash", "google/gemini-2.5-flash-lite", + "openai/gpt-4.1", "openai/gpt-4.1-mini", "openai/gpt-4.1-nano", "anthropic/claude-sonnet-4", ] @dataclass class CLIConfig: - """CLI configuration stored globally in ~/.docpixie/""" + """ + A dataclass for storing the CLI configuration. + This configuration is typically stored in `~/.docpixie/config.json`. + """ openrouter_api_key: Optional[str] = None - text_model: str = "qwen/qwen-plus" vision_model: str = "google/gemini-2.5-flash" - last_conversation_id: Optional[str] = None theme: str = "default" - auto_index_on_startup: bool = True max_conversation_history: int = 20 def to_dict(self) -> Dict[str, Any]: - """Convert config to dictionary for JSON serialization""" + """Converts the configuration to a dictionary for JSON serialization.""" return asdict(self) @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'CLIConfig': - """Create config from dictionary""" + """Creates a `CLIConfig` instance from a dictionary.""" return cls(**data) class ConfigManager: - """Manages global DocPixie CLI configuration""" + """ + Manages the global configuration for the DocPixie CLI. + + This class handles all interactions with the configuration file, including + loading, saving, and providing access to various settings. + + Attributes: + config_dir: The path to the configuration directory (`~/.docpixie`). + config_file: The path to the configuration file. + conversations_dir: The path to the directory where conversations are stored. + config: The loaded `CLIConfig` object. + """ def __init__(self): - """Initialize config manager with global config directory""" + """Initializes the `ConfigManager`.""" self.config_dir = Path.home() / ".docpixie" self.config_file = self.config_dir / "config.json" self.conversations_dir = self.config_dir / "conversations" - self.config_dir.mkdir(exist_ok=True) self.conversations_dir.mkdir(exist_ok=True) - self.config = self.load_config() def load_config(self) -> CLIConfig: - """Load configuration from file or create default""" + """Loads the configuration from the file, or creates a default one.""" if self.config_file.exists(): try: with open(self.config_file, 'r') as f: - data = json.load(f) - return CLIConfig.from_dict(data) + return CLIConfig.from_dict(json.load(f)) except Exception as e: print(f"Warning: Failed to load config: {e}") - return CLIConfig() - else: - env_key = os.getenv("OPENROUTER_API_KEY") - config = CLIConfig() - if env_key: - config.openrouter_api_key = env_key - return config + config = CLIConfig() + if env_key := os.getenv("OPENROUTER_API_KEY"): + config.openrouter_api_key = env_key + return config def save_config(self): - """Save current configuration to file""" + """Saves the current configuration to the file.""" try: with open(self.config_file, 'w') as f: json.dump(self.config.to_dict(), f, indent=2) @@ -103,26 +94,28 @@ def save_config(self): print(f"Error saving config: {e}") def get_api_key(self) -> Optional[str]: - """Get OpenRouter API key from config or environment""" - if self.config.openrouter_api_key: - return self.config.openrouter_api_key - return os.getenv("OPENROUTER_API_KEY") + """ + Retrieves the OpenRouter API key. + + It first checks the config file, then falls back to the environment variable. + """ + return self.config.openrouter_api_key or os.getenv("OPENROUTER_API_KEY") def set_api_key(self, api_key: str): - """Set and save OpenRouter API key""" + """Sets and saves the OpenRouter API key.""" self.config.openrouter_api_key = api_key self.save_config() def has_api_key(self) -> bool: - """Check if API key is configured""" + """Checks if an API key is configured.""" return bool(self.get_api_key()) def get_models(self) -> tuple[str, str]: - """Get configured models (text, vision)""" + """Returns the configured text and vision models.""" return self.config.text_model, self.config.vision_model def set_models(self, text_model: str = None, vision_model: str = None): - """Update model configuration""" + """Updates and saves the model configuration.""" if text_model: self.config.text_model = text_model if vision_model: @@ -130,31 +123,33 @@ def set_models(self, text_model: str = None, vision_model: str = None): self.save_config() def get_conversation_path(self, conversation_id: str) -> Path: - """Get path for a specific conversation file""" + """Returns the file path for a specific conversation.""" return self.conversations_dir / f"{conversation_id}.json" def get_all_conversations(self) -> list[Path]: - """Get all conversation files""" + """Returns a list of all saved conversation files.""" return list(self.conversations_dir.glob("*.json")) def validate_api_key(self, api_key: str) -> bool: """ - Validate API key by making a test request - Returns True if valid, False otherwise + Performs a basic validation of an API key. + + Args: + api_key: The API key to validate. + + Returns: + `True` if the key is likely valid, `False` otherwise. """ - try: - if api_key and len(api_key) > 10: - return True - return False - except Exception: - return False + return bool(api_key and len(api_key) > 10) -_config_manager = None +_config_manager: Optional[ConfigManager] = None def get_config_manager() -> ConfigManager: - """Get or create the global config manager instance""" + """ + Returns the singleton instance of the `ConfigManager`. + """ global _config_manager if _config_manager is None: _config_manager = ConfigManager() diff --git a/docpixie/cli/conversation_storage.py b/docpixie/cli/conversation_storage.py index 8f4a7ce..7fac5dc 100644 --- a/docpixie/cli/conversation_storage.py +++ b/docpixie/cli/conversation_storage.py @@ -1,6 +1,9 @@ """ -Local conversation storage for DocPixie CLI -Stores conversations per project directory +Local conversation storage for the DocPixie CLI. + +This module provides classes for managing the storage of conversations on the +local file system. Conversations are stored on a per-project basis within a +`.docpixie` directory. """ import json @@ -16,7 +19,7 @@ @dataclass class ConversationMetadata: - """Metadata for a conversation""" + """A dataclass for storing metadata about a conversation.""" id: str name: str working_directory: str @@ -28,260 +31,187 @@ class ConversationMetadata: class ConversationStorage: - """Manages local conversation storage in ./.docpixie/conversations/""" - + """ + Manages the local storage of conversations. + + This class handles the creation, saving, loading, and deletion of + conversation files, which are stored in a `.docpixie/conversations` + directory within the current project. + + Attributes: + base_path: The root path for the DocPixie CLI's local data. + conversations_dir: The directory where conversations are stored. + metadata_file: The file for storing metadata about all conversations. + working_directory: The current working directory. + current_conversation_id: The ID of the currently active conversation. + """ + def __init__(self): - """Initialize conversation storage for current directory""" + """Initializes the `ConversationStorage`.""" self.base_path = Path("./.docpixie") self.conversations_dir = self.base_path / "conversations" self.metadata_file = self.conversations_dir / "metadata.json" - self.conversations_dir.mkdir(parents=True, exist_ok=True) - self.working_directory = str(Path.cwd().resolve()) - self.current_conversation_id: Optional[str] = None - self._load_metadata() - + def _load_metadata(self) -> Dict[str, ConversationMetadata]: - """Load conversation metadata from file""" + """Loads conversation metadata from the metadata file.""" if not self.metadata_file.exists(): return {} - try: with open(self.metadata_file, 'r') as f: data = json.load(f) - - metadata = {} - for conv_id, conv_data in data.items(): - if 'total_cost' not in conv_data: - conv_data['total_cost'] = 0.0 - metadata[conv_id] = ConversationMetadata(**conv_data) - - return metadata + return { + conv_id: ConversationMetadata(**{**conv_data, 'total_cost': conv_data.get('total_cost', 0.0)}) + for conv_id, conv_data in data.items() + } except Exception as e: print(f"Warning: Failed to load conversation metadata: {e}") return {} - + def _save_metadata(self, metadata: Dict[str, ConversationMetadata]): - """Save conversation metadata to file""" + """Saves conversation metadata to the metadata file.""" try: - data = {} - for conv_id, conv_meta in metadata.items(): - data[conv_id] = asdict(conv_meta) - with open(self.metadata_file, 'w') as f: - json.dump(data, f, indent=2) + json.dump({cid: asdict(cm) for cid, cm in metadata.items()}, f, indent=2) except Exception as e: print(f"Error saving conversation metadata: {e}") - + def _conversation_file_path(self, conversation_id: str) -> Path: - """Get path for conversation file""" + """Returns the file path for a specific conversation.""" return self.conversations_dir / f"{conversation_id}.json" - + def _generate_conversation_name(self, messages: List[ConversationMessage]) -> str: - """Generate a conversation name from the first user message""" - if not messages: - return f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}" - - first_user_message = None - for msg in messages: - if msg.role == "user": - first_user_message = msg - break - + """Generates a name for a conversation from its first user message.""" + first_user_message = next((msg for msg in messages if msg.role == "user"), None) if first_user_message: name = first_user_message.content.strip()[:50] - if len(first_user_message.content) > 50: - name += "..." - return name - else: - return f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}" - + return name + "..." if len(first_user_message.content) > 50 else name + return f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}" + def create_new_conversation(self, indexed_documents: List[str] = None) -> str: - """Create a new conversation and return its ID""" + """Creates a new conversation and returns its ID.""" conversation_id = str(uuid.uuid4()) now = datetime.now().isoformat() - metadata = ConversationMetadata( - id=conversation_id, - name="New Chat", - working_directory=self.working_directory, - created_at=now, - updated_at=now, - message_count=0, - indexed_documents=indexed_documents or [], - total_cost=0.0 + id=conversation_id, name="New Chat", working_directory=self.working_directory, + created_at=now, updated_at=now, message_count=0, + indexed_documents=indexed_documents or [], total_cost=0.0 ) - - conversation_data = { - "id": conversation_id, - "metadata": asdict(metadata), - "messages": [] - } - - conversation_file = self._conversation_file_path(conversation_id) - with open(conversation_file, 'w') as f: - json.dump(conversation_data, f, indent=2) - + with open(self._conversation_file_path(conversation_id), 'w') as f: + json.dump({"id": conversation_id, "metadata": asdict(metadata), "messages": []}, f, indent=2) all_metadata = self._load_metadata() all_metadata[conversation_id] = metadata self._save_metadata(all_metadata) - self.current_conversation_id = conversation_id return conversation_id - - def save_conversation(self, conversation_id: str, messages: List[ConversationMessage], - indexed_documents: List[str] = None): - """Save conversation messages""" + + def save_conversation(self, conversation_id: str, messages: List[ConversationMessage], indexed_documents: List[str] = None): + """Saves the messages of a conversation.""" try: - now = datetime.now().isoformat() - messages_data = [] total_cost = 0.0 for msg in messages: - msg_dict = { - "role": msg.role, - "content": msg.content, - "timestamp": msg.timestamp.isoformat() - } msg_cost = getattr(msg, 'cost', 0.0) or 0.0 - msg_dict["cost"] = msg_cost total_cost += msg_cost - messages_data.append(msg_dict) - + messages_data.append({ + "role": msg.role, "content": msg.content, + "timestamp": msg.timestamp.isoformat(), "cost": msg_cost + }) all_metadata = self._load_metadata() - if conversation_id in all_metadata: - conv_metadata = all_metadata[conversation_id] + now = datetime.now().isoformat() + if conv_metadata := all_metadata.get(conversation_id): conv_metadata.updated_at = now conv_metadata.message_count = len(messages) conv_metadata.total_cost = total_cost if indexed_documents is not None: conv_metadata.indexed_documents = indexed_documents - if conv_metadata.name == "New Chat" and messages: conv_metadata.name = self._generate_conversation_name(messages) else: conv_metadata = ConversationMetadata( - id=conversation_id, - name=self._generate_conversation_name(messages), - working_directory=self.working_directory, - created_at=now, - updated_at=now, - message_count=len(messages), - indexed_documents=indexed_documents or [], + id=conversation_id, name=self._generate_conversation_name(messages), + working_directory=self.working_directory, created_at=now, updated_at=now, + message_count=len(messages), indexed_documents=indexed_documents or [], total_cost=total_cost ) all_metadata[conversation_id] = conv_metadata - - conversation_data = { - "id": conversation_id, - "metadata": asdict(conv_metadata), - "messages": messages_data - } - - conversation_file = self._conversation_file_path(conversation_id) - with open(conversation_file, 'w') as f: - json.dump(conversation_data, f, indent=2) - + with open(self._conversation_file_path(conversation_id), 'w') as f: + json.dump({"id": conversation_id, "metadata": asdict(conv_metadata), "messages": messages_data}, f, indent=2) self._save_metadata(all_metadata) - except Exception as e: print(f"Error saving conversation: {e}") - + def load_conversation(self, conversation_id: str) -> Optional[tuple[ConversationMetadata, List[ConversationMessage]]]: - """Load conversation by ID""" + """Loads a conversation from a file by its ID.""" try: conversation_file = self._conversation_file_path(conversation_id) if not conversation_file.exists(): return None - with open(conversation_file, 'r') as f: data = json.load(f) - metadata = ConversationMetadata(**data["metadata"]) - - messages = [] - for msg_data in data["messages"]: - message = ConversationMessage( - role=msg_data["role"], - content=msg_data["content"], - timestamp=datetime.fromisoformat(msg_data["timestamp"]), - cost=msg_data.get("cost", 0.0) - ) - messages.append(message) - + messages = [ + ConversationMessage( + role=msg_data["role"], content=msg_data["content"], + timestamp=datetime.fromisoformat(msg_data["timestamp"]), cost=msg_data.get("cost", 0.0) + ) for msg_data in data["messages"] + ] self.current_conversation_id = conversation_id return metadata, messages - except Exception as e: print(f"Error loading conversation: {e}") return None - + def list_local_conversations(self) -> List[ConversationMetadata]: - """List conversations from current working directory only""" + """Lists all conversations in the current working directory.""" all_metadata = self._load_metadata() - - local_conversations = [] - for conv_id, metadata in all_metadata.items(): - if metadata.working_directory == self.working_directory: - local_conversations.append(metadata) - + local_conversations = [md for md in all_metadata.values() if md.working_directory == self.working_directory] local_conversations.sort(key=lambda x: x.updated_at, reverse=True) return local_conversations - + def delete_conversation(self, conversation_id: str) -> bool: - """Delete a conversation""" + """Deletes a conversation by its ID.""" try: - conversation_file = self._conversation_file_path(conversation_id) - if conversation_file.exists(): + if (conversation_file := self._conversation_file_path(conversation_id)).exists(): conversation_file.unlink() - all_metadata = self._load_metadata() if conversation_id in all_metadata: del all_metadata[conversation_id] self._save_metadata(all_metadata) - if self.current_conversation_id == conversation_id: self.current_conversation_id = None - return True except Exception as e: print(f"Error deleting conversation: {e}") return False - + def rename_conversation(self, conversation_id: str, new_name: str) -> bool: - """Rename a conversation""" + """Renames a conversation.""" try: all_metadata = self._load_metadata() if conversation_id not in all_metadata: return False - all_metadata[conversation_id].name = new_name all_metadata[conversation_id].updated_at = datetime.now().isoformat() - - conversation_file = self._conversation_file_path(conversation_id) - if conversation_file.exists(): - with open(conversation_file, 'r') as f: + if (conversation_file := self._conversation_file_path(conversation_id)).exists(): + with open(conversation_file, 'r+') as f: data = json.load(f) - - data["metadata"]["name"] = new_name - data["metadata"]["updated_at"] = all_metadata[conversation_id].updated_at - - with open(conversation_file, 'w') as f: + data["metadata"]["name"] = new_name + data["metadata"]["updated_at"] = all_metadata[conversation_id].updated_at + f.seek(0) json.dump(data, f, indent=2) - + f.truncate() self._save_metadata(all_metadata) return True - except Exception as e: print(f"Error renaming conversation: {e}") return False - + def get_last_conversation(self) -> Optional[str]: - """Get the most recently updated conversation ID from current directory""" - conversations = self.list_local_conversations() - if conversations: + """Returns the ID of the most recently updated conversation.""" + if conversations := self.list_local_conversations(): return conversations[0].id return None \ No newline at end of file diff --git a/docpixie/cli/docpixie_manager.py b/docpixie/cli/docpixie_manager.py index 894e09d..3a300c2 100644 --- a/docpixie/cli/docpixie_manager.py +++ b/docpixie/cli/docpixie_manager.py @@ -1,5 +1,10 @@ """ -DocPixie integration manager for CLI application +DocPixie integration manager for the CLI application. + +This module provides the `DocPixieManager` class, which acts as a bridge +between the Textual UI and the core `DocPixie` library. It handles the +initialization of the `DocPixie` instance, document management, and query +processing. """ import asyncio @@ -19,219 +24,193 @@ class DocPixieManager: - """Manages DocPixie instance and all related operations""" + """ + Manages the `DocPixie` instance and all related operations for the CLI. + + Attributes: + app: The main `DocPixieTUI` application instance. + state_manager: The application's state manager. + config_manager: The application's configuration manager. + docpixie: An optional instance of the `DocPixie` main class. + """ def __init__(self, app: 'DocPixieTUI', state_manager: AppStateManager): + """ + Initializes the `DocPixieManager`. + + Args: + app: The main `DocPixieTUI` application instance. + state_manager: The application's state manager. + """ self.app = app self.state_manager = state_manager self.config_manager = get_config_manager() self.docpixie: Optional[DocPixie] = None async def create_docpixie_instance(self) -> bool: + """ + Creates an instance of the `DocPixie` class based on the current configuration. + + Returns: + `True` if the instance was created successfully, `False` otherwise. + """ try: - api_key = self.config_manager.get_api_key() - if not api_key: + if not (api_key := self.config_manager.get_api_key()): return False - text_model, vision_model = self.config_manager.get_models() - config = DocPixieConfig( - provider="openrouter", - model=text_model, - vision_model=vision_model, - storage_type="local", - local_storage_path="./.docpixie/documents", - openrouter_api_key=api_key, - jpeg_quality=85, - max_pages_per_task=4 + provider="openrouter", model=text_model, vision_model=vision_model, + storage_type="local", local_storage_path="./.docpixie/documents", + openrouter_api_key=api_key, jpeg_quality=85, max_pages_per_task=4 ) - self.docpixie = DocPixie(config=config) self.app.docpixie = self.docpixie return True - except Exception as e: try: - chat_log = self.app.query_one("#chat-log", ChatArea) - chat_log.write(f"[error]❌ Failed to create DocPixie instance: {e}[/error]") - except: + self.app.query_one("#chat-log", ChatArea).write(f"[error]❌ Failed to create DocPixie instance: {e}[/error]") + except Exception: pass return False async def initialize_docpixie(self, show_welcome: bool = True) -> None: - chat_log = self.app.query_one("#chat-log", ChatArea) + """ + Initializes the DocPixie instance and the application state. + This method orchestrates the startup sequence, including creating the + `DocPixie` instance, checking for documents, and loading the initial + conversation. + + Args: + show_welcome: Whether to show the welcome message after initialization. + """ + chat_log = self.app.query_one("#chat-log", ChatArea) if not await self.create_docpixie_instance(): chat_log.write("[error]❌ No API key configured. Please restart and configure.[/error]") return - try: await self.check_and_prompt_for_documents() await self.load_or_create_conversation() - if show_welcome: self.app.show_welcome_message() - if self.state_manager.current_conversation_id and self.state_manager.conversation_history: chat_log.add_static_text("[dim]━━━ Restored previous conversation ━━━[/dim]\n\n") - for msg in self.state_manager.conversation_history: - if msg.role == "user": - chat_log.add_user_message(msg.content) - else: - chat_log.add_assistant_message(msg.content) - + chat_log.add_user_message(msg.content) if msg.role == "user" else chat_log.add_assistant_message(msg.content) chat_log.add_static_text("[dim]━━━ Continue your conversation below ━━━[/dim]\n\n") - except Exception as e: chat_log.write(f"[error]❌ Failed to initialize: {e}[/error]") async def switch_models(self) -> None: + """ + Re-creates the DocPixie instance when the models are switched. + + This is necessary because the model configuration is set at the time + of `DocPixie` instantiation. + """ await self.create_docpixie_instance() async def check_and_prompt_for_documents(self) -> None: - chat_log = self.app.query_one("#chat-log", ChatArea) + """ + Checks for documents in the local `documents` folder. + If new documents are found or if no documents are present, it prompts + the user to manage them by opening the `DocumentManagerDialog`. + """ + chat_log = self.app.query_one("#chat-log", ChatArea) if not self.state_manager.documents_folder.exists(): self.state_manager.documents_folder.mkdir(parents=True) chat_log.write(f"[green bold]●[/green bold] Created documents folder: {self.state_manager.documents_folder.absolute()}\n") chat_log.write("[blue bold]●[/blue bold] Add PDF files to the ./documents folder or use /documents to manage them.\n") - # Auto-open the Document Manager when the folder is first created - await self.app.push_screen(DocumentManagerDialog( - self.state_manager.documents_folder, - self.docpixie - )) + await self.app.push_screen(DocumentManagerDialog(self.state_manager.documents_folder, self.docpixie)) return self.state_manager.clear_documents() - try: existing_docs = await self.docpixie.list_documents() indexed_names = {doc['name'] for doc in existing_docs} - for doc_meta in existing_docs: - doc = await self.docpixie.get_document(doc_meta['id']) - if doc: + if doc := await self.docpixie.get_document(doc_meta['id']): self.state_manager.add_document(doc) - except Exception as e: indexed_names = set() - chat_log.write(f"[dim]Note: Could not load existing documents: {e}[/dim]\\n") + chat_log.write(f"[dim]Note: Could not load existing documents: {e}[/dim]\n") pdf_files = list(self.state_manager.documents_folder.glob("*.pdf")) - - if not pdf_files: - # Auto-open the Document Manager when there are no PDFs yet - await self.app.push_screen(DocumentManagerDialog( - self.state_manager.documents_folder, - self.docpixie - )) - return - - new_pdf_files = [ - pdf for pdf in pdf_files - if pdf.stem not in indexed_names - ] - - if new_pdf_files: - chat_log.write(f"[blue bold]●[/blue bold] Found {len(new_pdf_files)} new PDF file(s)\n") - await self.app.push_screen(DocumentManagerDialog( - self.state_manager.documents_folder, - self.docpixie - )) + if not pdf_files or any(pdf.stem not in indexed_names for pdf in pdf_files): + await self.app.push_screen(DocumentManagerDialog(self.state_manager.documents_folder, self.docpixie)) async def load_or_create_conversation(self) -> None: + """Loads the last active conversation or creates a new one.""" try: - doc_ids = [doc.id for doc in self.state_manager.indexed_documents] - last_conversation_id = self.state_manager.get_last_conversation_id() - - if last_conversation_id: - if self.state_manager.load_conversation(last_conversation_id): - status_label = self.app.query_one("#status-label") - status_label.update(self.state_manager.get_status_text()) + if last_id := self.state_manager.get_last_conversation_id(): + if self.state_manager.load_conversation(last_id): + self.app.query_one("#status-label").update(self.state_manager.get_status_text()) return - self.state_manager.create_new_conversation() - status_label = self.app.query_one("#status-label") - status_label.update(self.state_manager.get_status_text()) - + self.app.query_one("#status-label").update(self.state_manager.get_status_text()) except Exception as e: print(f"Error loading conversation: {e}") self.state_manager.set_current_conversation(None) async def process_query(self, query: str, task_callback: Optional[Callable] = None) -> None: + """ + Processes a user query and displays the results in the chat area. + + Args: + query: The user's query string. + task_callback: An optional callback function to receive real-time + updates on the agent's task progress. + """ chat_log = self.app.query_one("#chat-log", ChatArea) - if not self.docpixie: - chat_log.write("[error]❌ DocPixie not initialized[/error]\\n") + chat_log.write("[error]❌ DocPixie not initialized[/error]\n") return - if not self.state_manager.has_documents(): - chat_log.write("[warning]⚠️ No documents indexed yet. Use /documents to add and index documents first.[/warning]\\n") + chat_log.write("[warning]⚠️ No documents indexed yet. Use /documents to add and index documents first.[/warning]\n") return self.state_manager.set_processing(True) - try: chat_log.show_processing_status() - - result = await asyncio.get_event_loop().run_in_executor( - None, - self.docpixie.query_sync, - query, - None, # mode - None, # document_ids - None, # max_pages - self.state_manager.conversation_history, - task_callback + result = await asyncio.to_thread( + self.docpixie.query_sync, query, conversation_history=self.state_manager.conversation_history, task_callback=task_callback ) - chat_log.add_assistant_message(result.answer) - - if hasattr(result, 'get_pages_by_document'): - pages_by_doc = result.get_pages_by_document() - if pages_by_doc: - chat_log.write("[dim]Analyzed documents:[/dim]\n") - for doc_name, page_nums in pages_by_doc.items(): - pages_str = ", ".join(str(p) for p in page_nums) - chat_log.write(f"[dim] • {doc_name}: Pages {pages_str}[/dim]\n") - elif hasattr(result, 'page_numbers') and result.page_numbers: - chat_log.write(f"[dim]Analyzed pages: {result.page_numbers}[/dim]\n") - - if hasattr(result, 'processing_time') and result.processing_time > 0: - chat_log.write(f"[dim]Processing time: {result.processing_time:.2f}s[/dim]\n") - - cost = getattr(result, 'total_cost', 0.0) or 0.0 - if cost < 0.01: - chat_log.write(f"[dim]Cost: ${cost:.6f}[/dim]\n") - else: - chat_log.write(f"[dim]Cost: ${cost:.4f}[/dim]\n") - - chat_log.write("\n") - - self.state_manager.add_conversation_message( - ConversationMessage(role="user", content=query) - ) - self.state_manager.add_conversation_message( - ConversationMessage(role="assistant", content=result.answer, - cost=getattr(result, 'total_cost', 0.0) or 0.0) - ) - + self._display_query_metadata(chat_log, result) + self.state_manager.add_conversation_message(ConversationMessage(role="user", content=query)) + self.state_manager.add_conversation_message(ConversationMessage(role="assistant", content=result.answer, cost=getattr(result, 'total_cost', 0.0) or 0.0)) self.state_manager.limit_conversation_history() self.state_manager.save_current_conversation() - - status_label = self.app.query_one("#status-label") - status_label.update(self.state_manager.get_status_text()) - + self.app.query_one("#status-label").update(self.state_manager.get_status_text()) except Exception as e: chat_log.write(f"[red bold]●[/red bold] Error: {e}\n\n") finally: self.state_manager.set_processing(False) + def _display_query_metadata(self, chat_log: ChatArea, result: Any) -> None: + """Displays metadata from the query result in the chat log.""" + if pages_by_doc := getattr(result, 'get_pages_by_document', lambda: None)(): + chat_log.write("[dim]Analyzed documents:[/dim]\n") + for doc_name, page_nums in pages_by_doc.items(): + chat_log.write(f"[dim] • {doc_name}: Pages {', '.join(map(str, page_nums))}[/dim]\n") + elif page_numbers := getattr(result, 'page_numbers', None): + chat_log.write(f"[dim]Analyzed pages: {page_numbers}[/dim]\n") + + if proc_time := getattr(result, 'processing_time', 0.0) > 0: + chat_log.write(f"[dim]Processing time: {proc_time:.2f}s[/dim]\n") + + cost = getattr(result, 'total_cost', 0.0) or 0.0 + chat_log.write(f"[dim]Cost: ${cost:.6f if cost < 0.01 else f'{cost:.4f}'}[/dim]\n\n") + def delete_document_sync(self, document_id: str) -> bool: - if self.docpixie: - try: - return self.docpixie.delete_document_sync(document_id) - except Exception: - return False - return False + """ + Synchronously deletes a document. + + Args: + document_id: The ID of the document to delete. + + Returns: + `True` if deletion was successful, `False` otherwise. + """ + return self.docpixie.delete_document_sync(document_id) if self.docpixie else False diff --git a/docpixie/cli/event_handlers.py b/docpixie/cli/event_handlers.py index e7201f8..3d6f6d4 100644 --- a/docpixie/cli/event_handlers.py +++ b/docpixie/cli/event_handlers.py @@ -1,5 +1,9 @@ """ -Event handling mixins for DocPixie CLI +Event handling mixins for the DocPixie CLI. + +This module defines several mixin classes that encapsulate the logic for +handling different types of events within the Textual application, such as +command, conversation, model, and document-related events. """ from typing import TYPE_CHECKING @@ -17,186 +21,127 @@ class CommandEventMixin: - """Handles command palette and text input events""" + """A mixin for handling command palette and text input events.""" async def on_text_area_changed(self: 'DocPixieTUI', event: TextArea.Changed) -> None: - """Handle text area changes for command palette""" + """Handles changes in the text area to show/hide the command palette.""" if event.text_area.id != "chat-input": return - - lines = event.text_area.text.split('\\n') - if lines: - current_line = lines[-1] if lines else "" - - if current_line.startswith("/"): - command_palette = self.query_one("#command-palette", CommandPalette) - if not self.state_manager.command_palette_active: - self.state_manager.command_palette_active = True - command_palette.show(current_line) - else: - command_palette.update_filter(current_line) + current_line = event.text_area.text.split('\n')[-1] + command_palette = self.query_one("#command-palette", CommandPalette) + if current_line.startswith("/"): + if not self.state_manager.command_palette_active: + self.state_manager.command_palette_active = True + command_palette.show(current_line) else: - if self.state_manager.command_palette_active: - command_palette = self.query_one("#command-palette", CommandPalette) - command_palette.hide() - self.state_manager.command_palette_active = False + command_palette.update_filter(current_line) + elif self.state_manager.command_palette_active: + command_palette.hide() + self.state_manager.command_palette_active = False async def on_key(self: 'DocPixieTUI', event: events.Key) -> None: - """Handle key events for command palette navigation""" + """Handles key events for command palette navigation.""" if self.state_manager.command_palette_active: command_palette = self.query_one("#command-palette", CommandPalette) - - if event.key == "escape": - command_palette.hide() - self.state_manager.command_palette_active = False - text_area = self.query_one("#chat-input") - text_area.clear() - event.prevent_default() - - elif event.key == "up": - command_palette.move_selection_up() - event.prevent_default() - - elif event.key == "down": - command_palette.move_selection_down() - event.prevent_default() - - elif event.key == "tab": - selected = command_palette.get_selected_command() - if selected: - text_area = self.query_one("#chat-input") - text_area.text = selected.command - text_area.cursor_location = (0, len(selected.command)) + if event.key in ("up", "down", "escape", "tab"): event.prevent_default() + if event.key == "up": + command_palette.move_selection_up() + elif event.key == "down": + command_palette.move_selection_down() + elif event.key == "escape": + command_palette.hide() + self.state_manager.command_palette_active = False + self.query_one("#chat-input").clear() + elif event.key == "tab": + if selected := command_palette.get_selected_command(): + text_area = self.query_one("#chat-input") + text_area.text = selected.command + text_area.cursor_location = (0, len(selected.command)) async def on_command_selected(self: 'DocPixieTUI', event: CommandSelected) -> None: - """Handle command selection from palette""" - command_palette = self.query_one("#command-palette", CommandPalette) - command_palette.hide() + """Handles a command selection from the palette.""" + self.query_one("#command-palette", CommandPalette).hide() self.state_manager.command_palette_active = False - - text_area = self.query_one("#chat-input") - text_area.clear() - + self.query_one("#chat-input").clear() await self.handle_command(event.command) async def on_command_auto_complete(self: 'DocPixieTUI', event: CommandAutoComplete) -> None: - """Handle command auto-completion""" + """Handles a command auto-completion event.""" text_area = self.query_one("#chat-input") text_area.text = event.command text_area.cursor_location = (0, len(event.command)) class ConversationEventMixin: - """Handles conversation-related events""" + """A mixin for handling conversation-related events.""" async def on_conversation_selected(self: 'DocPixieTUI', event: ConversationSelected) -> None: - """Handle conversation selection from dialog""" - chat_log = self.query_one("#chat-log", ChatArea) - + """Handles a conversation selection from the dialog.""" if event.conversation_id == "new": await self.handle_command("/new") return - + chat_log = self.query_one("#chat-log", ChatArea) try: self.state_manager.save_current_conversation() - if self.state_manager.load_conversation(event.conversation_id): - conversations = self.state_manager.conversation_storage.list_local_conversations() - metadata = next( - (conv for conv in conversations if conv.id == event.conversation_id), - None - ) - chat_log.clear() - for msg in self.state_manager.conversation_history: - if msg.role == "user": - chat_log.add_user_message(msg.content) - else: - chat_log.add_assistant_message(msg.content) - - status_label = self.query_one("#status-label", Label) - status_label.update(self.state_manager.get_status_text()) - - conv_name = metadata.name if metadata else "Unknown" + (chat_log.add_user_message if msg.role == "user" else chat_log.add_assistant_message)(msg.content) + self.query_one("#status-label", Label).update(self.state_manager.get_status_text()) + conv_name = next((c.name for c in self.state_manager.conversation_storage.list_local_conversations() if c.id == event.conversation_id), "Unknown") chat_log.write(f"[green bold]●[/green bold] Loaded conversation: {conv_name}\n\n") else: chat_log.write("[red bold]●[/red bold] Failed to load conversation\n\n") - except Exception as e: chat_log.write(f"[red bold]●[/red bold] Error loading conversation: {e}\n\n") async def on_conversation_deleted(self: 'DocPixieTUI', event: ConversationDeleted) -> None: - """Handle conversation deletion""" - chat_log = self.query_one("#chat-log", ChatArea) - chat_log.write("[green bold]●[/green bold] Conversation deleted\n\n") + """Handles a conversation deletion event.""" + self.query_one("#chat-log", ChatArea).write("[green bold]●[/green bold] Conversation deleted\n\n") class ModelEventMixin: - """Handles model selection events""" + """A mixin for handling model selection events.""" async def on_model_selected(self: 'DocPixieTUI', event: ModelSelected) -> None: - """Handle model selection""" + """Handles a model selection event.""" chat_log = self.query_one("#chat-log", ChatArea) - + model_changed = False if event.old_text_model and event.text_model != event.old_text_model: chat_log.write(f"[green bold]●[/green bold] Action model switched to {event.text_model}\n\n") - await self.docpixie_manager.switch_models() - elif event.old_vision_model and event.vision_model != event.old_vision_model: + model_changed = True + if event.old_vision_model and event.vision_model != event.old_vision_model: chat_log.write(f"[green bold]●[/green bold] Vision model switched to {event.vision_model}\n\n") + model_changed = True + if model_changed: await self.docpixie_manager.switch_models() else: chat_log.write("[dim]No model changes made[/dim]\n\n") - - status_label = self.query_one("#status-label", Label) - status_label.update(self.state_manager.get_status_text()) + self.query_one("#status-label", Label).update(self.state_manager.get_status_text()) class DocumentEventMixin: - """Handles document management events""" + """A mixin for handling document management events.""" async def on_document_removed(self: 'DocPixieTUI', event: DocumentRemoved) -> None: - """Handle document removal""" + """Handles a document removal event.""" chat_log = self.query_one("#chat-log", ChatArea) - - removed_count = 0 + removed_count = sum(1 for doc_id in event.document_ids if self.state_manager.remove_document(doc_id)) for doc_id in event.document_ids: - if self.state_manager.remove_document(doc_id): - removed_count += 1 - - if self.docpixie: - try: - success = self.docpixie_manager.delete_document_sync(doc_id) - if not success: - doc_name = f"Document {doc_id}" # Fallback name - chat_log.write(f"[warning]Warning: Could not delete {doc_name} from storage[/warning]\n") - except Exception as e: - doc_name = f"Document {doc_id}" # Fallback name - chat_log.write(f"[error]Error deleting {doc_name}: {e}[/error]\n") - - if removed_count == 1: - chat_log.write(f"[green bold]●[/green bold] Removed 1 document from index\n\n") - else: - chat_log.write(f"[green bold]●[/green bold] Removed {removed_count} documents from index\n\n") - - status_label = self.query_one("#status-label", Label) - status_label.update(self.state_manager.get_status_text()) + try: + if self.docpixie and not self.docpixie_manager.delete_document_sync(doc_id): + chat_log.write(f"[warning]Warning: Could not delete Document {doc_id} from storage[/warning]\n") + except Exception as e: + chat_log.write(f"[error]Error deleting Document {doc_id}: {e}[/error]\n") + chat_log.write(f"[green bold]●[/green bold] Removed {removed_count} document(s) from index\n\n") + self.query_one("#status-label", Label).update(self.state_manager.get_status_text()) async def on_documents_indexed(self: 'DocPixieTUI', event: DocumentsIndexed) -> None: - """Handle documents being indexed""" + """Handles a documents indexed event.""" chat_log = self.query_one("#chat-log", ChatArea) - - indexed_count = 0 + indexed_count = sum(1 for doc in event.documents if not any(d.id == doc.id for d in self.state_manager.indexed_documents)) for doc in event.documents: - if not any(existing.id == doc.id for existing in self.state_manager.indexed_documents): - self.state_manager.add_document(doc) - indexed_count += 1 - - if indexed_count == 1: - chat_log.write(f"[green bold]●[/green bold] Successfully indexed 1 document\n\n") - else: - chat_log.write(f"[green bold]●[/green bold] Successfully indexed {indexed_count} documents\n\n") - - status_label = self.query_one("#status-label", Label) - status_label.update(self.state_manager.get_status_text()) + self.state_manager.add_document(doc) + chat_log.write(f"[green bold]●[/green bold] Successfully indexed {indexed_count} document(s)\n\n") + self.query_one("#status-label", Label).update(self.state_manager.get_status_text()) diff --git a/docpixie/cli/state_manager.py b/docpixie/cli/state_manager.py index 4d8c117..bdfc6f9 100644 --- a/docpixie/cli/state_manager.py +++ b/docpixie/cli/state_manager.py @@ -1,5 +1,9 @@ """ -State management for DocPixie CLI application +State management for the DocPixie CLI application. + +This module provides the `AppStateManager` class, which centralizes the +management of the application's state, including conversations, documents, +and UI-related states. """ from pathlib import Path @@ -11,172 +15,166 @@ class AppStateManager: - """Manages application state including conversations, documents, and UI state""" - + """ + Manages the application state, including conversations, documents, and UI state. + + This class serves as a single source of truth for the application's state, + making it easier to manage and reason about the data and UI. + + Attributes: + indexed_documents: A list of currently indexed documents. + conversation_history: A list of messages in the current conversation. + current_conversation_id: The ID of the active conversation. + documents_folder: The path to the folder where documents are stored. + processing: A boolean indicating if the app is currently processing a query. + command_palette_active: A boolean for the command palette's visibility. + partial_command: The partial command string for command palette filtering. + default_input_hint: The default hint text for the input area. + current_plan: The current task plan from the RAG agent. + completed_tasks: A set of completed task names. + config_manager: An instance of the `ConfigManager`. + conversation_storage: An instance of the `ConversationStorage`. + """ + def __init__(self): + """Initializes the `AppStateManager`.""" self.indexed_documents: List[Document] = [] self.conversation_history: List[ConversationMessage] = [] self.current_conversation_id: Optional[str] = None self.documents_folder = Path("./documents") self.processing = False - self.command_palette_active = False self.partial_command = "" - self.default_input_hint = ( - "Press / for commands • Shift+Enter: new line • Shift+Tab: switch panel" - ) - + self.default_input_hint = "Press / for commands • Shift+Enter: new line • Shift+Tab: switch panel" self.current_plan: Optional[Any] = None - self.completed_tasks: Set = set() - + self.completed_tasks: Set[str] = set() self.config_manager = get_config_manager() self.conversation_storage = ConversationStorage() - + def get_status_text(self) -> str: - """Get current status bar text with emoji prefixes""" + """Generates the text for the status bar.""" text_model, vision_model = self.config_manager.get_models() - doc_count = len(self.indexed_documents) - segments = [ - f"📄: {doc_count}", + f"📄: {len(self.indexed_documents)}", f"🧠: {text_model.split('/')[-1]}", f"👁️: {vision_model.split('/')[-1]}", ] - if self.current_conversation_id: - conversations = self.conversation_storage.list_local_conversations() - current_conv = next( - (conv for conv in conversations if conv.id == self.current_conversation_id), - None, - ) - if current_conv: - # Conversation name (truncate to 20 chars, add ellipsis if longer) - conv_name = current_conv.name[:20] + ("..." if len(current_conv.name) > 20 else "") + conv = next((c for c in self.conversation_storage.list_local_conversations() if c.id == self.current_conversation_id), None) + if conv: + conv_name = (conv.name[:20] + "...") if len(conv.name) > 20 else conv.name segments.append(f"💬: {conv_name}") - - # Total cost formatting - total_cost = getattr(current_conv, "total_cost", 0.0) or 0.0 - if total_cost < 0.01: - segments.append(f"💰: {total_cost:.6f}") - else: - segments.append(f"💰: {total_cost:.4f}") - + cost = getattr(conv, "total_cost", 0.0) or 0.0 + segments.append(f"💰: {cost:.6f}" if cost < 0.01 else f"💰: {cost:.4f}") return " | ".join(segments) - + def add_document(self, document: Document) -> None: - """Add a document to the indexed documents list""" - if not any(existing.id == document.id for existing in self.indexed_documents): + """Adds a document to the list of indexed documents.""" + if not any(d.id == document.id for d in self.indexed_documents): self.indexed_documents.append(document) - + def remove_document(self, document_id: str) -> bool: - """Remove a document from the indexed documents list""" - for doc in self.indexed_documents[:]: - if doc.id == document_id: - self.indexed_documents.remove(doc) - return True - return False - + """Removes a document from the list of indexed documents.""" + initial_len = len(self.indexed_documents) + self.indexed_documents = [d for d in self.indexed_documents if d.id != document_id] + return len(self.indexed_documents) < initial_len + def clear_documents(self) -> None: - """Clear all indexed documents""" + """Clears all indexed documents.""" self.indexed_documents.clear() - + def add_conversation_message(self, message: ConversationMessage) -> None: - """Add a message to conversation history""" + """Adds a message to the current conversation history.""" self.conversation_history.append(message) - + def limit_conversation_history(self, max_messages: int = 20) -> None: - """Limit conversation history to maximum number of messages""" + """Limits the conversation history to a maximum number of messages.""" if len(self.conversation_history) > max_messages: self.conversation_history = self.conversation_history[-max_messages:] - + def clear_conversation_history(self) -> None: - """Clear conversation history""" - self.conversation_history = [] - + """Clears the current conversation history.""" + self.conversation_history.clear() + def set_current_conversation(self, conversation_id: Optional[str]) -> None: - """Set the current conversation ID""" + """Sets the ID of the current conversation.""" self.current_conversation_id = conversation_id - + def create_new_conversation(self) -> str: - """Create a new conversation and return its ID""" + """Creates a new conversation and returns its ID.""" doc_ids = [doc.id for doc in self.indexed_documents] self.current_conversation_id = self.conversation_storage.create_new_conversation(doc_ids) self.conversation_history = [] return self.current_conversation_id - + def load_conversation(self, conversation_id: str) -> bool: - """Load a conversation by ID""" - result = self.conversation_storage.load_conversation(conversation_id) - if result: - metadata, messages = result + """Loads a conversation by its ID.""" + if result := self.conversation_storage.load_conversation(conversation_id): + _, self.conversation_history = result self.current_conversation_id = conversation_id - self.conversation_history = messages return True return False - + def save_current_conversation(self) -> None: - """Save the current conversation if it exists""" + """Saves the current conversation to storage.""" if self.current_conversation_id and self.conversation_history: - doc_ids = [doc.id for doc in self.indexed_documents] self.conversation_storage.save_conversation( - self.current_conversation_id, - self.conversation_history, - doc_ids + self.current_conversation_id, self.conversation_history, [d.id for d in self.indexed_documents] ) - + def get_last_conversation_id(self) -> Optional[str]: - """Get the ID of the last conversation""" + """Gets the ID of the last active conversation.""" return self.conversation_storage.get_last_conversation() - + def set_processing(self, processing: bool) -> None: - """Set processing state""" + """Sets the application's processing state.""" self.processing = processing - + def is_processing(self) -> bool: - """Check if currently processing""" + """Checks if the application is currently processing.""" return self.processing - + def set_command_palette_active(self, active: bool) -> None: - """Set command palette active state""" + """Sets the active state of the command palette.""" self.command_palette_active = active - + def is_command_palette_active(self) -> bool: - """Check if command palette is active""" + """Checks if the command palette is active.""" return self.command_palette_active - + def set_partial_command(self, command: str) -> None: - """Set partial command text""" + """Sets the partial command text for filtering.""" self.partial_command = command - + def get_partial_command(self) -> str: - """Get partial command text""" + """Gets the partial command text.""" return self.partial_command - + def set_current_plan(self, plan: Optional[Any]) -> None: - """Set current task plan""" + """Sets the current task plan from the agent.""" self.current_plan = plan - + def get_current_plan(self) -> Optional[Any]: - """Get current task plan""" + """Gets the current task plan.""" return self.current_plan - + def clear_task_plan(self) -> None: - """Clear current task plan and completed tasks""" + """Clears the current task plan and completed tasks.""" self.current_plan = None self.completed_tasks.clear() - + def add_completed_task(self, task_name: str) -> None: - """Mark a task as completed""" + """Adds a task name to the set of completed tasks.""" self.completed_tasks.add(task_name) - + def get_completed_tasks(self) -> List[str]: - """Get list of completed task names""" + """Gets a list of completed task names.""" return list(self.completed_tasks) - + def has_documents(self) -> bool: - """Check if any documents are indexed""" - return len(self.indexed_documents) > 0 - + """Checks if there are any indexed documents.""" + return bool(self.indexed_documents) + def has_conversation_history(self) -> bool: - """Check if conversation history exists""" - return len(self.conversation_history) > 0 + """Checks if there is any conversation history.""" + return bool(self.conversation_history) diff --git a/docpixie/cli/task_display.py b/docpixie/cli/task_display.py index cec8090..34ca01d 100644 --- a/docpixie/cli/task_display.py +++ b/docpixie/cli/task_display.py @@ -1,5 +1,9 @@ """ -Task display management for DocPixie CLI +Task display management for the DocPixie CLI. + +This module provides the `TaskDisplayManager` class, which is responsible for +rendering updates about the RAG agent's task plan and progress in the chat +interface. """ from typing import TYPE_CHECKING, Any @@ -11,70 +15,102 @@ class TaskDisplayManager: - """Manages task plan and progress display in the chat interface""" - + """ + Manages the display of the agent's task plan and progress. + + This class receives task-related events from the `DocPixieManager` and + translates them into visual updates in the `ChatArea` widget. + + Attributes: + app: The main `DocPixieTUI` application instance. + state_manager: The application's state manager. + """ + def __init__(self, app: 'DocPixieTUI', state_manager: AppStateManager): + """ + Initializes the `TaskDisplayManager`. + + Args: + app: The main `DocPixieTUI` application instance. + state_manager: The application's state manager. + """ self.app = app self.state_manager = state_manager - + def display_task_update(self, event_type: str, data: Any) -> None: - """Display task plan updates""" + """ + Displays updates related to the agent's task plan and execution. + + Args: + event_type: The type of the task-related event (e.g., + 'plan_created', 'task_started'). + data: The data associated with the event. + """ chat_log = self.app.query_one("#chat-log", ChatArea) - - if event_type == 'plan_created': - plan = data - self.state_manager.current_plan = plan - self.state_manager.completed_tasks.clear() - chat_log.hide_processing_status(mark_done=True, final_text="Planning") - chat_log.show_plan(plan) - - elif event_type == 'plan_updated': - plan = data - self.state_manager.current_plan = plan - chat_log.show_plan(plan, is_update=True, completed_tasks=list(self.state_manager.completed_tasks)) - - elif event_type == 'task_started': - task = data['task'] - task_name = task.name if hasattr(task, 'name') else str(task) - - doc_name = self._get_document_name_for_task(task) - chat_log.show_task_progress(task_name, None, doc_name) - - elif event_type == 'pages_selected': - task = data['task'] - page_numbers = data.get('page_numbers', []) - task_name = task.name if hasattr(task, 'name') else str(task) - - doc_name = self._get_document_name_for_task(task) - pages_count = len(page_numbers) if isinstance(page_numbers, (list, tuple)) else 0 - chat_log.show_task_progress(task_name, pages_count, doc_name) - - elif event_type == 'task_completed': - task = data['task'] - task_name = task.name if hasattr(task, 'name') else str(task) - - chat_log.update_task_status(task_name, done=True) - self.state_manager.completed_tasks.add(task_name) - - if self.state_manager.current_plan: - chat_log.show_plan( - self.state_manager.current_plan, - is_update=True, - completed_tasks=list(self.state_manager.completed_tasks) - ) - - def _get_document_name_for_task(self, task) -> str: - """Extract document name from task, with fallback to 'document'""" - doc_name = 'document' + event_handlers = { + 'plan_created': self._handle_plan_created, + 'plan_updated': self._handle_plan_updated, + 'task_started': self._handle_task_started, + 'pages_selected': self._handle_pages_selected, + 'task_completed': self._handle_task_completed, + } + if handler := event_handlers.get(event_type): + handler(chat_log, data) + + def _handle_plan_created(self, chat_log: ChatArea, data: Any): + """Handles the 'plan_created' event.""" + self.state_manager.current_plan = data + self.state_manager.completed_tasks.clear() + chat_log.hide_processing_status(mark_done=True, final_text="Planning") + chat_log.show_plan(data) + + def _handle_plan_updated(self, chat_log: ChatArea, data: Any): + """Handles the 'plan_updated' event.""" + self.state_manager.current_plan = data + chat_log.show_plan(data, is_update=True, completed_tasks=list(self.state_manager.completed_tasks)) + + def _handle_task_started(self, chat_log: ChatArea, data: Any): + """Handles the 'task_started' event.""" + task = data['task'] + task_name = getattr(task, 'name', str(task)) + doc_name = self._get_document_name_for_task(task) + chat_log.show_task_progress(task_name, None, doc_name) + + def _handle_pages_selected(self, chat_log: ChatArea, data: Any): + """Handles the 'pages_selected' event.""" + task = data['task'] + page_numbers = data.get('page_numbers', []) + task_name = getattr(task, 'name', str(task)) + doc_name = self._get_document_name_for_task(task) + chat_log.show_task_progress(task_name, len(page_numbers), doc_name) + + def _handle_task_completed(self, chat_log: ChatArea, data: Any): + """Handles the 'task_completed' event.""" + task = data['task'] + task_name = getattr(task, 'name', str(task)) + chat_log.update_task_status(task_name, done=True) + self.state_manager.completed_tasks.add(task_name) + if self.state_manager.current_plan: + chat_log.show_plan( + self.state_manager.current_plan, is_update=True, + completed_tasks=list(self.state_manager.completed_tasks) + ) + + def _get_document_name_for_task(self, task: Any) -> str: + """ + Extracts the document name for a given task. + + Args: + task: The task object. + + Returns: + The name of the document associated with the task, or a default + string if not found. + """ try: - task_doc_id = getattr(task, 'document', '') - if task_doc_id: - doc = next( - (d for d in self.state_manager.indexed_documents if d.id == task_doc_id), - None - ) - if doc and getattr(doc, 'name', None): - doc_name = doc.name + if task_doc_id := getattr(task, 'document', ''): + if doc := next((d for d in self.state_manager.indexed_documents if d.id == task_doc_id), None): + return getattr(doc, 'name', 'document') except Exception: pass - return doc_name \ No newline at end of file + return 'document' \ No newline at end of file diff --git a/docpixie/core/config.py b/docpixie/core/config.py index f54c829..cb4975a 100644 --- a/docpixie/core/config.py +++ b/docpixie/core/config.py @@ -1,6 +1,9 @@ """ -DocPixie Configuration -Simplified version of production config without embedding/vector DB settings +Configuration for the DocPixie application. + +This module defines the `DocPixieConfig` dataclass, which centralizes all +the configuration options for the application. It includes settings for +document processing, storage, AI providers, and the RAG agent. """ import os @@ -11,103 +14,111 @@ @dataclass class DocPixieConfig: - """DocPixie configuration with sensible defaults""" - - # Document Processing with PyMuPDF - pdf_render_scale: float = 2.0 # Higher scale = better quality, larger files + """ + A dataclass for storing DocPixie configuration, with sensible defaults. + + This class provides a centralized way to manage all configuration settings. + It can be instantiated with default values, and can also load settings + from environment variables or a dictionary. + + Attributes: + pdf_render_scale: The scale factor for rendering PDF pages. Higher + values result in better quality but larger files. + pdf_max_image_size: The maximum dimensions for rendered PDF page images. + jpeg_quality: The quality setting for JPEG images (1-100). + thumbnail_size: The dimensions for generated page thumbnails. + vision_detail: The level of detail for vision model analysis. + storage_type: The type of storage backend to use ('local', 'memory'). + local_storage_path: The file path for the local storage backend. + provider: The AI provider to use ('openai', 'anthropic', 'openrouter'). + model: The primary language model for all operations. + vision_model: The vision-capable model for multimodal analysis. + openai_api_key: The API key for OpenAI. + anthropic_api_key: The API key for Anthropic. + openrouter_api_key: The API key for OpenRouter. + max_agent_iterations: The maximum number of adaptive planning + iterations for the agent. + max_pages_per_task: The maximum number of pages to analyze per task. + max_tasks_per_plan: The maximum number of tasks in the initial plan. + max_conversation_turns: The number of conversation turns before + summarization is triggered. + turns_to_summarize: The number of initial turns to include in a summary. + turns_to_keep_full: The number of recent turns to keep in full. + log_level: The logging level for the application. + log_requests: Whether to log API requests. + """ + + # Document Processing + pdf_render_scale: float = 2.0 pdf_max_image_size: Tuple[int, int] = (1200, 1200) jpeg_quality: int = 90 - thumbnail_size: Tuple[int, int] = (256, 256) # For quick page selection + thumbnail_size: Tuple[int, int] = (256, 256) # Processing settings - vision_detail: str = "high" # Use full resolution for best quality + vision_detail: str = "high" # Storage - storage_type: str = "local" # local, memory, s3 + storage_type: str = "local" local_storage_path: str = "./docpixie_data" - # AI Provider Settings (Provider-agnostic) - provider: str = "openai" # openai, anthropic, openrouter - model: str = "gpt-4o" # Primary model for all operations - vision_model: str = "gpt-4o" # Vision model for multimodal analysis + # AI Provider Settings + provider: str = "openai" + model: str = "gpt-4o" + vision_model: str = "gpt-4o" - # API keys loaded from environment variables only + # API keys openai_api_key: Optional[str] = None anthropic_api_key: Optional[str] = None openrouter_api_key: Optional[str] = None # Agent Settings - max_agent_iterations: int = 5 # Maximum adaptive planning iterations - max_pages_per_task: int = 6 # Maximum pages to analyze per task - max_tasks_per_plan: int = 4 # Maximum tasks in initial plan + max_agent_iterations: int = 5 + max_pages_per_task: int = 6 + max_tasks_per_plan: int = 4 - # Conversation Processing Settings - max_conversation_turns: int = 8 # When to start summarizing conversation - turns_to_summarize: int = 5 # How many turns to summarize - turns_to_keep_full: int = 3 # How many recent turns to keep in full + # Conversation Processing + max_conversation_turns: int = 8 + turns_to_summarize: int = 5 + turns_to_keep_full: int = 3 # Logging log_level: str = "INFO" log_requests: bool = False def __post_init__(self): - """Initialize and validate configuration""" - # Create storage directory if it doesn't exist + """Initializes and validates the configuration after instantiation.""" if self.storage_type == "local": Path(self.local_storage_path).mkdir(parents=True, exist_ok=True) - # Load API keys from environment if not provided - if not self.openai_api_key: - self.openai_api_key = os.getenv("OPENAI_API_KEY") - - if not self.anthropic_api_key: - self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") - - if not self.openrouter_api_key: - self.openrouter_api_key = os.getenv("OPENROUTER_API_KEY") + self.openai_api_key = self.openai_api_key or os.getenv("OPENAI_API_KEY") + self.anthropic_api_key = self.anthropic_api_key or os.getenv("ANTHROPIC_API_KEY") + self.openrouter_api_key = self.openrouter_api_key or os.getenv("OPENROUTER_API_KEY") - # Set provider-specific default models if using defaults self._set_provider_defaults() - - # Skip validation with test API keys (for testing) - if self.openai_api_key != "test-key" and self.anthropic_api_key != "test-key" and self.openrouter_api_key != "test-key": - # Validate required settings based on provider - if self.provider == "openai" and not self.openai_api_key: - raise ValueError("OpenAI API key is required when using OpenAI provider") - - if self.provider == "anthropic" and not self.anthropic_api_key: - raise ValueError("Anthropic API key is required when using Anthropic provider") - - if self.provider == "openrouter" and not self.openrouter_api_key: - raise ValueError("OpenRouter API key is required when using OpenRouter provider") - - # Validate image settings + self._validate() + + def _validate(self): + """Performs validation of the configuration settings.""" + if self.provider == "openai" and not self.openai_api_key: + raise ValueError("OpenAI API key is required for the OpenAI provider.") + if self.provider == "anthropic" and not self.anthropic_api_key: + raise ValueError("Anthropic API key is required for the Anthropic provider.") + if self.provider == "openrouter" and not self.openrouter_api_key: + raise ValueError("OpenRouter API key is required for the OpenRouter provider.") if self.pdf_render_scale <= 0: - raise ValueError("PDF render scale must be positive") - - if self.jpeg_quality < 1 or self.jpeg_quality > 100: - raise ValueError("JPEG quality must be between 1 and 100") + raise ValueError("PDF render scale must be a positive number.") + if not 1 <= self.jpeg_quality <= 100: + raise ValueError("JPEG quality must be between 1 and 100.") def _set_provider_defaults(self): - """Set appropriate default models based on provider""" + """Sets default models based on the selected provider, if not already set.""" provider_defaults = { - "openai": { - "model": "gpt-4o", - "vision_model": "gpt-4o" - }, - "anthropic": { - "model": "claude-3-opus-20240229", - "vision_model": "claude-3-opus-20240229" - }, - "openrouter": { - "model": "openai/gpt-4o", - "vision_model": "openai/gpt-4o" - } + "openai": {"model": "gpt-4o", "vision_model": "gpt-4o"}, + "anthropic": {"model": "claude-3-opus-20240229", "vision_model": "claude-3-opus-20240229"}, + "openrouter": {"model": "openai/gpt-4o", "vision_model": "openai/gpt-4o"} } - if self.provider in provider_defaults: defaults = provider_defaults[self.provider] - # Only update if still using OpenAI defaults (means user didn't specify custom models) if self.model == "gpt-4o": self.model = defaults["model"] if self.vision_model == "gpt-4o": @@ -115,52 +126,32 @@ def _set_provider_defaults(self): @classmethod def from_dict(cls, config_dict: Dict[str, Any]) -> 'DocPixieConfig': - """Create config from dictionary""" + """Creates a `DocPixieConfig` instance from a dictionary.""" return cls(**config_dict) @classmethod def from_env(cls) -> 'DocPixieConfig': - """Create config from environment variables""" + """Creates a `DocPixieConfig` instance from environment variables.""" config_dict = {} - - # Map environment variables to config fields env_mapping = { 'DOCPIXIE_PROVIDER': 'provider', 'DOCPIXIE_STORAGE_PATH': 'local_storage_path', - 'DOCPIXIE_JPEG_QUALITY': 'jpeg_quality', + 'DOCPIXIE_JPEG_QUALITY': ('jpeg_quality', int), 'DOCPIXIE_LOG_LEVEL': 'log_level', } - - for env_var, config_field in env_mapping.items(): - value = os.getenv(env_var) - if value is not None: - # Convert string values to appropriate types - if config_field in ['jpeg_quality']: - config_dict[config_field] = int(value) - elif config_field in ['enable_cache']: - config_dict[config_field] = value.lower() in ('true', '1', 'yes') + for env_var, field_info in env_mapping.items(): + if (value := os.getenv(env_var)) is not None: + if isinstance(field_info, tuple): + field, type_cast = field_info + config_dict[field] = type_cast(value) else: - config_dict[config_field] = value - + config_dict[field_info] = value return cls(**config_dict) def get_query_config(self) -> Dict[str, Any]: - """Get configuration for query processing""" - return { - 'vision_detail': self.vision_detail, - 'model': self.model - } + """Returns a dictionary of settings relevant to query processing.""" + return {'vision_detail': self.vision_detail, 'model': self.model} def validate_provider_config(self) -> None: - """Validate provider-specific configuration""" - if self.provider == "openai": - if not self.openai_api_key: - raise ValueError("OpenAI API key is required") - elif self.provider == "anthropic": - if not self.anthropic_api_key: - raise ValueError("Anthropic API key is required") - elif self.provider == "openrouter": - if not self.openrouter_api_key: - raise ValueError("OpenRouter API key is required") - else: - raise ValueError(f"Unsupported provider: {self.provider}") + """Validates that the required configuration for the selected provider is present.""" + self._validate() diff --git a/docpixie/core/utils.py b/docpixie/core/utils.py index 99baba6..f9954a5 100644 --- a/docpixie/core/utils.py +++ b/docpixie/core/utils.py @@ -1,35 +1,33 @@ """ -Core utility functions for DocPixie +Core utility functions for the DocPixie application. + +This module provides common utility functions that are used across different +parts of the application. """ import re def sanitize_llm_json(response: str) -> str: """ - Sanitize JSON response from LLM by removing markdown code blocks and extra whitespace. - - LLMs sometimes wrap JSON responses with markdown code blocks like: - ```json - {"key": "value"} - ``` - - This function strips those wrappers and returns clean JSON. - + Sanitizes a JSON response from a language model. + + This function cleans up a string that is expected to be a JSON object by + removing common artifacts from language model outputs, such as markdown + code blocks (e.g., ```json ... ```) and extra whitespace. + Args: - response: Raw response string from LLM - + response: The raw response string from the language model. + Returns: - Sanitized JSON string ready for json.loads() + A sanitized JSON string that is ready to be parsed. """ - # Strip leading/trailing whitespace cleaned = response.strip() - + # Remove markdown code block wrappers - # Matches ```json...``` or ```...``` patterns code_block_pattern = r'^```(?:json)?\s*\n?(.*?)\n?```$' match = re.match(code_block_pattern, cleaned, re.DOTALL | re.IGNORECASE) - + if match: cleaned = match.group(1).strip() - + return cleaned \ No newline at end of file diff --git a/docpixie/docpixie.py b/docpixie/docpixie.py index 435dac3..08cfec3 100644 --- a/docpixie/docpixie.py +++ b/docpixie/docpixie.py @@ -1,6 +1,9 @@ """ -Main DocPixie API class -Simplified multimodal RAG without embeddings or vector databases +The main DocPixie API class. + +This module provides the primary interface for interacting with the DocPixie +system. It offers a simplified, multimodal RAG (Retrieval-Augmented Generation) +solution that does not require embeddings or vector databases. """ import asyncio @@ -9,7 +12,7 @@ import logging from .models.document import ( - Document, Page, QueryResult, QueryMode, + Document, Page, QueryResult, QueryMode, DocumentProcessRequest, QueryRequest, DocumentStatus ) from .models.agent import ConversationMessage @@ -21,19 +24,25 @@ from .ai.summarizer import PageSummarizer from .ai.agent import PixieRAGAgent from .providers import create_provider -from .utils.async_helpers import sync_wrapper, make_sync_version +from .utils.async_helpers import sync_wrapper logger = logging.getLogger(__name__) class DocPixie: """ - Main DocPixie API class for multimodal RAG - - Provides both Flash (quick) and Pro (comprehensive) modes - without requiring vector databases or embeddings. + The main DocPixie API class for multimodal RAG. + + This class provides a high-level interface for adding, managing, and + querying documents. It integrates all the components of the DocPixie + system, including document processing, storage, and the AI-powered RAG + agent. + + The class offers both asynchronous methods for use in async applications + and synchronous wrappers for easier adoption in traditional synchronous + code. """ - + def __init__( self, config: Optional[DocPixieConfig] = None, @@ -41,47 +50,39 @@ def __init__( api_key: Optional[str] = None ): """ - Initialize DocPixie - + Initializes the DocPixie instance. + Args: - config: Configuration object (uses defaults if None) - storage: Storage backend (uses local storage if None) - api_key: API key for AI provider (can also use env vars) + config: A `DocPixieConfig` object. If `None`, a default + configuration is used. + storage: A storage backend that inherits from `BaseStorage`. If + `None`, a backend is created based on the configuration. + api_key: The API key for the AI provider. This can also be set via + environment variables. """ - # Initialize configuration if config is None: config = DocPixieConfig() - - # Override API key if provided + if api_key: if config.provider == "openai": config.openai_api_key = api_key elif config.provider == "anthropic": config.anthropic_api_key = api_key - + self.config = config - - # Initialize components self.processor_factory = ProcessorFactory(config) - - # Initialize storage + if storage is None: - if config.storage_type == "memory": - self.storage = InMemoryStorage(config) - else: - self.storage = LocalStorage(config) + self.storage = InMemoryStorage(config) if config.storage_type == "memory" else LocalStorage(config) else: self.storage = storage - - # Initialize AI components + self.provider = create_provider(config) self.summarizer = PageSummarizer(config) self.agent = PixieRAGAgent(self.provider, self.storage, config) - + logger.info(f"Initialized DocPixie with {config.provider} provider and {type(self.storage).__name__} storage") - - # Document Management - + async def add_document( self, file_path: Union[str, Path], @@ -89,56 +90,88 @@ async def add_document( document_name: Optional[str] = None ) -> Document: """ - Add a document to the RAG system - + Adds a document to the DocPixie system. + + This method processes the document file, generates a summary, and saves + it to the configured storage backend. + Args: - file_path: Path to document file (PDF, image, etc.) - document_id: Optional custom document ID - document_name: Optional custom document name - + file_path: The path to the document file (e.g., PDF, image). + document_id: An optional custom ID for the document. + document_name: An optional custom name for the document. + Returns: - Processed Document object with summary + The processed `Document` object, including its summary. """ file_path = str(file_path) logger.info(f"Adding document: {file_path}") - - # Process document + processor = self.processor_factory.get_processor(file_path) document = await processor.process(file_path, document_id) - - # Override name if provided + if document_name: document.name = document_name - - # Always generate document summary + logger.info(f"Generating document summary for {document.name}") document = await self.summarizer.summarize_document(document) - - # Save to storage + document.status = DocumentStatus.COMPLETED await self.storage.save_document(document) - + logger.info(f"Successfully added document {document.id}: {document.name}") return document - + async def get_document(self, document_id: str) -> Optional[Document]: - """Get a document by ID""" + """ + Retrieves a document by its ID. + + Args: + document_id: The ID of the document to retrieve. + + Returns: + The `Document` object if found, otherwise `None`. + """ return await self.storage.get_document(document_id) - + async def list_documents(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """List all documents with metadata""" + """ + Lists all documents with their metadata. + + Args: + limit: The maximum number of documents to return. + + Returns: + A list of dictionaries, where each dictionary contains the + metadata of a document. + """ return await self.storage.list_documents(limit) - + async def delete_document(self, document_id: str) -> bool: - """Delete a document and its associated files""" + """ + Deletes a document and its associated files. + + Args: + document_id: The ID of the document to delete. + + Returns: + `True` if the document was successfully deleted, `False` otherwise. + """ return await self.storage.delete_document(document_id) - + async def search_documents(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: - """Search documents by name and summary""" + """ + Searches for documents by their name and summary. + + Args: + query: The search query. + limit: The maximum number of results to return. + + Returns: + A list of dictionaries containing the metadata of matching + documents. + """ return await self.storage.search_documents(query, limit) - - # Query Processing (Phase 2 - Vision-based RAG with Adaptive Agent) - + async def query( self, question: str, @@ -147,46 +180,43 @@ async def query( max_pages: Optional[int] = None, stream: bool = False, conversation_history: Optional[List[ConversationMessage]] = None, - task_update_callback: Optional[Any] = None + task_update_callback: Optional[Callable] = None ) -> QueryResult: """ - Query documents with a question using adaptive vision-based RAG - + Queries the documents with a question using the adaptive RAG agent. + Args: - question: User's question - mode: Query mode (Flash, Pro, or Auto) - currently all use adaptive mode - document_ids: Specific documents to search (None = all) - currently searches all - max_pages: Maximum pages to analyze (uses config setting) - stream: Whether to stream the response (not implemented) - conversation_history: Previous conversation context - + question: The user's question. + mode: The query mode (currently only `AUTO` is supported). + document_ids: A list of specific document IDs to search. If `None`, + all documents are searched. + max_pages: The maximum number of pages to analyze. + stream: Whether to stream the response (not yet implemented). + conversation_history: The history of the current conversation. + task_update_callback: An optional callback function to receive + updates on the agent's progress. + Returns: - QueryResult with answer and metadata + A `QueryResult` object containing the answer and metadata. """ logger.info(f"Processing query with adaptive RAG agent: {question}") - try: - # Use the adaptive RAG agent for processing agent_result = await self.agent.process_query(question, conversation_history, task_update_callback) - - # Convert AgentQueryResult to public API QueryResult format return QueryResult( query=agent_result.query, answer=agent_result.answer, selected_pages=agent_result.get_unique_pages(), - mode=mode, # Keep the requested mode for compatibility + mode=mode, confidence=self._calculate_confidence(agent_result), processing_time=agent_result.processing_time_seconds, - total_cost=agent_result.total_cost, # Include the cost + total_cost=agent_result.total_cost, metadata={ 'agent_iterations': agent_result.total_iterations, 'tasks_completed': len(agent_result.task_results), 'total_pages_analyzed': agent_result.get_total_pages_analyzed(), 'agent_mode': 'adaptive', - 'phase': 'Phase 2 - Adaptive Vision RAG' } ) - except Exception as e: logger.error(f"Query processing failed: {e}") return QueryResult( @@ -199,23 +229,16 @@ async def query( total_cost=0.0, metadata={'error': str(e)} ) - - def _calculate_confidence(self, agent_result) -> float: - """Calculate confidence score based on agent execution""" - # Simple confidence calculation based on successful completion + + def _calculate_confidence(self, agent_result: Any) -> float: + """Calculates a confidence score based on the agent's execution.""" if not agent_result.task_results: return 0.0 - - # Base confidence on task completion and page analysis - task_success_rate = len([r for r in agent_result.task_results - if r.analysis and not r.analysis.startswith("Task execution failed")]) / len(agent_result.task_results) - - # Boost confidence if we analyzed pages + successful_tasks = [r for r in agent_result.task_results if r.analysis and not r.analysis.startswith("Task execution failed")] + task_success_rate = len(successful_tasks) / len(agent_result.task_results) page_boost = min(0.2, agent_result.get_total_pages_analyzed() * 0.02) - - # Cap at 1.0 return min(1.0, 0.6 + (task_success_rate * 0.3) + page_boost) - + async def query_with_conversation( self, question: str, @@ -223,38 +246,52 @@ async def query_with_conversation( mode: QueryMode = QueryMode.AUTO ) -> QueryResult: """ - Convenience method for conversation-aware queries - + A convenience method for performing conversation-aware queries. + Args: - question: Current user question - conversation_history: Previous conversation messages - mode: Query mode - + question: The current user question. + conversation_history: The history of previous messages in the + conversation. + mode: The query mode. + Returns: - QueryResult with conversation context + A `QueryResult` object that takes the conversation context into + account. """ return await self.query( - question=question, - mode=mode, + question=question, + mode=mode, conversation_history=conversation_history ) - - # Convenience Methods - + def supports_file(self, file_path: str) -> bool: - """Check if file type is supported""" + """ + Checks if a given file type is supported. + + Args: + file_path: The path to the file. + + Returns: + `True` if the file type is supported, `False` otherwise. + """ return self.processor_factory.supports_file(file_path) - + def get_supported_extensions(self) -> Dict[str, str]: - """Get all supported file extensions""" + """ + Gets a dictionary of all supported file extensions. + + Returns: + A dictionary mapping file extensions to processor types. + """ return self.processor_factory.get_supported_extensions() - + def get_stats(self) -> Dict[str, Any]: - """Get system statistics""" - storage_stats = self.storage.get_storage_stats() - summarizer_stats = self.summarizer.get_summary_stats() - agent_stats = self.agent.get_agent_stats() - + """ + Retrieves statistics about the DocPixie system. + + Returns: + A dictionary containing system statistics. + """ return { 'docpixie_version': '0.1.0', 'config': { @@ -263,105 +300,80 @@ def get_stats(self) -> Dict[str, Any]: 'max_agent_iterations': self.config.max_agent_iterations, 'max_pages_per_task': self.config.max_pages_per_task }, - 'storage': storage_stats, - 'summarizer': summarizer_stats, - 'agent': agent_stats, + 'storage': self.storage.get_storage_stats(), + 'summarizer': self.summarizer.get_summary_stats(), + 'agent': self.agent.get_agent_stats(), 'supported_extensions': list(self.get_supported_extensions().keys()), 'features': ['adaptive_rag', 'vision_page_selection', 'task_planning', 'conversation_aware'] } - - # Synchronous API for easier adoption - - def add_document_sync( - self, - file_path: Union[str, Path], - document_id: Optional[str] = None, - document_name: Optional[str] = None - ) -> Document: - """Synchronous version of add_document""" + + # Synchronous API Wrappers + def add_document_sync(self, file_path: Union[str, Path], document_id: Optional[str] = None, document_name: Optional[str] = None) -> Document: + """Synchronous version of `add_document`.""" return sync_wrapper(self.add_document(file_path, document_id, document_name)) - + def get_document_sync(self, document_id: str) -> Optional[Document]: - """Synchronous version of get_document""" + """Synchronous version of `get_document`.""" return sync_wrapper(self.get_document(document_id)) - + def list_documents_sync(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """Synchronous version of list_documents""" + """Synchronous version of `list_documents`.""" return sync_wrapper(self.list_documents(limit)) - + def delete_document_sync(self, document_id: str) -> bool: - """Synchronous version of delete_document""" + """Synchronous version of `delete_document`.""" return sync_wrapper(self.delete_document(document_id)) - - def query_sync( - self, - question: str, - mode: QueryMode = QueryMode.AUTO, - document_ids: Optional[List[str]] = None, - max_pages: Optional[int] = None, - conversation_history: Optional[List[ConversationMessage]] = None, - task_update_callback: Optional[Any] = None - ) -> QueryResult: - """Synchronous version of query""" + + def query_sync(self, question: str, mode: QueryMode = QueryMode.AUTO, document_ids: Optional[List[str]] = None, max_pages: Optional[int] = None, conversation_history: Optional[List[ConversationMessage]] = None, task_update_callback: Optional[Callable] = None) -> QueryResult: + """Synchronous version of `query`.""" return sync_wrapper(self.query(question, mode, document_ids, max_pages, stream=False, conversation_history=conversation_history, task_update_callback=task_update_callback)) - + def search_documents_sync(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: - """Synchronous version of search_documents""" + """Synchronous version of `search_documents`.""" return sync_wrapper(self.search_documents(query, limit)) - - def query_with_conversation_sync( - self, - question: str, - conversation_history: List[ConversationMessage], - mode: QueryMode = QueryMode.AUTO - ) -> QueryResult: - """Synchronous version of query_with_conversation""" + + def query_with_conversation_sync(self, question: str, conversation_history: List[ConversationMessage], mode: QueryMode = QueryMode.AUTO) -> QueryResult: + """Synchronous version of `query_with_conversation`.""" return sync_wrapper(self.query_with_conversation(question, conversation_history, mode)) - - # Context manager support - + + # Context Manager Support async def __aenter__(self): - """Async context manager entry""" + """Asynchronous context manager entry.""" return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit""" - # Cleanup if needed + """Asynchronous context manager exit.""" pass - + def __enter__(self): - """Sync context manager entry""" + """Synchronous context manager entry.""" return self - + def __exit__(self, exc_type, exc_val, exc_tb): - """Sync context manager exit""" - # Cleanup if needed + """Synchronous context manager exit.""" pass -# Convenience factory functions - def create_docpixie( provider: str = "openai", api_key: Optional[str] = None, storage_path: Optional[str] = None ) -> DocPixie: """ - Create a DocPixie instance with simple configuration - + A factory function to create a `DocPixie` instance with a simple configuration. + Args: - provider: AI provider ("openai" or "anthropic") - api_key: API key for the provider - storage_path: Local storage path (uses default if None) - + provider: The AI provider to use ('openai', 'anthropic', 'openrouter'). + api_key: The API key for the provider. + storage_path: The path for the local storage backend. + Returns: - Configured DocPixie instance + A configured `DocPixie` instance. """ config = DocPixieConfig( provider=provider, local_storage_path=storage_path or "./docpixie_data" ) - return DocPixie(config=config, api_key=api_key) @@ -370,18 +382,20 @@ def create_memory_docpixie( api_key: Optional[str] = None ) -> DocPixie: """ - Create DocPixie instance with in-memory storage for testing - + A factory function to create a `DocPixie` instance with in-memory storage. + + This is useful for testing and scenarios where data persistence is not + required. + Args: - provider: AI provider - api_key: API key for the provider - + provider: The AI provider to use. + api_key: The API key for the provider. + Returns: - DocPixie instance with memory storage + A `DocPixie` instance configured with in-memory storage. """ config = DocPixieConfig( provider=provider, storage_type="memory" ) - return DocPixie(config=config, api_key=api_key) \ No newline at end of file diff --git a/docpixie/models/agent.py b/docpixie/models/agent.py index f1fa4dd..5f30bb7 100644 --- a/docpixie/models/agent.py +++ b/docpixie/models/agent.py @@ -1,5 +1,8 @@ """ -Agent models and data structures for DocPixie RAG Agent +Agent models and data structures for DocPixie RAG Agent. + +This module defines the data structures used by the RAG agent to manage tasks, +plans, and results. """ import uuid @@ -12,7 +15,7 @@ class TaskStatus(str, Enum): - """Agent task status""" + """Enumeration for the status of an agent task.""" PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" @@ -21,31 +24,47 @@ class TaskStatus(str, Enum): @dataclass class ConversationMessage: - """Represents a single conversation message""" - role: str # "user" or "assistant" + """Represents a single message in a conversation. + + Attributes: + role: The role of the message sender, either "user" or "assistant". + content: The text content of the message. + timestamp: The time the message was created. + cost: The cost associated with generating this message, particularly for + assistant messages which may involve multiple API calls. + """ + role: str content: str timestamp: datetime = field(default_factory=datetime.now) - cost: float = 0.0 # Cost for this message (agent pipeline total for assistant messages) + cost: float = 0.0 def __post_init__(self): - """Validate message data""" + """Validates the message data after initialization.""" if self.role not in ["system", "user", "assistant"]: - raise ValueError("Role must be 'user' or 'assistant'") + raise ValueError("Role must be 'system', 'user', or 'assistant'") if not self.content.strip(): raise ValueError("Content cannot be empty") @dataclass class AgentTask: - """Represents a single task in the agent's plan""" + """Represents a single, discrete task for the agent to perform. + + Attributes: + id: A unique identifier for the task. + name: A short, descriptive name for the task. + description: A more detailed description of what the task entails. + status: The current status of the task. + document: The ID of the document this task is associated with. + """ id: str = field(default_factory=lambda: str(uuid.uuid4())) name: str = "" description: str = "" status: TaskStatus = TaskStatus.PENDING - document: str = "" # Single document ID assigned to this task + document: str = "" def __post_init__(self): - """Validate task data""" + """Validates the task data after initialization.""" if not self.name.strip(): raise ValueError("Task name cannot be empty") if not self.description.strip(): @@ -54,21 +73,46 @@ def __post_init__(self): @dataclass class TaskPlan: - """Represents the agent's current task plan""" + """Represents the agent's plan of action to address a user query. + + Attributes: + initial_query: The user's query that this plan is designed to address. + tasks: A list of `AgentTask` objects that make up the plan. + current_iteration: The current iteration number in the agent's execution + loop. + """ initial_query: str tasks: List[AgentTask] = field(default_factory=list) current_iteration: int = 0 def get_next_pending_task(self) -> Optional[AgentTask]: - """Get the next task that needs to be executed""" + """ + Retrieves the next task in the plan that has a 'pending' status. + + Returns: + The next pending `AgentTask`, or `None` if no tasks are pending. + """ return next((task for task in self.tasks if task.status == TaskStatus.PENDING), None) def has_pending_tasks(self) -> bool: - """Check if there are any pending tasks""" + """ + Checks if there are any tasks in the plan with a 'pending' status. + + Returns: + `True` if there are pending tasks, `False` otherwise. + """ return any(task.status == TaskStatus.PENDING for task in self.tasks) def mark_task_completed(self, task_id: str) -> bool: - """Mark a task as completed""" + """ + Marks a specific task as completed. + + Args: + task_id: The ID of the task to mark as completed. + + Returns: + `True` if the task was found and marked, `False` otherwise. + """ task = next((t for t in self.tasks if t.id == task_id), None) if task: task.status = TaskStatus.COMPLETED @@ -76,46 +120,96 @@ def mark_task_completed(self, task_id: str) -> bool: return False def add_task(self, task: AgentTask): - """Add a new task to the plan""" + """ + Adds a new task to the plan. + + Args: + task: The `AgentTask` to add to the plan. + """ self.tasks.append(task) def remove_task(self, task_id: str) -> bool: - """Remove a task from the plan""" + """ + Removes a task from the plan by its ID. + + Args: + task_id: The ID of the task to remove. + + Returns: + `True` if a task was removed, `False` otherwise. + """ original_length = len(self.tasks) self.tasks = [t for t in self.tasks if t.id != task_id] return len(self.tasks) < original_length def get_completed_tasks(self) -> List[AgentTask]: - """Get all completed tasks""" + """ + Retrieves a list of all completed tasks in the plan. + + Returns: + A list of `AgentTask` objects with a 'completed' status. + """ return [task for task in self.tasks if task.status == TaskStatus.COMPLETED] @dataclass class TaskResult: - """Represents the result of executing a single task""" + """Represents the result of executing a single agent task. + + Attributes: + task: The `AgentTask` that was executed. + selected_pages: A list of `Page` objects that were selected as relevant + by the task. + analysis: A textual summary or analysis produced by the task. + pages_analyzed: The number of pages that were analyzed to produce this + result. + """ task: AgentTask selected_pages: List[Page] analysis: str pages_analyzed: int = 0 def __post_init__(self): - """Calculate pages analyzed""" + """Calculates the number of pages analyzed after initialization.""" self.pages_analyzed = len(self.selected_pages) @dataclass class AgentQueryResult: - """Represents the final result of processing a user query through the agent pipeline""" + """ + Represents the final result of processing a user query through the agent. + + Attributes: + query: The initial user query. + answer: The final, synthesized answer to the query. + selected_pages: A list of `Page` objects that were used to generate the + answer. + task_results: A list of `TaskResult` objects from the execution of the + plan. + total_iterations: The total number of iterations the agent performed. + processing_time_seconds: The total time taken to process the query, in + seconds. + total_cost: The total monetary cost of all API calls made to process + the query. + """ query: str answer: str selected_pages: List[Page] task_results: List[TaskResult] = field(default_factory=list) total_iterations: int = 0 processing_time_seconds: float = 0.0 - total_cost: float = 0.0 # Total cost of all API calls for this query + total_cost: float = 0.0 def get_unique_pages(self) -> List[Page]: - """Get unique pages from all task results""" + """ + Retrieves a list of unique pages from all task results. + + This method filters out duplicate pages that may have been selected by + multiple tasks. + + Returns: + A list of unique `Page` objects. + """ seen_paths = set() unique_pages = [] @@ -127,5 +221,10 @@ def get_unique_pages(self) -> List[Page]: return unique_pages def get_total_pages_analyzed(self) -> int: - """Get total number of pages analyzed across all tasks""" + """ + Calculates the total number of pages analyzed across all tasks. + + Returns: + The total count of pages analyzed. + """ return sum(result.pages_analyzed for result in self.task_results) diff --git a/docpixie/models/document.py b/docpixie/models/document.py index 5dc0f12..4da4ceb 100644 --- a/docpixie/models/document.py +++ b/docpixie/models/document.py @@ -1,6 +1,9 @@ """ -Document models and data structures for DocPixie -Simplified version of schemas from production DocPixie +Document models and data structures for DocPixie. + +This module defines the core data structures for representing documents, pages, +and query results within the DocPixie system. These models are simplified +versions of the schemas used in the production DocPixie application. """ from dataclasses import dataclass, field @@ -12,12 +15,12 @@ class QueryMode(str, Enum): - """Query processing modes""" + """Enumeration for query processing modes.""" AUTO = "auto" # Standard adaptive processing class DocumentStatus(str, Enum): - """Document processing status""" + """Enumeration for the status of document processing.""" PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" @@ -26,7 +29,16 @@ class DocumentStatus(str, Enum): @dataclass class Page: - """Represents a single document page""" + """ + Represents a single page of a document. + + Attributes: + page_number: The one-based index of the page within the document. + image_path: The file path to the image representation of the page. + metadata: A dictionary for storing any additional page-specific data. + document_name: The name of the document this page belongs to. + document_id: The unique identifier of the document this page belongs to. + """ page_number: int image_path: str metadata: Dict[str, Any] = field(default_factory=dict) @@ -34,7 +46,7 @@ class Page: document_id: Optional[str] = None def __post_init__(self): - """Validate page data""" + """Validates the page data after initialization.""" if self.page_number <= 0: raise ValueError("Page number must be positive") if not self.image_path: @@ -43,7 +55,18 @@ def __post_init__(self): @dataclass class Document: - """Represents a processed document with pages""" + """ + Represents a processed document, including its pages and metadata. + + Attributes: + id: A unique identifier for the document. + name: The name of the document. + pages: A list of `Page` objects that make up the document. + summary: A text summary of the document's content. + status: The current processing status of the document. + metadata: A dictionary for storing any additional document-level data. + created_at: The timestamp when the document was created. + """ id: str name: str pages: List[Page] @@ -53,7 +76,9 @@ class Document: created_at: datetime = field(default_factory=datetime.now) def __post_init__(self): - """Generate ID if not provided and validate data""" + """ + Generates a unique ID if not provided and validates the document data. + """ if not self.id: self.id = str(uuid.uuid4()) if not self.name: @@ -63,25 +88,56 @@ def __post_init__(self): @property def page_count(self) -> int: - """Get total number of pages""" + """Returns the total number of pages in the document.""" return len(self.pages) def get_page(self, page_number: int) -> Optional[Page]: - """Get specific page by number""" + """ + Retrieves a specific page from the document by its page number. + + Args: + page_number: The number of the page to retrieve. + + Returns: + The `Page` object if found, otherwise `None`. + """ for page in self.pages: if page.page_number == page_number: return page return None def get_pages_range(self, start: int, end: int) -> List[Page]: - """Get pages in a range""" + """ + Retrieves a list of pages within a specified range (inclusive). + + Args: + start: The starting page number. + end: The ending page number. + + Returns: + A list of `Page` objects within the specified range. + """ return [p for p in self.pages if start <= p.page_number <= end] @dataclass class QueryResult: - """Result of a RAG query""" + """ + Represents the result of a RAG (Retrieval-Augmented Generation) query. + + Attributes: + query: The original query string. + answer: The generated answer to the query. + selected_pages: A list of `Page` objects used to generate the answer. + mode: The `QueryMode` used for processing the query. + confidence: A score between 0 and 1 indicating the confidence in the + answer. + processing_time: The time taken to process the query, in seconds. + metadata: A dictionary for storing any additional metadata about the + query result. + total_cost: The total monetary cost of all API calls made for this query. + """ query: str answer: str selected_pages: List[Page] @@ -89,10 +145,10 @@ class QueryResult: confidence: float = 0.0 processing_time: float = 0.0 metadata: Dict[str, Any] = field(default_factory=dict) - total_cost: float = 0.0 # Total cost of all API calls for this query + total_cost: float = 0.0 def __post_init__(self): - """Validate result data""" + """Validates the query result data after initialization.""" if not self.query: raise ValueError("Query is required") if not self.answer: @@ -102,16 +158,22 @@ def __post_init__(self): @property def page_count(self) -> int: - """Number of pages used for the answer""" + """Returns the number of pages used to generate the answer.""" return len(self.selected_pages) @property def page_numbers(self) -> List[int]: - """Page numbers used for the answer""" + """Returns a list of the page numbers used to generate the answer.""" return [p.page_number for p in self.selected_pages] def get_pages_by_document(self) -> Dict[str, List[int]]: - """Get pages grouped by document name""" + """ + Groups the selected pages by their document name. + + Returns: + A dictionary where keys are document names and values are sorted + lists of page numbers. + """ pages_by_doc = {} for page in self.selected_pages: doc_name = page.document_name or "Unknown Document" @@ -128,13 +190,20 @@ def get_pages_by_document(self) -> Dict[str, List[int]]: @dataclass class DocumentProcessRequest: - """Request to process a document""" + """ + Represents a request to process a new document. + + Attributes: + file_path: The path to the document file. + document_id: An optional custom ID for the document. + document_name: An optional custom name for the document. + """ file_path: str document_id: Optional[str] = None document_name: Optional[str] = None def __post_init__(self): - """Validate and set defaults""" + """Validates the request and sets default values.""" if not self.file_path or not Path(self.file_path).exists(): raise FileNotFoundError(f"File not found: {self.file_path}") @@ -147,7 +216,17 @@ def __post_init__(self): @dataclass class QueryRequest: - """Request to query documents""" + """ + Represents a request to query documents. + + Attributes: + query: The user's question. + mode: The `QueryMode` to use for the query. + document_ids: A list of specific document IDs to search. If `None`, all + documents are searched. + max_pages: The maximum number of pages to analyze. + stream: Whether to stream the response. + """ query: str mode: QueryMode = QueryMode.AUTO document_ids: Optional[List[str]] = None @@ -155,7 +234,7 @@ class QueryRequest: stream: bool = False def __post_init__(self): - """Validate query request""" + """Validates the query request and sets default values.""" if not self.query.strip(): raise ValueError("Query cannot be empty") diff --git a/docpixie/processors/base.py b/docpixie/processors/base.py index c476a2c..acd4c29 100644 --- a/docpixie/processors/base.py +++ b/docpixie/processors/base.py @@ -1,5 +1,9 @@ """ -Base processor interface for document processing +Base processor interface for document processing. + +This module defines the abstract base class for document processors in +DocPixie. It establishes a common interface for handling different file types +and converting them into a standardized `Document` object. """ from abc import ABC, abstractmethod @@ -14,43 +18,80 @@ class BaseProcessor(ABC): - """Base class for document processors""" - + """ + An abstract base class for document processors. + + This class defines the core interface for processing various file types + (e.g., PDF, images) and converting them into a structured `Document` format + that can be used by the rest of the DocPixie system. + + Attributes: + config: The DocPixie configuration object. + """ + def __init__(self, config: DocPixieConfig): + """ + Initializes the BaseProcessor. + + Args: + config: The DocPixie configuration object. + """ self.config = config - + @abstractmethod def supports(self, file_path: str) -> bool: - """Check if this processor supports the given file type""" + """ + Checks if this processor can handle the given file type. + + Args: + file_path: The path to the file. + + Returns: + `True` if the processor supports the file type, `False` otherwise. + """ pass - + @abstractmethod async def process(self, file_path: str, document_id: Optional[str] = None) -> Document: """ - Process a document file into pages - + Processes a document file and converts it into a `Document` object. + Args: - file_path: Path to the document file - document_id: Optional custom document ID - + file_path: The path to the document file. + document_id: An optional custom ID for the document. + Returns: - Document with processed pages + A `Document` object containing the processed pages. """ pass - + def get_supported_extensions(self) -> List[str]: - """Get list of supported file extensions""" + """ + Returns a list of file extensions supported by this processor. + + Returns: + A list of strings, each representing a supported file extension. + """ return [] - + def _create_document( - self, - file_path: str, - pages: List[Page], + self, + file_path: str, + pages: List[Page], document_id: Optional[str] = None ) -> Document: - """Create a Document object from processed pages""" + """ + Creates a `Document` object from a list of processed pages. + + Args: + file_path: The path to the original document file. + pages: A list of `Page` objects. + document_id: An optional custom document ID. + + Returns: + A new `Document` object. + """ document_name = Path(file_path).stem - return Document( id=document_id or self._generate_document_id(file_path), name=document_name, @@ -61,14 +102,31 @@ def _create_document( 'file_size': Path(file_path).stat().st_size if Path(file_path).exists() else 0 } ) - + def _generate_document_id(self, file_path: str) -> str: - """Generate a document ID from file path""" + """ + Generates a unique document ID based on the file path. + + Args: + file_path: The path to the document file. + + Returns: + A unique MD5 hash of the file path. + """ import hashlib return hashlib.md5(file_path.encode()).hexdigest() - + def _validate_file(self, file_path: str) -> None: - """Validate that file exists and is readable""" + """ + Validates that a file exists, is a file, and is not empty. + + Args: + file_path: The path to the file. + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If the path is not a file or if the file is empty. + """ path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"File not found: {file_path}") @@ -79,9 +137,23 @@ def _validate_file(self, file_path: str) -> None: class ProcessingError(Exception): - """Exception raised during document processing""" - + """ + A custom exception raised for errors that occur during document processing. + + Attributes: + file_path: The path to the file being processed. + page_number: The page number where the error occurred, if applicable. + """ + def __init__(self, message: str, file_path: str, page_number: Optional[int] = None): + """ + Initializes the ProcessingError. + + Args: + message: The error message. + file_path: The path to the file being processed. + page_number: The optional page number where the error occurred. + """ self.file_path = file_path self.page_number = page_number super().__init__(message) \ No newline at end of file diff --git a/docpixie/processors/factory.py b/docpixie/processors/factory.py index 4312e48..4a8f9f8 100644 --- a/docpixie/processors/factory.py +++ b/docpixie/processors/factory.py @@ -1,5 +1,10 @@ """ -Processor factory for selecting appropriate document processor +A factory for creating and managing document processors. + +This module provides a `ProcessorFactory` class that is responsible for +instantiating the correct document processor based on the file type. It +maintains a mapping of file extensions to processor classes and allows for +the registration of custom processors. """ from typing import Optional, Dict, Type @@ -15,103 +20,126 @@ class ProcessorFactory: - """Factory for creating document processors""" - + """ + A factory class for creating document processors. + + This class provides an interface to get the appropriate processor for a + given file type. It automatically maps file extensions to the registered + processor classes. + + Attributes: + config: The DocPixie configuration object. + """ + def __init__(self, config: DocPixieConfig): + """ + Initializes the ProcessorFactory. + + Args: + config: The DocPixie configuration object. + """ self.config = config self._processors: Dict[str, Type[BaseProcessor]] = { 'pdf': PDFProcessor, 'image': ImageProcessor } - - # Map file extensions to processor types self._extension_map: Dict[str, str] = {} self._build_extension_map() - + def _build_extension_map(self): - """Build mapping from file extensions to processor types""" - # Create processor instances to get supported extensions + """Builds the mapping from file extensions to processor types.""" for processor_type, processor_class in self._processors.items(): processor = processor_class(self.config) for ext in processor.get_supported_extensions(): self._extension_map[ext.lower()] = processor_type - logger.debug(f"Built extension map: {self._extension_map}") - + def get_processor(self, file_path: str) -> BaseProcessor: """ - Get appropriate processor for file - + Gets the appropriate processor for a given file. + Args: - file_path: Path to file - + file_path: The path to the file. + Returns: - Processor instance - + An instance of the appropriate `BaseProcessor` subclass. + Raises: - ValueError: If file type is not supported + ValueError: If the file type is not supported or the file has no + extension. """ file_extension = Path(file_path).suffix.lower() - if not file_extension: raise ValueError(f"File has no extension: {file_path}") - + processor_type = self._extension_map.get(file_extension) - if not processor_type: supported_exts = list(self._extension_map.keys()) raise ValueError( f"Unsupported file type '{file_extension}'. " f"Supported extensions: {supported_exts}" ) - + processor_class = self._processors[processor_type] processor = processor_class(self.config) - logger.debug(f"Selected {processor_class.__name__} for {file_path}") return processor - + def supports_file(self, file_path: str) -> bool: - """Check if file type is supported""" + """ + Checks if a given file type is supported. + + Args: + file_path: The path to the file. + + Returns: + `True` if the file type is supported, `False` otherwise. + """ file_extension = Path(file_path).suffix.lower() return file_extension in self._extension_map - + def get_supported_extensions(self) -> Dict[str, str]: - """Get all supported extensions and their processor types""" + """ + Gets a dictionary of all supported extensions and their processor types. + + Returns: + A dictionary mapping file extensions to processor type names. + """ return self._extension_map.copy() - + def register_processor(self, processor_type: str, processor_class: Type[BaseProcessor]): """ - Register a custom processor - + Registers a new custom processor. + Args: - processor_type: Unique identifier for processor - processor_class: Processor class + processor_type: A unique identifier for the processor. + processor_class: The processor class to register, which must be a + subclass of `BaseProcessor`. """ self._processors[processor_type] = processor_class - - # Update extension mapping processor = processor_class(self.config) for ext in processor.get_supported_extensions(): self._extension_map[ext.lower()] = processor_type - logger.info(f"Registered custom processor: {processor_type}") - + def list_processors(self) -> Dict[str, Type[BaseProcessor]]: - """Get all registered processors""" + """ + Gets a dictionary of all registered processors. + + Returns: + A dictionary mapping processor type names to their classes. + """ return self._processors.copy() - + def create_processor(self, processor_type: str) -> Optional[BaseProcessor]: """ - Create processor by type - + Creates a processor instance by its type name. + Args: - processor_type: Type of processor to create - + processor_type: The type of processor to create. + Returns: - Processor instance or None if type not found + An instance of the processor, or `None` if the type is not found. """ processor_class = self._processors.get(processor_type) - if processor_class: - return processor_class(self.config) - return None \ No newline at end of file + return processor_class(self.config) if processor_class else None \ No newline at end of file diff --git a/docpixie/processors/image.py b/docpixie/processors/image.py index e9badd9..99d97bc 100644 --- a/docpixie/processors/image.py +++ b/docpixie/processors/image.py @@ -1,6 +1,9 @@ """ -Image processor for direct image files -Handles JPG, PNG, WebP, and other image formats +Image processor for direct image files. + +This module provides a processor for handling various image formats, such as +JPG, PNG, and WebP. It processes a single image file into a one-page +`Document` object. """ import asyncio @@ -20,95 +23,101 @@ class ImageProcessor(BaseProcessor): - """Processor for image files""" - + """ + A processor for handling various image file formats. + + This class implements the `BaseProcessor` interface to process image files. + It converts a single image into a one-page `Document`, optimizing the image + for storage and analysis by converting it to JPEG, resizing if necessary, + and handling transparency. + + Attributes: + SUPPORTED_EXTENSIONS: A list of supported image file extensions. + temp_dir: The temporary directory used for storing processed images. + """ + SUPPORTED_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff', '.tif'] - + def __init__(self, config: DocPixieConfig): + """ + Initializes the ImageProcessor. + + Args: + config: The DocPixie configuration object. + """ super().__init__(config) self.temp_dir = None - + def supports(self, file_path: str) -> bool: - """Check if file is a supported image format""" + """Checks if the given file is a supported image format.""" return Path(file_path).suffix.lower() in self.SUPPORTED_EXTENSIONS - + def get_supported_extensions(self) -> List[str]: - """Get supported file extensions""" + """Returns the list of supported image file extensions.""" return self.SUPPORTED_EXTENSIONS.copy() - + async def process(self, file_path: str, document_id: Optional[str] = None) -> Document: """ - Process image file into a single-page document - + Processes an image file into a single-page `Document`. + Args: - file_path: Path to image file - document_id: Optional custom document ID - + file_path: The path to the image file. + document_id: An optional custom ID for the document. + Returns: - Document with single page + A `Document` object containing a single page representing the image. + + Raises: + ProcessingError: If the image processing fails. """ self._validate_file(file_path) logger.info(f"Processing image: {file_path}") - + try: - # Create temporary directory for processed image self.temp_dir = tempfile.mkdtemp(prefix="docpixie_img_") - - # Process image in thread pool - page = await asyncio.get_event_loop().run_in_executor( - None, - self._process_image_sync, - file_path - ) - - # Create document with single page + page = await asyncio.to_thread(self._process_image_sync, file_path) document = self._create_document(file_path, [page], document_id) document.status = DocumentStatus.COMPLETED - - # Update page with document info for page in document.pages: page.document_name = document.name page.document_id = document.id - logger.info(f"Successfully processed image: {file_path}") return document - + except Exception as e: logger.error(f"Failed to process image {file_path}: {e}") - # Clean up temp directory on error if self.temp_dir and os.path.exists(self.temp_dir): import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) raise ProcessingError(f"Image processing failed: {e}", file_path) - + def _process_image_sync(self, file_path: str) -> Page: - """Synchronous image processing""" + """ + Performs the synchronous processing of the image file. + + Args: + file_path: The path to the image file. + + Returns: + A `Page` object representing the processed image. + + Raises: + ProcessingError: If the image cannot be identified or processed. + """ try: - # Open and process image with Image.open(file_path) as img: - # Get original dimensions original_width, original_height = img.size - - # Optimize image optimized_img = self._optimize_image(img) - - # Save optimized image - output_filename = "page_001.jpg" - output_path = os.path.join(self.temp_dir, output_filename) - + + output_path = os.path.join(self.temp_dir, "page_001.jpg") optimized_img.save( - output_path, - 'JPEG', - quality=self.config.jpeg_quality, - optimize=True + output_path, 'JPEG', quality=self.config.jpeg_quality, optimize=True ) - - # Get final image dimensions and file size + final_width, final_height = optimized_img.size file_size = os.path.getsize(output_path) - - # Create page object - page = Page( + + return Page( page_number=1, image_path=output_path, metadata={ @@ -120,68 +129,77 @@ def _process_image_sync(self, file_path: str) -> Page: 'original_format': img.format } ) - - return page - except Image.UnidentifiedImageError as e: raise ProcessingError(f"Unrecognized image format: {e}", file_path) except Exception as e: raise ProcessingError(f"Failed to process image: {e}", file_path) - + def _optimize_image(self, img: Image.Image) -> Image.Image: """ - Optimize image for storage and processing - Same logic as PDF processor + Optimizes an image for storage and processing. + + This involves converting the image to RGB, handling transparency, and + resizing if it exceeds the configured maximum dimensions. + + Args: + img: The PIL `Image` object to optimize. + + Returns: + The optimized PIL `Image` object. """ - # Convert to RGB if necessary - if img.mode in ('RGBA', 'LA', 'P'): - # Create white background for transparency - rgb_img = Image.new('RGB', img.size, (255, 255, 255)) - if img.mode == 'RGBA': - rgb_img.paste(img, mask=img.split()[-1]) # Use alpha channel as mask - elif img.mode == 'P' and 'transparency' in img.info: - # Handle palette mode with transparency - img = img.convert('RGBA') - rgb_img.paste(img, mask=img.split()[-1]) + if img.mode != 'RGB': + if img.mode in ('RGBA', 'LA', 'P'): + rgb_img = Image.new('RGB', img.size, (255, 255, 255)) + if img.mode == 'RGBA' or (img.mode == 'P' and 'transparency' in img.info): + img = img.convert('RGBA') + rgb_img.paste(img, mask=img.split()[-1]) + else: + rgb_img.paste(img) + img = rgb_img else: - rgb_img.paste(img) - img = rgb_img - elif img.mode != 'RGB': - img = img.convert('RGB') - - # Resize if image is too large + img = img.convert('RGB') + max_width, max_height = self.config.pdf_max_image_size if img.width > max_width or img.height > max_height: - # Calculate new size maintaining aspect ratio ratio = min(max_width / img.width, max_height / img.height) - new_width = int(img.width * ratio) - new_height = int(img.height * ratio) - - img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - logger.debug(f"Resized image to {new_width}x{new_height}") - + new_size = (int(img.width * ratio), int(img.height * ratio)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + logger.debug(f"Resized image to {new_size[0]}x{new_size[1]}") + return img - + def create_thumbnail(self, image_path: str) -> str: - """Create thumbnail for quick page selection""" + """ + Creates a thumbnail for an image. + + Args: + image_path: The path to the image file. + + Returns: + The path to the created thumbnail, or the original image path if + thumbnail creation fails. + """ try: with Image.open(image_path) as img: - # Create thumbnail thumbnail = img.copy() thumbnail.thumbnail(self.config.thumbnail_size, Image.Resampling.LANCZOS) - - # Save thumbnail thumb_path = image_path.replace('.jpg', '_thumb.jpg') thumbnail.save(thumb_path, 'JPEG', quality=85, optimize=True) - return thumb_path - except Exception as e: logger.error(f"Failed to create thumbnail for {image_path}: {e}") - return image_path # Return original if thumbnail creation fails - + return image_path + def get_image_metadata(self, file_path: str) -> dict: - """Extract image metadata""" + """ + Extracts metadata from an image file. + + Args: + file_path: The path to the image file. + + Returns: + A dictionary containing image metadata. + """ try: with Image.open(file_path) as img: metadata = { @@ -191,14 +209,9 @@ def get_image_metadata(self, file_path: str) -> dict: 'height': img.height, 'has_transparency': img.mode in ('RGBA', 'LA') or 'transparency' in img.info } - - # Add EXIF data if available - if hasattr(img, '_getexif') and img._getexif() is not None: - exif = img._getexif() + if hasattr(img, '_getexif') and (exif := img._getexif()): metadata['exif'] = exif - return metadata - except Exception as e: logger.error(f"Failed to extract image metadata: {e}") return {} \ No newline at end of file diff --git a/docpixie/processors/pdf.py b/docpixie/processors/pdf.py index ede612e..9f6e6bb 100644 --- a/docpixie/processors/pdf.py +++ b/docpixie/processors/pdf.py @@ -1,6 +1,9 @@ """ -PyMuPDF-based PDF processor -Replacement for pdf2image with better performance and quality +PyMuPDF-based PDF processor. + +This module provides a high-performance PDF processor that uses the PyMuPDF +library to rasterize PDF pages into images. It serves as a more efficient +alternative to `pdf2image`. """ import asyncio @@ -9,6 +12,7 @@ from pathlib import Path import tempfile import os +import io from PIL import Image import fitz # PyMuPDF @@ -21,115 +25,103 @@ class PDFProcessor(BaseProcessor): - """PDF processor using PyMuPDF for better performance""" - + """ + A PDF processor that uses PyMuPDF for rendering pages. + + This class implements the `BaseProcessor` interface to handle PDF files. It + iterates through the pages of a PDF, renders each one as an image, and + creates a `Document` object containing these pages. + + Attributes: + SUPPORTED_EXTENSIONS: A list of supported file extensions (only '.pdf'). + temp_dir: The temporary directory for storing rendered page images. + """ + SUPPORTED_EXTENSIONS = ['.pdf'] - + def __init__(self, config: DocPixieConfig): + """ + Initializes the PDFProcessor. + + Args: + config: The DocPixie configuration object. + """ super().__init__(config) self.temp_dir = None - + def supports(self, file_path: str) -> bool: - """Check if file is a PDF""" + """Checks if the given file is a PDF.""" return Path(file_path).suffix.lower() in self.SUPPORTED_EXTENSIONS - + def get_supported_extensions(self) -> List[str]: - """Get supported file extensions""" + """Returns the list of supported file extensions.""" return self.SUPPORTED_EXTENSIONS.copy() - + async def process(self, file_path: str, document_id: Optional[str] = None) -> Document: """ - Process PDF into document pages using PyMuPDF - + Processes a PDF file into a `Document` with multiple pages. + Args: - file_path: Path to PDF file - document_id: Optional custom document ID - + file_path: The path to the PDF file. + document_id: An optional custom ID for the document. + Returns: - Document with processed pages + A `Document` object with its pages populated from the PDF. + + Raises: + ProcessingError: If the PDF processing fails. """ self._validate_file(file_path) logger.info(f"Processing PDF: {file_path}") - + try: - # Create temporary directory for page images self.temp_dir = tempfile.mkdtemp(prefix="docpixie_pdf_") - - # Process PDF in thread pool (PyMuPDF is not async) - pages = await asyncio.get_event_loop().run_in_executor( - None, - self._process_pdf_sync, - file_path - ) - - # Create document + pages = await asyncio.to_thread(self._process_pdf_sync, file_path) document = self._create_document(file_path, pages, document_id) document.status = DocumentStatus.COMPLETED - - # Update pages with document info for page in document.pages: page.document_name = document.name page.document_id = document.id - logger.info(f"Successfully processed PDF: {len(pages)} pages") return document - + except Exception as e: logger.error(f"Failed to process PDF {file_path}: {e}") - # Clean up temp directory on error if self.temp_dir and os.path.exists(self.temp_dir): import shutil shutil.rmtree(self.temp_dir, ignore_errors=True) raise ProcessingError(f"PDF processing failed: {e}", file_path) - + def _process_pdf_sync(self, file_path: str) -> List[Page]: - """Synchronous PDF processing with PyMuPDF""" + """ + Performs the synchronous processing of the PDF file. + + Args: + file_path: The path to the PDF file. + + Returns: + A list of `Page` objects, one for each page in the PDF. + + Raises: + ProcessingError: If the PDF is invalid or a page fails to process. + """ pages = [] - try: - # Open PDF document pdf_doc = fitz.open(file_path) - total_pages = pdf_doc.page_count - - logger.info(f"Processing {total_pages} pages from PDF") - - for page_num in range(total_pages): + logger.info(f"Processing {pdf_doc.page_count} pages from PDF") + for page_num in range(pdf_doc.page_count): try: - # Get page page = pdf_doc[page_num] - - # Create transformation matrix for scaling - matrix = fitz.Matrix( - self.config.pdf_render_scale, - self.config.pdf_render_scale - ) - - # Render page to pixmap - pix = page.get_pixmap( - matrix=matrix, - alpha=False # No transparency for JPEG - ) - - # Convert to PIL Image - img_data = pix.tobytes("ppm") - img = Image.open(io.BytesIO(img_data)) - - # Optimize image + matrix = fitz.Matrix(self.config.pdf_render_scale, self.config.pdf_render_scale) + pix = page.get_pixmap(matrix=matrix, alpha=False) + img = Image.open(io.BytesIO(pix.tobytes("ppm"))) optimized_img = self._optimize_image(img) - - # Save page image - page_filename = f"page_{page_num + 1:03d}.jpg" - page_image_path = os.path.join(self.temp_dir, page_filename) - + + page_image_path = os.path.join(self.temp_dir, f"page_{page_num + 1:03d}.jpg") optimized_img.save( - page_image_path, - 'JPEG', - quality=self.config.jpeg_quality, - optimize=True + page_image_path, 'JPEG', quality=self.config.jpeg_quality, optimize=True ) - - # Create page object - page_obj = Page( + pages.append(Page( page_number=page_num + 1, image_path=page_image_path, metadata={ @@ -137,98 +129,75 @@ def _process_pdf_sync(self, file_path: str) -> List[Page]: 'height': pix.height, 'file_size': os.path.getsize(page_image_path) } - ) - - pages.append(page_obj) - + )) except Exception as e: logger.error(f"Failed to process page {page_num + 1}: {e}") - raise ProcessingError( - f"Failed to process page {page_num + 1}: {e}", - file_path, - page_num + 1 - ) - + raise ProcessingError(f"Failed to process page {page_num + 1}: {e}", file_path, page_num + 1) pdf_doc.close() return pages - - except fitz.FileDataError as e: - raise ProcessingError(f"Invalid PDF file: {e}", file_path) - except fitz.FileNotFoundError as e: - raise ProcessingError(f"PDF file not found: {e}", file_path) + except (fitz.FileDataError, fitz.FileNotFoundError) as e: + raise ProcessingError(f"Invalid or not found PDF file: {e}", file_path) except Exception as e: raise ProcessingError(f"Unexpected error processing PDF: {e}", file_path) - + def _optimize_image(self, img: Image.Image) -> Image.Image: """ - Optimize image for storage and processing - Adapted from existing resize_image_for_upload logic + Optimizes a rendered page image. + + Args: + img: The PIL `Image` object of the rendered page. + + Returns: + The optimized PIL `Image` object. """ - # Convert to RGB if necessary - if img.mode in ('RGBA', 'LA', 'P'): - # Create white background - rgb_img = Image.new('RGB', img.size, (255, 255, 255)) - if img.mode == 'RGBA': - rgb_img.paste(img, mask=img.split()[-1]) # Use alpha channel as mask - else: - rgb_img.paste(img) - img = rgb_img - elif img.mode != 'RGB': + if img.mode != 'RGB': img = img.convert('RGB') - - # Resize if image is too large + max_width, max_height = self.config.pdf_max_image_size if img.width > max_width or img.height > max_height: - # Calculate new size maintaining aspect ratio ratio = min(max_width / img.width, max_height / img.height) - new_width = int(img.width * ratio) - new_height = int(img.height * ratio) - - img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - logger.debug(f"Resized image to {new_width}x{new_height}") - + new_size = (int(img.width * ratio), int(img.height * ratio)) + img = img.resize(new_size, Image.Resampling.LANCZOS) + logger.debug(f"Resized image to {new_size[0]}x{new_size[1]}") return img - + def create_thumbnail(self, image_path: str) -> str: - """Create thumbnail for quick page selection""" + """ + Creates a thumbnail for a rendered page image. + + Args: + image_path: The path to the page image. + + Returns: + The path to the thumbnail, or the original path on failure. + """ try: with Image.open(image_path) as img: - # Create thumbnail thumbnail = img.copy() thumbnail.thumbnail(self.config.thumbnail_size, Image.Resampling.LANCZOS) - - # Save thumbnail thumb_path = image_path.replace('.jpg', '_thumb.jpg') thumbnail.save(thumb_path, 'JPEG', quality=85, optimize=True) - return thumb_path - except Exception as e: logger.error(f"Failed to create thumbnail for {image_path}: {e}") - return image_path # Return original if thumbnail creation fails - + return image_path + def get_pdf_metadata(self, file_path: str) -> dict: - """Extract PDF metadata""" + """ + Extracts metadata from a PDF file. + + Args: + file_path: The path to the PDF file. + + Returns: + A dictionary containing the PDF's metadata. + """ try: - pdf_doc = fitz.open(file_path) - metadata = pdf_doc.metadata - page_count = pdf_doc.page_count - pdf_doc.close() - - return { - 'title': metadata.get('title', ''), - 'author': metadata.get('author', ''), - 'subject': metadata.get('subject', ''), - 'creator': metadata.get('creator', ''), - 'producer': metadata.get('producer', ''), - 'creation_date': metadata.get('creationDate', ''), - 'modification_date': metadata.get('modDate', ''), - 'page_count': page_count - } + with fitz.open(file_path) as pdf_doc: + return { + **pdf_doc.metadata, + 'page_count': pdf_doc.page_count + } except Exception as e: logger.error(f"Failed to extract PDF metadata: {e}") - return {} - - -# Import io for BytesIO -import io \ No newline at end of file + return {} \ No newline at end of file diff --git a/docpixie/providers/anthropic.py b/docpixie/providers/anthropic.py index 3dd1362..a389f85 100644 --- a/docpixie/providers/anthropic.py +++ b/docpixie/providers/anthropic.py @@ -1,5 +1,10 @@ """ -Anthropic Claude provider for raw API operations +Anthropic Claude provider for raw API operations. + +This module provides an implementation of the `BaseProvider` for the Anthropic +Claude family of models. It handles the specific request and response formats +for the Anthropic API, including the conversion of image data to base64 for +multimodal inputs. """ import logging @@ -12,125 +17,183 @@ class AnthropicProvider(BaseProvider): - """Anthropic Claude provider for raw API operations""" - + """ + An AI provider for interacting with the Anthropic Claude API. + + This class implements the `BaseProvider` interface to send requests to the + Anthropic API. It supports both text-only and multimodal (text and image) + messages and handles the nuances of the Claude API, such as the special + handling of system messages. + + Attributes: + client: An instance of the `anthropic.AsyncAnthropic` client. + model: The specific Anthropic model to be used for API calls. + """ + def __init__(self, config: DocPixieConfig): + """ + Initializes the AnthropicProvider. + + Args: + config: The DocPixie configuration object. + + Raises: + ValueError: If the Anthropic API key is not provided in the config. + ImportError: If the `anthropic` library is not installed. + """ super().__init__(config) - + if not config.anthropic_api_key: raise ValueError("Anthropic API key is required") - - # Import here to make it optional dependency + try: import anthropic self.client = anthropic.AsyncAnthropic(api_key=config.anthropic_api_key) except ImportError: raise ImportError("Anthropic library not found. Install with: pip install anthropic") - - self.model = config.vision_model # Use vision model for multimodal operations - + + self.model = config.vision_model + async def process_text_messages( - self, - messages: List[Dict[str, Any]], - max_tokens: int = 300, + self, + messages: List[Dict[str, Any]], + max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process text-only messages through Anthropic API""" + """ + Processes text-only messages through the Anthropic API. + + This method prepares the messages for the Claude API, including the + special handling of system messages, and sends them to the API. + + Args: + messages: A list of message dictionaries. + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + + Returns: + The text response from the API. + + Raises: + ProviderError: If the API call fails. + """ try: - # Convert system message format for Anthropic claude_messages = self._prepare_claude_text_messages(messages) - + response = await self.client.messages.create( model=self.model, max_tokens=max_tokens, temperature=temperature, messages=claude_messages ) - + result = response.content[0].text.strip() logger.debug(f"Anthropic text response: {result[:50]}...") - + return result - + except Exception as e: logger.error(f"Anthropic text processing failed: {e}") raise ProviderError(f"Text processing failed: {e}", "anthropic") - + async def process_multimodal_messages( - self, - messages: List[Dict[str, Any]], - max_tokens: int = 300, + self, + messages: List[Dict[str, Any]], + max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process multimodal messages (text + images) through Anthropic Vision API""" + """ + Processes multimodal messages (text and images) through the Anthropic API. + + This method converts image paths to base64 encoded data and formats the + messages for the Claude multimodal API. + + Args: + messages: A list of message dictionaries, potentially including + image paths. + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + + Returns: + The text response from the API. + + Raises: + ProviderError: If the API call fails. + """ try: - # Process messages to convert image paths to base64 claude_messages = self._prepare_claude_multimodal_messages(messages) - + response = await self.client.messages.create( model=self.model, max_tokens=max_tokens, temperature=temperature, messages=claude_messages ) - + result = response.content[0].text.strip() logger.debug(f"Anthropic multimodal response: {result[:50]}...") - + return result - + except Exception as e: logger.error(f"Anthropic multimodal processing failed: {e}") raise ProviderError(f"Multimodal processing failed: {e}", "anthropic") - + def _prepare_claude_text_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Prepare text-only messages for Claude API (handle system messages)""" + """ + Prepares text-only messages for the Claude API. + + The Claude API handles system messages by prepending their content to + the first user message. This method performs that transformation. + + Args: + messages: The list of message dictionaries. + + Returns: + A new list of messages formatted for the Claude API. + """ claude_messages = [] - + for message in messages: if message["role"] == "system": - # Claude handles system messages differently - we'll prepend to first user message continue else: claude_messages.append(message) - - # Prepend system message content to first user message if present - system_content = None - for message in messages: - if message["role"] == "system": - system_content = message["content"] - break - + + system_content = next((msg["content"] for msg in messages if msg["role"] == "system"), None) + if system_content and claude_messages and claude_messages[0]["role"] == "user": - # Prepend system content to first user message original_content = claude_messages[0]["content"] claude_messages[0]["content"] = f"{system_content}\n\n{original_content}" - + return claude_messages - + def _prepare_claude_multimodal_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Prepare multimodal messages for Claude API by converting image paths to base64""" + """ + Prepares multimodal messages for the Claude API. + + This method converts image paths to base64 encoded data and handles the + special formatting for system messages in multimodal requests. + + Args: + messages: The list of message dictionaries. + + Returns: + A new list of messages formatted for the Claude multimodal API. + """ claude_messages = [] - system_content = None - - # Extract system message - for message in messages: - if message["role"] == "system": - system_content = message["content"] - break - - for message in messages: - if message["role"] == "system": - continue # Skip system message, will be prepended to user message - elif message["role"] == "user" and isinstance(message["content"], list): - # User message with multimodal content + system_content = next((msg["content"] for msg in messages if msg["role"] == "system"), None) + + user_messages = [msg for msg in messages if msg["role"] != "system"] + + for i, message in enumerate(user_messages): + if message["role"] == "user" and isinstance(message["content"], list): processed_content = [] - + for content_item in message["content"]: if content_item["type"] == "text": processed_content.append(content_item) elif content_item["type"] == "image_path": - # Convert image path to Claude format image_path = content_item["image_path"] if self._validate_image_path(image_path): encoded_image = self._encode_image(image_path) @@ -145,22 +208,16 @@ def _prepare_claude_multimodal_messages(self, messages: List[Dict[str, Any]]) -> else: logger.warning(f"Skipping invalid image path: {image_path}") else: - # Pass through other content types processed_content.append(content_item) - - # Prepend system content to first user message - if system_content and len(claude_messages) == 0: - processed_content.insert(0, { - "type": "text", - "text": system_content - }) - + + if system_content and i == 0: + processed_content.insert(0, {"type": "text", "text": system_content}) + claude_messages.append({ "role": message["role"], "content": processed_content }) else: - # Regular text message claude_messages.append(message) - + return claude_messages \ No newline at end of file diff --git a/docpixie/providers/base.py b/docpixie/providers/base.py index 6e27ead..42318c2 100644 --- a/docpixie/providers/base.py +++ b/docpixie/providers/base.py @@ -1,5 +1,10 @@ """ -Base provider interface for vision AI operations +Base provider interface for vision AI operations. + +This module defines the abstract base class for AI providers, establishing a +common interface for processing text and multimodal messages. It also includes +helper methods for handling images and a custom exception class for provider- +related errors. """ import base64 @@ -16,15 +21,40 @@ @dataclass class APIResult: - """Container for API response with optional cost tracking""" + """ + A container for API responses that includes the generated text and optional + cost tracking. + + Attributes: + text: The text content of the API response. + cost: The optional cost associated with the API call. + """ text: str cost: Optional[float] = None class BaseProvider(ABC): - """Base class for AI vision providers""" + """ + An abstract base class for AI vision providers. + + This class defines the common interface that all AI providers in DocPixie + must implement. It provides a foundation for processing both text-only and + multimodal (text and image) messages, as well as handling cost tracking + and image encoding. + + Attributes: + config: The DocPixie configuration object. + last_api_cost: The cost of the most recent API call. + total_cost: The accumulated cost of all API calls. + """ def __init__(self, config: DocPixieConfig): + """ + Initializes the BaseProvider. + + Args: + config: The DocPixie configuration object. + """ self.config = config self.last_api_cost: Optional[float] = None self.total_cost: float = 0.0 @@ -36,7 +66,19 @@ async def process_text_messages( max_tokens: int = 512, temperature: float = 0.3 ) -> str: - """Process text-only messages through the provider API""" + """ + Processes a list of text-only messages through the provider's API. + + Args: + messages: A list of message dictionaries, typically with 'role' and + 'content' keys. + max_tokens: The maximum number of tokens to generate in the response. + temperature: The sampling temperature for the generation, controlling + the randomness of the output. + + Returns: + The text content of the API response. + """ pass @abstractmethod @@ -46,26 +88,56 @@ async def process_multimodal_messages( max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process messages with text and images through the provider API""" + """ + Processes a list of messages containing both text and images. + + Args: + messages: A list of message dictionaries, which may include image + data. + max_tokens: The maximum number of tokens to generate in the response. + temperature: The sampling temperature for the generation. + + Returns: + The text content of the API response. + """ pass def get_last_cost(self) -> Optional[float]: - """Get the cost of the last API call (if available)""" + """ + Retrieves the cost of the last API call, if available. + + Returns: + The cost of the last API call, or `None` if not available. + """ return self.last_api_cost def get_total_cost(self) -> float: - """Get the total accumulated cost""" + """ + Retrieves the total accumulated cost of all API calls. + + Returns: + The total cost. + """ return self.total_cost def reset_cost_tracking(self): - """Reset cost tracking""" + """Resets the cost tracking attributes to zero.""" self.last_api_cost = None self.total_cost = 0.0 - # Helper methods for image handling (shared by all providers) - def _encode_image(self, image_path: str) -> str: - """Encode image to base64 for API calls""" + """ + Encodes an image file to a base64 string for API calls. + + Args: + image_path: The path to the image file. + + Returns: + The base64-encoded image string. + + Raises: + Exception: If the image file cannot be read or encoded. + """ try: with open(image_path, 'rb') as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') @@ -75,20 +147,51 @@ def _encode_image(self, image_path: str) -> str: raise def _create_image_data_url(self, image_path: str) -> str: - """Create data URL for image""" + """ + Creates a data URL for an image file. + + Args: + image_path: The path to the image file. + + Returns: + A data URL string for the image. + """ encoded_image = self._encode_image(image_path) return f"data:image/jpeg;base64,{encoded_image}" def _validate_image_path(self, image_path: str) -> bool: - """Validate image path exists and is readable""" + """ + Validates that an image path exists and points to a file. + + Args: + image_path: The path to the image file. + + Returns: + `True` if the path is valid, `False` otherwise. + """ path = Path(image_path) return path.exists() and path.is_file() class ProviderError(Exception): - """Exception raised by provider operations""" + """ + A custom exception raised for errors related to provider operations. + + Attributes: + provider: The name of the provider that raised the error. + image_path: The path to the image being processed when the error + occurred, if applicable. + """ def __init__(self, message: str, provider: str, image_path: str = None): + """ + Initializes the ProviderError. + + Args: + message: The error message. + provider: The name of the provider. + image_path: The optional path to the image being processed. + """ self.provider = provider self.image_path = image_path super().__init__(message) diff --git a/docpixie/providers/factory.py b/docpixie/providers/factory.py index 7120656..e577554 100644 --- a/docpixie/providers/factory.py +++ b/docpixie/providers/factory.py @@ -1,5 +1,9 @@ """ -Provider factory for creating AI vision providers +Provider factory for creating AI vision providers. + +This module provides a factory function to instantiate the appropriate AI +provider based on the application's configuration. It helps to decouple the +provider creation logic from the rest of the application. """ from typing import Union @@ -13,68 +17,71 @@ def create_provider(config: DocPixieConfig) -> BaseProvider: """ - Create AI provider based on configuration - + Creates and returns an AI provider instance based on the configuration. + + This factory function selects the appropriate provider class (e.g., + OpenAIProvider, AnthropicProvider) based on the `provider` attribute of the + `DocPixieConfig` object and initializes it with the given configuration. + Args: - config: DocPixie configuration - + config: The DocPixie configuration object. + Returns: - Configured provider instance - + An instance of a class that inherits from `BaseProvider`. + Raises: - ValueError: If provider is not supported + ValueError: If the specified provider is not supported. """ - if config.provider == "openai": - return OpenAIProvider(config) - elif config.provider == "anthropic": - return AnthropicProvider(config) - elif config.provider == "openrouter": - return OpenRouterProvider(config) + provider_map = { + "openai": OpenAIProvider, + "anthropic": AnthropicProvider, + "openrouter": OpenRouterProvider, + } + provider_class = provider_map.get(config.provider) + if provider_class: + return provider_class(config) else: raise ValueError(f"Unsupported provider: {config.provider}") def get_available_providers() -> list[str]: - """Get list of available provider names""" + """ + Returns a list of available provider names. + + Returns: + A list of strings, where each string is a supported provider name. + """ return ["openai", "anthropic", "openrouter"] def validate_provider_config(provider: str, config: DocPixieConfig) -> bool: """ - Validate provider configuration - + Validates the configuration for a specific provider. + + This function checks if the necessary configuration options (e.g., API keys, + models) are present in the `DocPixieConfig` object for the given provider. + Args: - provider: Provider name - config: Configuration to validate - + provider: The name of the provider to validate. + config: The DocPixie configuration object. + Returns: - True if configuration is valid - + `True` if the configuration is valid for the provider. + Raises: - ValueError: If configuration is invalid + ValueError: If the provider is unknown or the configuration is invalid. """ if provider not in get_available_providers(): raise ValueError(f"Unknown provider: {provider}") - - if provider == "openai": - if not config.openai_api_key: - raise ValueError("OpenAI API key is required") - if not config.vision_model: - raise ValueError("Vision model is required") - return True - - elif provider == "anthropic": - if not config.anthropic_api_key: - raise ValueError("Anthropic API key is required") - if not config.vision_model: - raise ValueError("Vision model is required") - return True - - elif provider == "openrouter": - if not config.openrouter_api_key: - raise ValueError("OpenRouter API key is required") - if not config.vision_model: - raise ValueError("Vision model is required") - return True - - return False \ No newline at end of file + + required_keys = { + "openai": ["openai_api_key", "vision_model"], + "anthropic": ["anthropic_api_key", "vision_model"], + "openrouter": ["openrouter_api_key", "vision_model"], + } + + for key in required_keys.get(provider, []): + if not getattr(config, key, None): + raise ValueError(f"{key.replace('_', ' ').title()} is required for {provider}") + + return True \ No newline at end of file diff --git a/docpixie/providers/openai.py b/docpixie/providers/openai.py index 5a1031f..345b8e1 100644 --- a/docpixie/providers/openai.py +++ b/docpixie/providers/openai.py @@ -1,5 +1,10 @@ """ -OpenAI GPT-4V provider for raw API operations +OpenAI GPT-4V provider for raw API operations. + +This module provides an implementation of the `BaseProvider` for the OpenAI +GPT-4V model. It handles the specific request and response formats for the +OpenAI API, including the conversion of image data to data URLs for multimodal +inputs. """ import logging @@ -12,30 +17,62 @@ class OpenAIProvider(BaseProvider): - """OpenAI GPT-4V provider for raw API operations""" - + """ + An AI provider for interacting with the OpenAI API. + + This class implements the `BaseProvider` interface to send requests to the + OpenAI API, particularly for use with vision-capable models like GPT-4V. + It supports both text-only and multimodal (text and image) messages. + + Attributes: + client: An instance of the `openai.AsyncOpenAI` client. + model: The specific OpenAI model to be used for API calls. + """ + def __init__(self, config: DocPixieConfig): + """ + Initializes the OpenAIProvider. + + Args: + config: The DocPixie configuration object. + + Raises: + ValueError: If the OpenAI API key is not provided in the config. + ImportError: If the `openai` library is not installed. + """ super().__init__(config) - + if not config.openai_api_key: raise ValueError("OpenAI API key is required") - - # Import here to make it optional dependency + try: from openai import AsyncOpenAI self.client = AsyncOpenAI(api_key=config.openai_api_key) except ImportError: raise ImportError("OpenAI library not found. Install with: pip install openai") - + self.model = config.vision_model - + async def process_text_messages( - self, - messages: List[Dict[str, Any]], - max_tokens: int = 300, + self, + messages: List[Dict[str, Any]], + max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process text-only messages through OpenAI API""" + """ + Processes text-only messages through the OpenAI API. + + Args: + messages: A list of message dictionaries. + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + + Returns: + The text response from the API. + + Raises: + ProviderError: If the API call fails. + """ try: response = await self.client.chat.completions.create( model=self.config.model, @@ -43,60 +80,82 @@ async def process_text_messages( max_tokens=max_tokens, temperature=temperature ) - + result = response.choices[0].message.content.strip() logger.debug(f"OpenAI text response: {result[:50]}...") - + return result - + except Exception as e: logger.error(f"OpenAI text processing failed: {e}") raise ProviderError(f"Text processing failed: {e}", "openai") - + async def process_multimodal_messages( - self, - messages: List[Dict[str, Any]], - max_tokens: int = 300, + self, + messages: List[Dict[str, Any]], + max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process multimodal messages (text + images) through OpenAI Vision API""" + """ + Processes multimodal messages (text and images) through the OpenAI API. + + This method converts image paths to data URLs and formats the messages + for the OpenAI multimodal API. + + Args: + messages: A list of message dictionaries, potentially including + image paths. + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + + Returns: + The text response from the API. + + Raises: + ProviderError: If the API call fails. + """ try: - # Process messages to convert image paths to data URLs processed_messages = self._prepare_openai_messages(messages) - + response = await self.client.chat.completions.create( - model=self.model, # Use vision model + model=self.model, messages=processed_messages, max_tokens=max_tokens, temperature=temperature ) - + result = response.choices[0].message.content.strip() logger.debug(f"OpenAI multimodal response: {result[:50]}...") - + return result - + except Exception as e: logger.error(f"OpenAI multimodal processing failed: {e}") raise ProviderError(f"Multimodal processing failed: {e}", "openai") - + def _prepare_openai_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Prepare messages for OpenAI API by converting image paths to data URLs""" + """ + Prepares messages for the OpenAI API. + + This method converts any `image_path` content items into the data URL + format required by the OpenAI API for multimodal inputs. + + Args: + messages: The list of message dictionaries. + + Returns: + A new list of messages formatted for the OpenAI API. + """ processed_messages = [] - + for message in messages: - if message["role"] == "system": - # System messages are text-only - processed_messages.append(message) - elif message["role"] == "user" and isinstance(message["content"], list): - # User message with multimodal content + if message["role"] == "user" and isinstance(message["content"], list): processed_content = [] - + for content_item in message["content"]: if content_item["type"] == "text": processed_content.append(content_item) elif content_item["type"] == "image_path": - # Convert image path to OpenAI format image_path = content_item["image_path"] if self._validate_image_path(image_path): image_data_url = self._create_image_data_url(image_path) @@ -110,15 +169,13 @@ def _prepare_openai_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[ else: logger.warning(f"Skipping invalid image path: {image_path}") else: - # Pass through other content types processed_content.append(content_item) - + processed_messages.append({ "role": message["role"], "content": processed_content }) else: - # Regular text message processed_messages.append(message) - + return processed_messages \ No newline at end of file diff --git a/docpixie/providers/openrouter.py b/docpixie/providers/openrouter.py index 7296556..3a4b329 100644 --- a/docpixie/providers/openrouter.py +++ b/docpixie/providers/openrouter.py @@ -1,6 +1,10 @@ """ -OpenRouter provider for raw API operations -Uses OpenAI client with OpenRouter's API endpoint +OpenRouter provider for raw API operations. + +This module provides an implementation of the `BaseProvider` for the OpenRouter +service. It utilizes the OpenAI client library to interact with the OpenRouter +API endpoint, which allows for access to a wide variety of language and vision +models from different providers. """ import logging @@ -13,15 +17,37 @@ class OpenRouterProvider(BaseProvider): - """OpenRouter provider for raw API operations""" + """ + An AI provider for interacting with the OpenRouter API. + + This class leverages the `openai` library to communicate with the OpenRouter + API, which acts as a gateway to numerous AI models. It supports both text + and multimodal messages and includes cost tracking for OpenRouter's usage + reporting. + + Attributes: + client: An instance of the `openai.AsyncOpenAI` client configured for + the OpenRouter API. + model: The specific model to be used for API calls, as specified in + the OpenRouter model registry. + """ def __init__(self, config: DocPixieConfig): + """ + Initializes the OpenRouterProvider. + + Args: + config: The DocPixie configuration object. + + Raises: + ValueError: If the OpenRouter API key is not provided. + ImportError: If the `openai` library is not installed. + """ super().__init__(config) if not config.openrouter_api_key: raise ValueError("OpenRouter API key is required") - # Import here to make it optional dependency try: from openai import AsyncOpenAI self.client = AsyncOpenAI( @@ -39,28 +65,38 @@ async def process_text_messages( max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process text-only messages through OpenRouter API""" + """ + Processes text-only messages through the OpenRouter API. + + Args: + messages: A list of message dictionaries. + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + + Returns: + The text response from the API. + + Raises: + ProviderError: If the API call fails. + """ try: response = await self.client.chat.completions.create( model=self.config.model, messages=messages, max_tokens=max_tokens, temperature=temperature, - extra_body= { - "usage": { - "include": True, - }, + extra_body={ + "usage": {"include": True}, }, ) result = response.choices[0].message.content.strip() logger.debug(f"OpenRouter text response: {result[:50]}...") - # Track cost if available - if hasattr(response, 'usage') and hasattr(response.usage, 'cost'): - self.last_api_cost = response.usage.cost - self.total_cost += response.usage.cost - logger.debug(f"OpenRouter API cost: ${response.usage.cost}") + if hasattr(response, 'usage') and response.usage and hasattr(response.usage, 'total_cost'): + self.last_api_cost = response.usage.total_cost + self.total_cost += response.usage.total_cost + logger.debug(f"OpenRouter API cost: ${response.usage.total_cost}") else: self.last_api_cost = None @@ -76,31 +112,41 @@ async def process_multimodal_messages( max_tokens: int = 300, temperature: float = 0.3 ) -> str: - """Process multimodal messages (text + images) through OpenRouter API""" + """ + Processes multimodal messages through the OpenRouter API. + + Args: + messages: A list of message dictionaries, potentially including + image paths. + max_tokens: The maximum number of tokens to generate. + temperature: The sampling temperature. + + Returns: + The text response from the API. + + Raises: + ProviderError: If the API call fails. + """ try: - # Process messages to convert image paths to data URLs processed_messages = self._prepare_openai_messages(messages) response = await self.client.chat.completions.create( - model=self.model, # Use vision model + model=self.model, messages=processed_messages, max_tokens=max_tokens, temperature=temperature, - extra_body= { - "usage": { - "include": True, - }, + extra_body={ + "usage": {"include": True}, }, ) result = response.choices[0].message.content.strip() logger.debug(f"OpenRouter multimodal response: {result[:50]}...") - # Track cost if available - if hasattr(response, 'usage') and hasattr(response.usage, 'cost'): - self.last_api_cost = response.usage.cost - self.total_cost += response.usage.cost - logger.debug(f"OpenRouter API cost: ${response.usage.cost}") + if hasattr(response, 'usage') and response.usage and hasattr(response.usage, 'total_cost'): + self.last_api_cost = response.usage.total_cost + self.total_cost += response.usage.total_cost + logger.debug(f"OpenRouter API cost: ${response.usage.total_cost}") else: self.last_api_cost = None @@ -111,22 +157,28 @@ async def process_multimodal_messages( raise ProviderError(f"Multimodal processing failed: {e}", "openrouter") def _prepare_openai_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Prepare messages for OpenRouter API by converting image paths to data URLs""" + """ + Prepares messages for the OpenRouter API. + + This method converts image paths to the data URL format, which is + compatible with OpenAI's API and, by extension, OpenRouter. + + Args: + messages: The list of message dictionaries. + + Returns: + A new list of messages formatted for the API. + """ processed_messages = [] for message in messages: - if message["role"] == "system": - # System messages are text-only - processed_messages.append(message) - elif message["role"] == "user" and isinstance(message["content"], list): - # User message with multimodal content + if message["role"] == "user" and isinstance(message["content"], list): processed_content = [] for content_item in message["content"]: if content_item["type"] == "text": processed_content.append(content_item) elif content_item["type"] == "image_path": - # Convert image path to OpenRouter format (same as OpenAI) image_path = content_item["image_path"] if self._validate_image_path(image_path): image_data_url = self._create_image_data_url(image_path) @@ -140,7 +192,6 @@ def _prepare_openai_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[ else: logger.warning(f"Skipping invalid image path: {image_path}") else: - # Pass through other content types processed_content.append(content_item) processed_messages.append({ @@ -148,7 +199,6 @@ def _prepare_openai_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[ "content": processed_content }) else: - # Regular text message processed_messages.append(message) return processed_messages diff --git a/docpixie/storage/base.py b/docpixie/storage/base.py index 95af52b..e3c6940 100644 --- a/docpixie/storage/base.py +++ b/docpixie/storage/base.py @@ -1,5 +1,9 @@ """ -Base storage interface for documents +Base storage interface for documents. + +This module defines the abstract base class for storage backends in DocPixie. +It establishes a common interface for saving, retrieving, and managing +documents and their associated data. """ from abc import ABC, abstractmethod @@ -12,172 +16,182 @@ class BaseStorage(ABC): - """Base class for storage backends""" - + """ + An abstract base class for storage backends. + + This class defines the essential methods that any storage implementation + in DocPixie must provide. It ensures a consistent API for document + management, regardless of the underlying storage mechanism (e.g., local + filesystem, in-memory). + """ + @abstractmethod async def save_document(self, document: Document) -> str: """ - Save a processed document - + Saves a processed document to the storage. + Args: - document: Document to save - + document: The `Document` object to be saved. + Returns: - Document ID + The unique ID of the saved document. """ pass - + @abstractmethod async def get_document(self, document_id: str) -> Optional[Document]: """ - Retrieve a document by ID - + Retrieves a document by its unique ID. + Args: - document_id: ID of document to retrieve - + document_id: The ID of the document to retrieve. + Returns: - Document or None if not found + The `Document` object if found, otherwise `None`. """ pass - + @abstractmethod async def list_documents(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: """ - List all documents with metadata - + Lists all documents in the storage, returning their metadata. + Args: - limit: Maximum number of documents to return - + limit: The maximum number of documents to return. + Returns: - List of document metadata dicts + A list of dictionaries, where each dictionary represents the + metadata of a document. """ pass - + @abstractmethod async def delete_document(self, document_id: str) -> bool: """ - Delete a document and its associated files - + Deletes a document and its associated files from the storage. + Args: - document_id: ID of document to delete - + document_id: The ID of the document to delete. + Returns: - True if deletion was successful + `True` if the deletion was successful, `False` otherwise. """ pass - + @abstractmethod async def document_exists(self, document_id: str) -> bool: """ - Check if document exists - + Checks if a document with the given ID exists in the storage. + Args: - document_id: Document ID to check - + document_id: The ID of the document to check. + Returns: - True if document exists + `True` if the document exists, `False` otherwise. """ pass - + @abstractmethod async def get_document_summary(self, document_id: str) -> Optional[str]: """ - Get document summary without loading full document - + Retrieves the summary of a document without loading the entire object. + Args: - document_id: Document ID - + document_id: The ID of the document. + Returns: - Document summary or None + The document summary as a string, or `None` if not found. """ pass - + @abstractmethod async def update_document_summary(self, document_id: str, summary: str) -> bool: """ - Update document summary - + Updates the summary of a specific document. + Args: - document_id: Document ID - summary: New summary text - + document_id: The ID of the document to update. + summary: The new summary text. + Returns: - True if update was successful + `True` if the update was successful, `False` otherwise. """ pass - + @abstractmethod async def get_all_documents(self) -> List[Document]: """ - Get all documents for agent processing - + Retrieves all documents from the storage. + + This method is primarily intended for use by the RAG agent for + processing. + Returns: - List of all documents in storage + A list of all `Document` objects in the storage. """ pass - + @abstractmethod async def get_all_pages(self) -> List[Page]: """ - Get all pages from all documents for agent processing - + Retrieves all pages from all documents in the storage. + + This method is primarily intended for use by the RAG agent. + Returns: - List of all pages across all documents + A list of all `Page` objects across all documents. """ pass - + async def get_documents_by_ids(self, document_ids: List[str]) -> List[Document]: """ - Get multiple documents by IDs - + Retrieves multiple documents based on a list of IDs. + Args: - document_ids: List of document IDs - + document_ids: A list of document IDs to retrieve. + Returns: - List of documents (may be fewer than requested if some not found) + A list of `Document` objects. This list may be smaller than the + input list if some documents are not found. """ - documents = [] - for doc_id in document_ids: - doc = await self.get_document(doc_id) - if doc: - documents.append(doc) + documents = [doc for doc_id in document_ids if (doc := await self.get_document(doc_id))] return documents - + async def search_documents(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: """ - Simple text search in document names and summaries - Default implementation - subclasses can override for better search - + Performs a simple text search on document names and summaries. + + This is a default implementation that can be overridden by subclasses + for more sophisticated search capabilities. + Args: - query: Search query - limit: Maximum results - + query: The search query string. + limit: The maximum number of results to return. + Returns: - List of matching document metadata + A list of matching document metadata dictionaries. """ all_docs = await self.list_documents() matching_docs = [] query_lower = query.lower() - + for doc_meta in all_docs: - name_match = query_lower in doc_meta.get('name', '').lower() - summary_match = query_lower in doc_meta.get('summary', '').lower() - - if name_match or summary_match: + if any(query_lower in str(doc_meta.get(field, '')).lower() for field in ['name', 'summary']): matching_docs.append(doc_meta) if len(matching_docs) >= limit: break - + return matching_docs - + def get_storage_stats(self) -> Dict[str, Any]: """ - Get storage statistics - Default implementation - subclasses can override - + Retrieves statistics about the storage backend. + + This is a default implementation that can be overridden by subclasses. + Returns: - Dictionary with storage statistics + A dictionary containing storage statistics. """ return { 'backend': self.__class__.__name__, @@ -186,8 +200,21 @@ def get_storage_stats(self) -> Dict[str, Any]: class StorageError(Exception): - """Exception raised by storage operations""" - + """ + A custom exception raised for errors related to storage operations. + + Attributes: + document_id: The ID of the document associated with the error, if + applicable. + """ + def __init__(self, message: str, document_id: Optional[str] = None): + """ + Initializes the StorageError. + + Args: + message: The error message. + document_id: The optional ID of the document being processed. + """ self.document_id = document_id super().__init__(message) \ No newline at end of file diff --git a/docpixie/storage/local.py b/docpixie/storage/local.py index ff4f7d0..1983dfb 100644 --- a/docpixie/storage/local.py +++ b/docpixie/storage/local.py @@ -1,6 +1,10 @@ """ -Local file system storage backend -Adapted from production LocalStorage but simplified for open-source version +Local file system storage backend. + +This module provides a concrete implementation of the `BaseStorage` interface +that uses the local file system to store documents, their pages, and metadata. +It is a simplified version of the `LocalStorage` used in the production +DocPixie application. """ import os @@ -20,345 +24,278 @@ class LocalStorage(BaseStorage): - """Local file system storage backend""" - + """ + A storage backend that uses the local file system. + + This class implements the `BaseStorage` interface, providing a way to + store and manage documents on the local disk. Each document is stored in + its own directory, containing a `metadata.json` file and a `pages` + subdirectory for the page images. + + Attributes: + config: The DocPixie configuration object. + base_path: The root directory for the local storage. + """ + def __init__(self, config: DocPixieConfig): + """ + Initializes the LocalStorage. + + Args: + config: The DocPixie configuration object. + """ self.config = config self.base_path = Path(config.local_storage_path) self.base_path.mkdir(parents=True, exist_ok=True) logger.info(f"Initialized local storage at: {self.base_path}") - + def _doc_dir(self, document_id: str) -> Path: - """Get document directory path""" + """ + Returns the path to the directory for a specific document. + + Args: + document_id: The ID of the document. + + Returns: + The path to the document's directory. + """ return self.base_path / document_id - + def _metadata_path(self, document_id: str) -> Path: - """Get metadata file path""" + """ + Returns the path to the metadata file for a specific document. + + Args: + document_id: The ID of the document. + + Returns: + The path to the metadata JSON file. + """ return self._doc_dir(document_id) / "metadata.json" - + def _pages_dir(self, document_id: str) -> Path: - """Get pages directory path""" + """ + Returns the path to the pages directory for a specific document. + + Args: + document_id: The ID of the document. + + Returns: + The path to the directory containing the page images. + """ return self._doc_dir(document_id) / "pages" - + async def save_document(self, document: Document) -> str: - """Save document to local storage""" + """Saves a document to the local storage.""" try: doc_dir = self._doc_dir(document.id) pages_dir = self._pages_dir(document.id) - - # Create directories doc_dir.mkdir(parents=True, exist_ok=True) pages_dir.mkdir(parents=True, exist_ok=True) - - # Copy page images to storage - stored_pages = [] - for page in document.pages: - if os.path.exists(page.image_path): - # Copy page image to storage - page_filename = f"page_{page.page_number:03d}.jpg" - dest_path = pages_dir / page_filename - - await asyncio.get_event_loop().run_in_executor( - None, shutil.copy2, page.image_path, dest_path - ) - - # Update page with new path - stored_page = Page( - page_number=page.page_number, - image_path=str(dest_path), - metadata=page.metadata, - document_name=page.document_name, - document_id=page.document_id - ) - stored_pages.append(stored_page) - else: - logger.warning(f"Page image not found: {page.image_path}") - - # Create metadata - metadata = { - "id": document.id, - "name": document.name, - "summary": document.summary, - "status": document.status.value, - "page_count": len(stored_pages), - "pages": [ - { - "page_number": page.page_number, - "image_path": page.image_path, - "metadata": page.metadata, - "document_name": page.document_name, - "document_id": page.document_id - } - for page in stored_pages - ], - "metadata": document.metadata, - "created_at": document.created_at.isoformat(), - "updated_at": datetime.now().isoformat() - } - - # Save metadata - metadata_path = self._metadata_path(document.id) - with open(metadata_path, 'w') as f: + + stored_pages = await self._store_page_images(document, pages_dir) + + metadata = self._create_document_metadata(document, stored_pages) + with open(self._metadata_path(document.id), 'w') as f: json.dump(metadata, f, indent=2) - + logger.info(f"Saved document {document.id} with {len(stored_pages)} pages") return document.id - + except Exception as e: logger.error(f"Failed to save document {document.id}: {e}") - # Clean up on error doc_dir = self._doc_dir(document.id) if doc_dir.exists(): shutil.rmtree(doc_dir, ignore_errors=True) raise StorageError(f"Failed to save document: {e}", document.id) - + async def get_document(self, document_id: str) -> Optional[Document]: - """Retrieve document from local storage""" + """Retrieves a document from the local storage.""" try: metadata_path = self._metadata_path(document_id) if not metadata_path.exists(): return None - - # Load metadata + with open(metadata_path, 'r') as f: metadata = json.load(f) - - # Reconstruct pages - pages = [] - for page_data in metadata.get('pages', []): - page = Page( - page_number=page_data['page_number'], - image_path=page_data['image_path'], - metadata=page_data.get('metadata', {}), - document_name=page_data.get('document_name'), - document_id=page_data.get('document_id') - ) - pages.append(page) - - # Reconstruct document - document = Document( - id=metadata['id'], - name=metadata['name'], - pages=pages, - summary=metadata.get('summary'), - metadata=metadata.get('metadata', {}), - created_at=datetime.fromisoformat(metadata['created_at']) - ) - - return document - + + return self._reconstruct_document(metadata) + except Exception as e: logger.error(f"Failed to load document {document_id}: {e}") raise StorageError(f"Failed to load document: {e}", document_id) - + async def list_documents(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """List all documents""" + """Lists all documents in the local storage.""" try: documents = [] - if not self.base_path.exists(): return documents - - for doc_dir in self.base_path.iterdir(): - if not doc_dir.is_dir(): - continue - - metadata_path = doc_dir / "metadata.json" - if not metadata_path.exists(): - continue - + + doc_dirs = [d for d in self.base_path.iterdir() if d.is_dir()] + for doc_dir in doc_dirs: try: - with open(metadata_path, 'r') as f: - metadata = json.load(f) - - # Return summary info - doc_info = { - 'id': metadata['id'], - 'name': metadata['name'], - 'summary': metadata.get('summary'), - 'page_count': metadata.get('page_count', 0), - 'created_at': metadata['created_at'], - 'updated_at': metadata.get('updated_at'), - 'status': metadata.get('status', 'unknown') - } - documents.append(doc_info) - + metadata_path = doc_dir / "metadata.json" + if metadata_path.exists(): + with open(metadata_path, 'r') as f: + metadata = json.load(f) + documents.append(self._extract_document_info(metadata)) except Exception as e: logger.warning(f"Failed to read metadata for {doc_dir.name}: {e}") - continue if limit and len(documents) >= limit: break - - # Sort by creation time (newest first) + documents.sort(key=lambda x: x['created_at'], reverse=True) return documents - + except Exception as e: logger.error(f"Failed to list documents: {e}") raise StorageError(f"Failed to list documents: {e}") - + async def delete_document(self, document_id: str) -> bool: - """Delete document and all associated files""" + """Deletes a document from the local storage.""" try: doc_dir = self._doc_dir(document_id) if doc_dir.exists(): - await asyncio.get_event_loop().run_in_executor( - None, shutil.rmtree, doc_dir - ) + await asyncio.to_thread(shutil.rmtree, doc_dir) logger.info(f"Deleted document {document_id}") return True - else: - logger.warning(f"Document directory not found: {document_id}") - return False - + logger.warning(f"Document directory not found: {document_id}") + return False + except Exception as e: logger.error(f"Failed to delete document {document_id}: {e}") raise StorageError(f"Failed to delete document: {e}", document_id) - + async def document_exists(self, document_id: str) -> bool: - """Check if document exists""" - metadata_path = self._metadata_path(document_id) - return metadata_path.exists() - + """Checks if a document exists in the local storage.""" + return self._metadata_path(document_id).exists() + async def get_document_summary(self, document_id: str) -> Optional[str]: - """Get document summary without loading full document""" + """Retrieves a document's summary from the local storage.""" try: metadata_path = self._metadata_path(document_id) if not metadata_path.exists(): return None - with open(metadata_path, 'r') as f: metadata = json.load(f) - return metadata.get('summary') - + except Exception as e: logger.error(f"Failed to get summary for {document_id}: {e}") return None - + async def update_document_summary(self, document_id: str, summary: str) -> bool: - """Update document summary""" + """Updates a document's summary in the local storage.""" try: metadata_path = self._metadata_path(document_id) if not metadata_path.exists(): return False - - # Load existing metadata + with open(metadata_path, 'r') as f: metadata = json.load(f) - - # Update summary and timestamp + metadata['summary'] = summary metadata['updated_at'] = datetime.now().isoformat() - - # Save updated metadata + with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) - + logger.info(f"Updated summary for document {document_id}") return True - + except Exception as e: logger.error(f"Failed to update summary for {document_id}: {e}") return False - - def get_document_pages(self, document_id: str) -> List[str]: - """Get list of page image paths (synchronous helper method)""" - pages_dir = self._pages_dir(document_id) - if not pages_dir.exists(): - return [] - - # Get all page files sorted by name - page_files = sorted([ - f for f in pages_dir.iterdir() - if f.is_file() and f.name.lower().startswith("page_") - ]) - - return [str(f) for f in page_files] - + async def get_all_documents(self) -> List[Document]: - """Get all documents for agent processing""" + """Retrieves all documents from the local storage.""" try: - documents = [] - - if not self.base_path.exists(): - return documents - - # Get all document directories - for doc_dir in self.base_path.iterdir(): - if doc_dir.is_dir(): - try: - # Load the document - document = await self.get_document(doc_dir.name) - if document: - documents.append(document) - except Exception as e: - logger.warning(f"Failed to load document {doc_dir.name}: {e}") - continue - - return documents - + doc_ids = [d.name for d in self.base_path.iterdir() if d.is_dir()] + return [doc for doc_id in doc_ids if (doc := await self.get_document(doc_id))] except Exception as e: logger.error(f"Failed to get all documents: {e}") raise StorageError(f"Failed to get all documents: {e}") - + async def get_all_pages(self) -> List[Page]: - """Get all pages from all documents for agent processing""" + """Retrieves all pages from all documents in the local storage.""" try: - all_pages = [] - - if not self.base_path.exists(): - return all_pages - - # Get all document directories - for doc_dir in self.base_path.iterdir(): - if doc_dir.is_dir(): - try: - # Load the document - document = await self.get_document(doc_dir.name) - if document and document.pages: - all_pages.extend(document.pages) - except Exception as e: - logger.warning(f"Failed to load pages from document {doc_dir.name}: {e}") - continue - - return all_pages - + all_documents = await self.get_all_documents() + return [page for doc in all_documents for page in doc.pages] except Exception as e: logger.error(f"Failed to get all pages: {e}") raise StorageError(f"Failed to get all pages: {e}") - + def get_storage_stats(self) -> Dict[str, Any]: - """Get storage statistics""" + """Retrieves statistics about the local storage.""" try: - total_size = 0 - total_documents = 0 - total_pages = 0 - - if self.base_path.exists(): - for doc_dir in self.base_path.iterdir(): - if doc_dir.is_dir(): - total_documents += 1 - for file_path in doc_dir.rglob('*'): - if file_path.is_file(): - total_size += file_path.stat().st_size - if file_path.name.startswith('page_'): - total_pages += 1 - + total_size = sum(f.stat().st_size for f in self.base_path.glob('**/*') if f.is_file()) + doc_dirs = [d for d in self.base_path.iterdir() if d.is_dir()] + total_pages = sum(len(list((d / 'pages').glob('page_*.jpg'))) for d in doc_dirs) + return { 'backend': 'LocalStorage', 'base_path': str(self.base_path), - 'total_documents': total_documents, + 'total_documents': len(doc_dirs), 'total_pages': total_pages, 'total_size_bytes': total_size, 'total_size_mb': round(total_size / (1024 * 1024), 2), 'features': ['local_storage', 'metadata', 'page_images'] } - except Exception as e: logger.error(f"Failed to get storage stats: {e}") - return { - 'backend': 'LocalStorage', - 'error': str(e) - } \ No newline at end of file + return {'backend': 'LocalStorage', 'error': str(e)} + + # Helper methods + async def _store_page_images(self, document: Document, pages_dir: Path) -> List[Page]: + stored_pages = [] + for page in document.pages: + if os.path.exists(page.image_path): + dest_path = pages_dir / f"page_{page.page_number:03d}.jpg" + await asyncio.to_thread(shutil.copy2, page.image_path, dest_path) + stored_pages.append(Page( + page_number=page.page_number, + image_path=str(dest_path), + metadata=page.metadata, + document_name=page.document_name, + document_id=page.document_id + )) + else: + logger.warning(f"Page image not found: {page.image_path}") + return stored_pages + + def _create_document_metadata(self, document: Document, stored_pages: List[Page]) -> Dict[str, Any]: + return { + "id": document.id, + "name": document.name, + "summary": document.summary, + "status": document.status.value, + "page_count": len(stored_pages), + "pages": [page.__dict__ for page in stored_pages], + "metadata": document.metadata, + "created_at": document.created_at.isoformat(), + "updated_at": datetime.now().isoformat() + } + + def _reconstruct_document(self, metadata: Dict[str, Any]) -> Document: + pages = [Page(**page_data) for page_data in metadata.get('pages', [])] + return Document( + id=metadata['id'], + name=metadata['name'], + pages=pages, + summary=metadata.get('summary'), + metadata=metadata.get('metadata', {}), + created_at=datetime.fromisoformat(metadata['created_at']) + ) + + def _extract_document_info(self, metadata: Dict[str, Any]) -> Dict[str, Any]: + return { + 'id': metadata['id'], + 'name': metadata['name'], + 'summary': metadata.get('summary'), + 'page_count': metadata.get('page_count', 0), + 'created_at': metadata['created_at'], + 'updated_at': metadata.get('updated_at'), + 'status': metadata.get('status', 'unknown') + } \ No newline at end of file diff --git a/docpixie/storage/memory.py b/docpixie/storage/memory.py index cb19bb6..2e1b23c 100644 --- a/docpixie/storage/memory.py +++ b/docpixie/storage/memory.py @@ -1,5 +1,9 @@ """ -In-memory storage backend for testing +In-memory storage backend for testing and development. + +This module provides a non-persistent, in-memory implementation of the +`BaseStorage` interface. It is primarily used for testing, development, and +scenarios where data persistence is not required. """ import asyncio @@ -16,212 +20,173 @@ class InMemoryStorage(BaseStorage): - """In-memory storage backend for testing and development""" - + """ + An in-memory storage backend for documents. + + This class provides a simple, non-persistent storage solution that keeps + all documents and their data in memory. It is ideal for testing and + development purposes. + + Attributes: + config: The DocPixie configuration object. + """ + def __init__(self, config: DocPixieConfig): + """ + Initializes the InMemoryStorage. + + Args: + config: The DocPixie configuration object. + """ self.config = config self._documents: Dict[str, Document] = {} self._document_summaries: Dict[str, str] = {} self._created_at = datetime.now() logger.info("Initialized in-memory storage") - + async def save_document(self, document: Document) -> str: - """Save document to memory""" + """Saves a document to the in-memory store.""" try: - # Deep copy to avoid external modifications stored_document = copy.deepcopy(document) - - # Store document self._documents[document.id] = stored_document - - # Store summary separately for quick access if document.summary: self._document_summaries[document.id] = document.summary - logger.info(f"Saved document {document.id} to memory ({len(document.pages)} pages)") return document.id - except Exception as e: logger.error(f"Failed to save document {document.id} to memory: {e}") raise StorageError(f"Failed to save document: {e}", document.id) - + async def get_document(self, document_id: str) -> Optional[Document]: - """Retrieve document from memory""" + """Retrieves a document from the in-memory store.""" try: document = self._documents.get(document_id) - if document: - # Return a deep copy to avoid external modifications - return copy.deepcopy(document) - return None - + return copy.deepcopy(document) if document else None except Exception as e: logger.error(f"Failed to get document {document_id} from memory: {e}") raise StorageError(f"Failed to get document: {e}", document_id) - + async def list_documents(self, limit: Optional[int] = None) -> List[Dict[str, Any]]: - """List all documents in memory""" + """Lists all documents in the in-memory store.""" try: - documents = [] - - for doc_id, document in self._documents.items(): - doc_info = { - 'id': document.id, - 'name': document.name, - 'summary': self._document_summaries.get(doc_id), - 'page_count': len(document.pages), - 'created_at': document.created_at.isoformat(), - 'updated_at': document.created_at.isoformat(), # No update tracking in memory - 'status': document.status.value - } - documents.append(doc_info) - - if limit and len(documents) >= limit: - break - - # Sort by creation time (newest first) + documents = [ + self._create_doc_info(doc_id, doc) + for doc_id, doc in self._documents.items() + ] documents.sort(key=lambda x: x['created_at'], reverse=True) - return documents - + return documents[:limit] if limit else documents except Exception as e: logger.error(f"Failed to list documents in memory: {e}") raise StorageError(f"Failed to list documents: {e}") - + async def delete_document(self, document_id: str) -> bool: - """Delete document from memory""" + """Deletes a document from the in-memory store.""" try: if document_id in self._documents: del self._documents[document_id] self._document_summaries.pop(document_id, None) logger.info(f"Deleted document {document_id} from memory") return True - else: - logger.warning(f"Document {document_id} not found in memory") - return False - + logger.warning(f"Document {document_id} not found in memory") + return False except Exception as e: logger.error(f"Failed to delete document {document_id} from memory: {e}") raise StorageError(f"Failed to delete document: {e}", document_id) - + async def document_exists(self, document_id: str) -> bool: - """Check if document exists in memory""" + """Checks if a document exists in the in-memory store.""" return document_id in self._documents - + async def get_document_summary(self, document_id: str) -> Optional[str]: - """Get document summary from memory""" + """Retrieves a document's summary from the in-memory store.""" return self._document_summaries.get(document_id) - + async def update_document_summary(self, document_id: str, summary: str) -> bool: - """Update document summary in memory""" + """Updates a document's summary in the in-memory store.""" try: if document_id in self._documents: - # Update summary in both document and summary cache self._documents[document_id].summary = summary self._document_summaries[document_id] = summary logger.info(f"Updated summary for document {document_id} in memory") return True - else: - logger.warning(f"Document {document_id} not found for summary update") - return False - + logger.warning(f"Document {document_id} not found for summary update") + return False except Exception as e: logger.error(f"Failed to update summary for {document_id} in memory: {e}") return False - + async def search_documents(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: - """Search documents in memory""" + """Performs a simple text search on documents in memory.""" try: matching_docs = [] query_lower = query.lower() - for doc_id, document in self._documents.items(): - # Check name match - name_match = query_lower in document.name.lower() - - # Check summary match summary = self._document_summaries.get(doc_id, '') - summary_match = query_lower in summary.lower() - - if name_match or summary_match: - doc_info = { - 'id': document.id, - 'name': document.name, - 'summary': summary, - 'page_count': len(document.pages), - 'created_at': document.created_at.isoformat(), - 'status': document.status.value, - 'relevance_score': self._calculate_relevance( - query_lower, document, summary - ) - } + if query_lower in document.name.lower() or query_lower in summary.lower(): + doc_info = self._create_doc_info(doc_id, document) + doc_info['relevance_score'] = self._calculate_relevance(query_lower, document, summary) matching_docs.append(doc_info) - - if len(matching_docs) >= limit: - break - - # Sort by relevance score + matching_docs.sort(key=lambda x: x['relevance_score'], reverse=True) - return matching_docs - + return matching_docs[:limit] except Exception as e: logger.error(f"Failed to search documents in memory: {e}") return [] - - def _calculate_relevance(self, query: str, document: Document, summary: str) -> float: - """Calculate simple relevance score for search results""" - score = 0.0 - - # Name matches are highly relevant - if query in document.name.lower(): - score += 10.0 - - # Summary matches are relevant - summary_matches = summary.lower().count(query) - score += summary_matches * 2.0 - - - return score - + async def get_all_documents(self) -> List[Document]: - """Get all documents for agent processing""" - return list(self._documents.values()) - + """Retrieves all documents from the in-memory store.""" + return list(copy.deepcopy(list(self._documents.values()))) + async def get_all_pages(self) -> List[Page]: - """Get all pages from all documents for agent processing""" - all_pages = [] - for document in self._documents.values(): - if document.pages: - all_pages.extend(document.pages) - return all_pages - + """Retrieves all pages from all documents in the in-memory store.""" + return [ + page + for document in self._documents.values() + for page in (document.pages or []) + ] + def get_storage_stats(self) -> Dict[str, Any]: - """Get storage statistics""" + """Retrieves statistics about the in-memory storage.""" try: - total_pages = sum(len(doc.pages) for doc in self._documents.values()) - return { 'backend': 'InMemoryStorage', - 'total_documents': len(self._documents), - 'total_pages': total_pages, + 'total_documents': self.get_document_count(), + 'total_pages': self.get_total_pages(), 'created_at': self._created_at.isoformat(), 'features': ['in_memory', 'fast_access', 'search', 'testing'] } - except Exception as e: - return { - 'backend': 'InMemoryStorage', - 'error': str(e) - } - + return {'backend': 'InMemoryStorage', 'error': str(e)} + def clear_all(self): - """Clear all documents (useful for testing)""" + """Clears all documents from the in-memory store.""" self._documents.clear() self._document_summaries.clear() logger.info("Cleared all documents from memory") - + def get_document_count(self) -> int: - """Get total number of documents in memory""" + """Returns the total number of documents in memory.""" return len(self._documents) - + def get_total_pages(self) -> int: - """Get total number of pages across all documents""" - return sum(len(doc.pages) for doc in self._documents.values()) \ No newline at end of file + """Returns the total number of pages across all documents.""" + return sum(len(doc.pages) for doc in self._documents.values()) + + # Helper methods + def _create_doc_info(self, doc_id: str, document: Document) -> Dict[str, Any]: + return { + 'id': document.id, + 'name': document.name, + 'summary': self._document_summaries.get(doc_id), + 'page_count': len(document.pages), + 'created_at': document.created_at.isoformat(), + 'updated_at': document.created_at.isoformat(), # No update tracking in memory + 'status': document.status.value + } + + def _calculate_relevance(self, query: str, document: Document, summary: str) -> float: + score = 0.0 + if query in document.name.lower(): + score += 10.0 + score += summary.lower().count(query) * 2.0 + return score \ No newline at end of file diff --git a/docpixie/utils/async_helpers.py b/docpixie/utils/async_helpers.py index 79c94e4..66f0b98 100644 --- a/docpixie/utils/async_helpers.py +++ b/docpixie/utils/async_helpers.py @@ -1,10 +1,13 @@ """ -Async/sync compatibility helpers +Utilities for managing asynchronous and synchronous code compatibility. + +This module provides helper functions and decorators to facilitate the use of +asynchronous functions in synchronous contexts, and vice versa. """ import asyncio import threading -from typing import Any, Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar, Callable from functools import wraps T = TypeVar('T') @@ -12,26 +15,31 @@ def sync_wrapper(coro: Awaitable[T]) -> T: """ - Run async function in sync context - Handles both cases: existing event loop and no event loop + Runs an awaitable in a synchronous context. + + This function handles the complexity of running an async function from a + sync function, including cases where an event loop is already running in + the current thread. + + Args: + coro: The awaitable to run. + + Returns: + The result of the awaitable. """ try: - # Try to get the current event loop - loop = asyncio.get_running_loop() - # We're in an async context, need to run in a new thread + asyncio.get_running_loop() return _run_in_thread(coro) except RuntimeError: - # No running event loop, safe to use asyncio.run return asyncio.run(coro) def _run_in_thread(coro: Awaitable[T]) -> T: - """Run coroutine in a separate thread with its own event loop""" - result = {"value": None, "exception": None} - + """Runs a coroutine in a new thread with its own event loop.""" + result: Dict[str, Any] = {"value": None, "exception": None} + def thread_target(): try: - # Create new event loop for this thread new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) result["value"] = new_loop.run_until_complete(coro) @@ -39,40 +47,55 @@ def thread_target(): result["exception"] = e finally: new_loop.close() - + thread = threading.Thread(target=thread_target) thread.start() thread.join() - + if result["exception"]: raise result["exception"] - + return result["value"] -def ensure_async(func): +def ensure_async(func: Callable) -> Callable[..., Awaitable[Any]]: """ - Decorator to ensure function is async-compatible - If the function is sync, wrap it to run in thread pool + A decorator to ensure a function is awaitable. + + If the decorated function is synchronous, it is wrapped to run in an + executor, making it awaitable. If it is already a coroutine function, it is + returned unchanged. + + Args: + func: The function to decorate. + + Returns: + An awaitable version of the function. """ if asyncio.iscoroutinefunction(func): return func - + @wraps(func) async def async_wrapper(*args, **kwargs): loop = asyncio.get_event_loop() return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) - + return async_wrapper -def make_sync_version(async_func): +def make_sync_version(async_func: Callable[..., Awaitable[T]]) -> Callable[..., T]: """ - Create a synchronous version of an async function + Creates a synchronous version of an asynchronous function. + + Args: + async_func: The asynchronous function to wrap. + + Returns: + A new function that is a synchronous version of the input function. """ @wraps(async_func) - def sync_version(*args, **kwargs): + def sync_version(*args, **kwargs) -> T: coro = async_func(*args, **kwargs) return sync_wrapper(coro) - + return sync_version \ No newline at end of file