diff --git a/.env.gui b/.env.gui new file mode 100644 index 0000000..c488b0a --- /dev/null +++ b/.env.gui @@ -0,0 +1,4 @@ +RELEASE=test +VERSION=1 +BUILD=1 +FIX=0 diff --git a/.github/workflows/ci-build-image.yml b/.github/workflows/ci-build-image.yml new file mode 100644 index 0000000..b8588cb --- /dev/null +++ b/.github/workflows/ci-build-image.yml @@ -0,0 +1,43 @@ +name: Build and publish GUI + +on: + push: + branches: + - wip + paths: + - '.env.gui' + +jobs: + PackageDeploy: + runs-on: ubuntu-22.04 + + steps: + - uses: actions/checkout@v2 + + - name: Docker Setup BuildX + uses: docker/setup-buildx-action@v2 + + - name: Load environment variables and set them + run: | + if [ -f .env.gui ]; then + export $(cat .env.gui | grep -v '^#' | xargs) + fi + echo "RELEASE=$RELEASE" >> $GITHUB_ENV + echo "VERSION=$VERSION" >> $GITHUB_ENV + echo "BUILD=$BUILD" >> $GITHUB_ENV + echo "FIX=$FIX" >> $GITHUB_ENV + - name: Set repo + run: | + LOWER_CASE_GITHUB_REPOSITORY=$(echo $GITHUB_REPOSITORY | tr '[:upper:]' '[:lower:]') + echo "DOCKER_TAG_CUSTOM=ghcr.io/${LOWER_CASE_GITHUB_REPOSITORY}:$RELEASE-$VERSION.$BUILD.$FIX" >> $GITHUB_ENV + echo "$GITHUB_ENV" + - name: Docker Build + run: | + cd GUI + docker image build --tag $DOCKER_TAG_CUSTOM -f Dockerfile.dev . + + - name: Log in to GitHub container registry + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $ --password-stdin + + - name: Push Docker image to ghcr + run: docker push $DOCKER_TAG_CUSTOM diff --git a/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md b/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md index 15669e4..398299a 100644 --- a/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md +++ b/docs/TOOL_CLASSIFIER_AND_SERVICE_WORKFLOW.md @@ -244,9 +244,9 @@ intent_result = intent_module.forward(...) # After LLM call usage_info = get_lm_usage_since(history_length_before) -costs_dict["intent_detection"] = usage_info +costs_metric["intent_detection"] = usage_info -# Later: orchestration_service.log_costs(costs_dict) +# Later: orchestration_service.log_costs(costs_metric) ``` --- @@ -557,14 +557,14 @@ Service workflow tracks LLM costs following the RAG workflow pattern: ```python # Create costs dict at workflow level -costs_dict: Dict[str, Dict[str, Any]] = {} +costs_metric: Dict[str, Dict[str, Any]] = {} # Intent detection captures costs intent_result, intent_usage = await _detect_service_intent(...) -costs_dict["intent_detection"] = intent_usage +costs_metric["intent_detection"] = intent_usage # Log costs after workflow completes -orchestration_service.log_costs(costs_dict) +orchestration_service.log_costs(costs_metric) ``` **Cost Breakdown Logged:** diff --git a/docs/TOOL_CLASSIFIER_EXTENSION_SPEC.md b/docs/TOOL_CLASSIFIER_EXTENSION_SPEC.md index 469f809..38d8189 100644 --- a/docs/TOOL_CLASSIFIER_EXTENSION_SPEC.md +++ b/docs/TOOL_CLASSIFIER_EXTENSION_SPEC.md @@ -425,7 +425,7 @@ formatted_content = format_service_response(service_response) # Apply output guardrails if guardrails_adapter: output_check = await guardrails_adapter.check_output_async(formatted_content) - costs_dict["output_guardrails"] = output_check.usage + costs_metric["output_guardrails"] = output_check.usage if not output_check.allowed: logger.warning(f"Service response blocked by guardrails: {output_check.reason}") @@ -449,7 +449,7 @@ formatted_content = format_service_response(service_response) # Apply output guardrails validation if guardrails_adapter: output_check = await guardrails_adapter.check_output_async(formatted_content) - costs_dict["output_guardrails"] = output_check.usage + costs_metric["output_guardrails"] = output_check.usage if not output_check.allowed: logger.warning(f"Service response blocked by guardrails") @@ -791,7 +791,7 @@ async def execute_context_workflow( request: OrchestrationRequest, llm_manager: LLMManager, guardrails_adapter: Optional[NeMoRailsAdapter], - costs_dict: Dict + costs_metric: Dict ) -> Optional[OrchestrationResponse]: """ Execute context-based response workflow with output guardrails. @@ -807,7 +807,7 @@ async def execute_context_workflow( ) # Track costs - costs_dict["context_check"] = get_lm_usage_since(history_before) + costs_metric["context_check"] = get_lm_usage_since(history_before) if (context_result.is_greeting or context_result.can_answer_from_context) and context_result.answer: logger.info( @@ -820,7 +820,7 @@ async def execute_context_workflow( output_check = await guardrails_adapter.check_output_async( context_result.answer ) - costs_dict["output_guardrails"] = output_check.usage + costs_metric["output_guardrails"] = output_check.usage if not output_check.allowed: logger.warning( @@ -852,7 +852,7 @@ async def execute_context_workflow_streaming( request: OrchestrationRequest, llm_manager: LLMManager, guardrails_adapter: Optional[NeMoRailsAdapter], - costs_dict: Dict + costs_metric: Dict ) -> Optional[AsyncIterator[str]]: """ Execute context workflow with streaming support and output guardrails. @@ -871,7 +871,7 @@ async def execute_context_workflow_streaming( ) # Track costs - costs_dict["context_check"] = get_lm_usage_since(history_before) + costs_metric["context_check"] = get_lm_usage_since(history_before) if (context_result.is_greeting or context_result.can_answer_from_context) and context_result.answer: logger.info( @@ -884,7 +884,7 @@ async def execute_context_workflow_streaming( output_check = await guardrails_adapter.check_output_async( context_result.answer ) - costs_dict["output_guardrails"] = output_check.usage + costs_metric["output_guardrails"] = output_check.usage if not output_check.allowed: logger.warning( @@ -941,17 +941,17 @@ def split_into_tokens(text: str, chunk_size: int = 5) -> List[str]: ```python try: result = await execute_context_workflow( - request, llm_manager, guardrails_adapter, costs_dict + request, llm_manager, guardrails_adapter, costs_metric ) if result: return result # Context-based answer (validated) else: # Move to Layer 3 (RAG) - return await execute_rag_workflow(request, components, costs_dict) + return await execute_rag_workflow(request, components, costs_metric) except Exception as e: logger.error(f"Context workflow failed: {e}") # Fallback to RAG workflow - return await execute_rag_workflow(request, components, costs_dict) + return await execute_rag_workflow(request, components, costs_metric) ``` **Guardrail Violation Fallback:** @@ -963,7 +963,7 @@ if not output_check.allowed: # Option 2: Fallback to RAG (alternative approach) if not output_check.allowed: logger.warning("Context response blocked, trying RAG workflow") - return await execute_rag_workflow(request, components, costs_dict) + return await execute_rag_workflow(request, components, costs_metric) ``` --- @@ -978,7 +978,7 @@ if not output_check.allowed: ```python # Reuse existing RAG pipeline return self._execute_orchestration_pipeline( - request, components, costs_dict, timing_dict + request, components, costs_metric, time_metric ) ``` @@ -1121,7 +1121,7 @@ if context_result.can_answer_from_context: - **Pre-validation**: Get complete response → Validate → Stream to client - **Complete response**: Already have full text before streaming starts - **Uni-directional**: Simply chunk and send validated response -- **Cost**: Separate validation call tracked in `costs_dict["output_guardrails"]` +- **Cost**: Separate validation call tracked in `costs_metric["output_guardrails"]` - **UX Consistency**: Simulates streaming to match RAG workflow behavior ### Why Different Approaches? @@ -1601,15 +1601,15 @@ CREATE INDEX idx_classifier_decisions_workflow **Add tracking for new LLM calls:** # Service workflow - intent detection -costs_dict["intent_detection"] = { +costs_metric["intent_detection"] = { "total_prompt_tokens": usage.prompt_tokens, "total_completion_tokens": usage.completion_tokens, "total_cost": calculate_cost(usage) } # Context workflow - context availability check -costs_dict["context_check -costs_dict["intent_detection"] = { +costs_metric["context_check +costs_metric["intent_detection"] = { "total_prompt_tokens": usage.prompt_tokens, "total_completion_tokens": usage.completion_tokens, "total_cost": calculate_cost(usage) @@ -1663,7 +1663,7 @@ async def stream_validated_response( response_text: str, guardrails_adapter: NeMoRailsAdapter, request: OrchestrationRequest, - costs_dict: Dict + costs_metric: Dict ) -> AsyncIterator[str]: """ Apply output guardrails and stream validated response. @@ -1677,7 +1677,7 @@ async def stream_validated_response( output_check = await guardrails_adapter.check_output_async(response_text) # Track costs - costs_dict["output_guardrails"] = output_check.usage + costs_metric["output_guardrails"] = output_check.usage if not output_check.allowed: logger.warning(f"[{request.chatId}] Output blocked by guardrails") diff --git a/src/contextual_retrieval/contextual_retrieval.md b/src/contextual_retrieval/contextual_retrieval.md index f80d6aa..ce3446c 100644 --- a/src/contextual_retrieval/contextual_retrieval.md +++ b/src/contextual_retrieval/contextual_retrieval.md @@ -788,7 +788,7 @@ def _initialize_contextual_retriever( #### 2. Request Processing ```python # Main orchestration pipeline -def _execute_orchestration_pipeline(self, request, components, costs_dict): +def _execute_orchestration_pipeline(self, request, components, costs_metric): # Step 1: Refine user prompt refined_output = self._refine_user_prompt(...) diff --git a/src/guardrails/readme.md b/src/guardrails/readme.md index 0a51315..7a69e93 100644 --- a/src/guardrails/readme.md +++ b/src/guardrails/readme.md @@ -180,7 +180,7 @@ result.usage = usage_info # Contains: total_cost, tokens, num_calls ### Modified Pipeline in `llm_orchestration_service.py` ```python -costs_dict = { +costs_metric = { "input_guardrails": {...}, # Step 1 "prompt_refiner": {...}, # Step 2 "response_generator": {...}, # Step 4 diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index e2eb0c9..0d32941 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -26,7 +26,6 @@ from src.response_generator.response_generate import ResponseGeneratorAgent from src.response_generator.response_generate import stream_response_native from src.llm_orchestrator_config.llm_ochestrator_constants import ( - OUT_OF_SCOPE_MESSAGE, OUT_OF_SCOPE_MESSAGES, TECHNICAL_ISSUE_MESSAGE, TECHNICAL_ISSUE_MESSAGES, @@ -67,7 +66,7 @@ class LangfuseConfig: """Configuration for Langfuse integration.""" - def __init__(self): + def __init__(self) -> None: self.langfuse_client: Optional[Langfuse] = None self._initialize_langfuse() @@ -134,9 +133,70 @@ def __init__(self) -> None: # This allows components to be initialized per-request with proper context self.tool_classifier = None + # Initialize shared guardrails adapters at startup (production and testing) + self.shared_guardrails_adapters = ( + self._initialize_shared_guardrails_at_startup() + ) + # Log feature flag configuration FeatureFlags.log_configuration() + def _initialize_shared_guardrails_at_startup(self) -> Dict[str, NeMoRailsAdapter]: + """ + Initialize shared guardrails adapters at startup for production and testing environments. + + Returns: + Dictionary mapping environment names to NeMoRailsAdapter instances. + Empty dict on failure (graceful degradation). + """ + adapters: Dict[str, NeMoRailsAdapter] = {} + + # Initialize adapters for commonly-used environments + environments_to_initialize = ["production", "testing"] + + logger.info(" Initializing shared guardrails at startup...") + total_start_time = time.time() + + for env in environments_to_initialize: + try: + logger.info(f" Initializing guardrails for environment: {env}") + start_time = time.time() + + # Initialize with specific environment and no connection (shared config) + guardrails_adapter = self._initialize_guardrails( + environment=env, + connection_id=None, # Shared configuration, not user-specific + ) + + elapsed_time = time.time() - start_time + adapters[env] = guardrails_adapter + logger.info( + f" Guardrails for '{env}' initialized successfully in {elapsed_time:.3f}s" + ) + + except Exception as e: + logger.error(f" Failed to initialize guardrails for '{env}': {e}") + logger.warning( + f" Service will fall back to per-request initialization for '{env}' environment" + ) + # Continue with other environments - partial success is acceptable + continue + + total_elapsed = time.time() - total_start_time + + if adapters: + logger.info( + f" Shared guardrails initialized for {len(adapters)} environment(s) " + f"in {total_elapsed:.3f}s total" + ) + else: + logger.error( + " Failed to initialize any shared guardrails - " + "service will use per-request initialization (slower)" + ) + + return adapters + @observe(name="orchestration_request", as_type="agent") async def process_orchestration_request( self, request: OrchestrationRequest @@ -161,8 +221,8 @@ async def process_orchestration_request( Raises: Exception: For any processing errors """ - costs_dict: Dict[str, Dict[str, Any]] = {} - timing_dict: Dict[str, float] = {} + costs_metric: Dict[str, Dict[str, Any]] = {} + time_metric: Dict[str, float] = {} try: logger.info( @@ -170,9 +230,11 @@ async def process_orchestration_request( f"authorId: {request.authorId}, environment: {request.environment}" ) - # STEP 0: Detect language from user message + # STEP 0: Detect language from user message (with timing) + start_time = time.time() detected_language = detect_language(request.message) language_name = get_language_name(detected_language) + time_metric["language_detection"] = time.time() - start_time logger.info( f"[{request.chatId}] Detected language: {language_name} ({detected_language})" ) @@ -182,7 +244,9 @@ async def process_orchestration_request( setattr(request, "_detected_language", detected_language) # STEP 0.5: Basic Query Validation (before expensive component initialization) + start_time = time.time() validation_result = validate_query_basic(request.message) + time_metric["query_validation"] = time.time() - start_time if not validation_result.is_valid: logger.info( f"[{request.chatId}] Query validation failed: {validation_result.rejection_reason}" @@ -210,8 +274,30 @@ async def process_orchestration_request( content=validation_msg, ) - # Initialize all service components (only for valid queries) + # Initialize all service components (only for valid queries, with timing) + start_time = time.time() components = self._initialize_service_components(request) + time_metric["initialization"] = time.time() - start_time + + if components["guardrails_adapter"]: + start_time = time.time() + input_blocked_response = await self.handle_input_guardrails( + components["guardrails_adapter"], request, {} + ) + time_metric["input_guardrails_check"] = time.time() - start_time + + if input_blocked_response: + logger.warning( + f"[{request.chatId}] Input blocked before classifier - " + f"saved expensive service discovery" + ) + log_step_timings(time_metric, request.chatId) + return input_blocked_response + else: + logger.info( + f"[{request.chatId}] Guardrails not available - " + f"proceeding without input validation" + ) # TOOL CLASSIFIER INTEGRATION # Route through tool classifier if enabled, otherwise use existing RAG pipeline @@ -229,24 +315,29 @@ async def process_orchestration_request( ) logger.info("Tool classifier initialized") - # Classify query to determine workflow + # Classify query to determine workflow (with timing) + start_time = time.time() classification = await self.tool_classifier.classify( query=request.message, conversation_history=request.conversationHistory, language=detected_language, ) + time_metric["classifier.classify"] = time.time() - start_time logger.info( f"[{request.chatId}] Classification: {classification.workflow.value} " f"(confidence: {classification.confidence:.2f})" ) - # Route to appropriate workflow + # Route to appropriate workflow (with timing) + start_time = time.time() response = await self.tool_classifier.route_to_workflow( classification=classification, request=request, is_streaming=False, + time_metric=time_metric, ) + time_metric["classifier.route"] = time.time() - start_time except Exception as classifier_error: logger.error( @@ -260,7 +351,7 @@ async def process_orchestration_request( ) # Execute existing RAG pipeline as fallback response = await self._execute_orchestration_pipeline( - request, components, costs_dict, timing_dict + request, components, costs_metric, time_metric ) else: raise @@ -270,27 +361,27 @@ async def process_orchestration_request( f"[{request.chatId}] Tool classifier disabled - using RAG pipeline" ) response = await self._execute_orchestration_pipeline( - request, components, costs_dict, timing_dict + request, components, costs_metric, time_metric ) # Log final costs and return response - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) # Update budget for the LLM connection self._update_connection_budget( - request.connection_id, costs_dict, request.environment + request.connection_id, costs_metric, request.environment ) if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client - total_costs = calculate_total_costs(costs_dict) + total_costs = calculate_total_costs(costs_metric) total_input_tokens = sum( - c.get("total_prompt_tokens", 0) for c in costs_dict.values() + c.get("total_prompt_tokens", 0) for c in costs_metric.values() ) total_output_tokens = sum( - c.get("total_completion_tokens", 0) for c in costs_dict.values() + c.get("total_completion_tokens", 0) for c in costs_metric.values() ) langfuse.update_current_generation( @@ -307,7 +398,7 @@ async def process_orchestration_request( }, metadata={ "total_calls": total_costs.get("total_calls", 0), - "cost_breakdown": costs_dict, + "cost_breakdown": costs_metric, "chat_id": request.chatId, "author_id": request.authorId, "environment": request.environment, @@ -331,12 +422,12 @@ async def process_orchestration_request( } ) langfuse.flush() - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) # Update budget even on error self._update_connection_budget( - request.connection_id, costs_dict, request.environment + request.connection_id, costs_metric, request.environment ) return self._create_error_response(request) @@ -379,12 +470,14 @@ async def stream_orchestration_response( """ # Track costs after streaming completes - costs_dict: Dict[str, Dict[str, Any]] = {} - timing_dict: Dict[str, float] = {} + costs_metric: Dict[str, Dict[str, Any]] = {} + time_metric: Dict[str, float] = {} - # STEP 0: Detect language from user message + # STEP 0: Detect language from user message (with timing) + start_time = time.time() detected_language = detect_language(request.message) language_name = get_language_name(detected_language) + time_metric["language_detection"] = time.time() - start_time logger.info( f"[{request.chatId}] Streaming request - Detected language: {language_name} ({detected_language})" ) @@ -393,8 +486,10 @@ async def stream_orchestration_response( # Using setattr for type safety - adds dynamic attribute to Pydantic model instance setattr(request, "_detected_language", detected_language) - # Step 0.5: Basic Query Validation (before guardrails) + # Step 0.5: Basic Query Validation (before guardrails, with timing) + start_time = time.time() validation_result = validate_query_basic(request.message) + time_metric["query_validation"] = time.time() - start_time if not validation_result.is_valid: logger.info( f"[{request.chatId}] Streaming - Query validation failed: {validation_result.rejection_reason}" @@ -419,12 +514,15 @@ async def stream_orchestration_response( f"(environment: {request.environment})" ) - # Initialize all service components + # Initialize all service components (with timing) + start_time = time.time() components = self._initialize_service_components(request) + time_metric["initialization"] = time.time() - start_time - # STEP 1: CHECK INPUT GUARDRAILS (blocking) + # This implements fail-fast principle - block malicious/policy-violating inputs + # before expensive operations (service discovery, LLM calls, streaming setup) logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Step 1: Checking input guardrails" + f"[{request.chatId}] [{stream_ctx.stream_id}] Checking input guardrails (before classifier)" ) if components["guardrails_adapter"]: @@ -432,25 +530,32 @@ async def stream_orchestration_response( input_check_result = await self._check_input_guardrails_async( guardrails_adapter=components["guardrails_adapter"], user_message=request.message, - costs_dict=costs_dict, + costs_metric=costs_metric, ) - timing_dict["input_guardrails_check"] = time.time() - start_time + time_metric["input_guardrails_check"] = time.time() - start_time if not input_check_result.allowed: logger.warning( - f"[{request.chatId}] [{stream_ctx.stream_id}] Input blocked by guardrails: " - f"{input_check_result.reason}" + f"[{request.chatId}] [{stream_ctx.stream_id}] Input blocked before classifier - " + f"saved expensive service discovery. Reason: {input_check_result.reason}" ) yield self.format_sse( request.chatId, INPUT_GUARDRAIL_VIOLATION_MESSAGE ) yield self.format_sse(request.chatId, "END") - self.log_costs(costs_dict) + self.log_costs(costs_metric) + # Log timings before returning (for visibility) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return + else: + logger.info( + f"[{request.chatId}] [{stream_ctx.stream_id}] Guardrails not available - " + f"proceeding without input validation" + ) logger.info( - f"[{request.chatId}] [{stream_ctx.stream_id}] Input guardrails passed " + f"[{request.chatId}] [{stream_ctx.stream_id}] Input guardrails passed" ) # TOOL CLASSIFIER INTEGRATION (STREAMING) @@ -500,8 +605,8 @@ async def stream_orchestration_response( ) # Log costs and timings - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return # Exit after successful classifier routing @@ -531,8 +636,8 @@ async def stream_orchestration_response( request=request, components=components, stream_ctx=stream_ctx, - costs_dict=costs_dict, - timing_dict=timing_dict, + costs_metric=costs_metric, + time_metric=time_metric, ): yield sse_chunk @@ -549,12 +654,12 @@ async def stream_orchestration_response( yield self.format_sse(request.chatId, TECHNICAL_ISSUE_MESSAGE) yield self.format_sse(request.chatId, "END") - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) # Update budget even on outer exception self._update_connection_budget( - request.connection_id, costs_dict, request.environment + request.connection_id, costs_metric, request.environment ) if self.langfuse_config.langfuse_client: @@ -575,8 +680,8 @@ async def _stream_rag_pipeline( request: OrchestrationRequest, components: Dict[str, Any], stream_ctx: Any, - costs_dict: Dict[str, Dict[str, Any]], - timing_dict: Dict[str, float], + costs_metric: Dict[str, Dict[str, Any]], + time_metric: Dict[str, float], ) -> AsyncIterator[str]: """ Core RAG streaming pipeline without classifier routing. @@ -594,8 +699,8 @@ async def _stream_rag_pipeline( request: Orchestration request components: Initialized service components (LLM, retriever, generator, guardrails) stream_ctx: Stream context for tracking - costs_dict: Dictionary to accumulate costs - timing_dict: Dictionary to accumulate timings + costs_metric: Dictionary to accumulate costs + time_metric: Dictionary to accumulate timings Yields: SSE-formatted strings @@ -614,8 +719,8 @@ async def _stream_rag_pipeline( original_message=request.message, conversation_history=request.conversationHistory, ) - timing_dict["prompt_refiner"] = time.time() - start_time - costs_dict["prompt_refiner"] = refiner_usage + time_metric["prompt_refiner"] = time.time() - start_time + costs_metric["prompt_refiner"] = refiner_usage logger.info( f"[{request.chatId}] [{stream_ctx.stream_id}] Prompt refinement complete" @@ -631,7 +736,7 @@ async def _stream_rag_pipeline( relevant_chunks = await self._safe_retrieve_contextual_chunks( components["contextual_retriever"], refined_output, request ) - timing_dict["contextual_retrieval"] = time.time() - start_time + time_metric["contextual_retrieval"] = time.time() - start_time except ( ContextualRetrieverInitializationError, ContextualRetrievalFailureError, @@ -647,8 +752,8 @@ async def _stream_rag_pipeline( ) yield self.format_sse(request.chatId, localized_msg) yield self.format_sse(request.chatId, "END") - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return @@ -661,8 +766,8 @@ async def _stream_rag_pipeline( ) yield self.format_sse(request.chatId, localized_msg) yield self.format_sse(request.chatId, "END") - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return @@ -681,7 +786,7 @@ async def _stream_rag_pipeline( chunks=relevant_chunks, max_blocks=ResponseGenerationConstants.DEFAULT_MAX_BLOCKS, ) - timing_dict["scope_check"] = time.time() - start_time + time_metric["scope_check"] = time.time() - start_time if is_out_of_scope: logger.info( @@ -692,8 +797,8 @@ async def _stream_rag_pipeline( ) yield self.format_sse(request.chatId, localized_msg) yield self.format_sse(request.chatId, "END") - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return @@ -761,9 +866,9 @@ async def bot_response_generator() -> AsyncIterator[str]: yield self.format_sse(request.chatId, "END") usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + costs_metric["streaming_generation"] = usage_info + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return @@ -790,9 +895,9 @@ async def bot_response_generator() -> AsyncIterator[str]: yield self.format_sse(request.chatId, "END") usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + costs_metric["streaming_generation"] = usage_info + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) stream_ctx.mark_completed() return @@ -859,11 +964,11 @@ async def bot_response_generator() -> AsyncIterator[str]: # Extract usage information after streaming completes usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info + costs_metric["streaming_generation"] = usage_info # Record timings - timing_dict["streaming_generation"] = time.time() - streaming_step_start - timing_dict["output_guardrails"] = 0.0 # Inline during streaming + time_metric["streaming_generation"] = time.time() - streaming_step_start + time_metric["output_guardrails"] = 0.0 # Inline during streaming # Calculate streaming duration streaming_duration = (datetime.now() - streaming_start_time).total_seconds() @@ -872,18 +977,18 @@ async def bot_response_generator() -> AsyncIterator[str]: ) # Log costs and trace - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) # Update budget self._update_connection_budget( - request.connection_id, costs_dict, request.environment + request.connection_id, costs_metric, request.environment ) # Langfuse tracking if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client - total_costs = calculate_total_costs(costs_dict) + total_costs = calculate_total_costs(costs_metric) langfuse.update_current_generation( model=components["llm_manager"] @@ -899,7 +1004,7 @@ async def bot_response_generator() -> AsyncIterator[str]: "streaming": True, "streaming_duration_seconds": streaming_duration, "chunks_streamed": chunk_count, - "cost_breakdown": costs_dict, + "cost_breakdown": costs_metric, "chat_id": request.chatId, "environment": request.environment, "stream_id": stream_ctx.stream_id, @@ -934,13 +1039,13 @@ async def bot_response_generator() -> AsyncIterator[str]: f"[{request.chatId}] [{stream_ctx.stream_id}] Client disconnected" ) usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + costs_metric["streaming_generation"] = usage_info + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) # Update budget even on client disconnect self._update_connection_budget( - request.connection_id, costs_dict, request.environment + request.connection_id, costs_metric, request.environment ) raise except Exception as stream_error: @@ -957,13 +1062,13 @@ async def bot_response_generator() -> AsyncIterator[str]: yield self.format_sse(request.chatId, "END") usage_info = get_lm_usage_since(history_length_before) - costs_dict["streaming_generation"] = usage_info - self.log_costs(costs_dict) - log_step_timings(timing_dict, request.chatId) + costs_metric["streaming_generation"] = usage_info + self.log_costs(costs_metric) + log_step_timings(time_metric, request.chatId) # Update budget even on streaming error self._update_connection_budget( - request.connection_id, costs_dict, request.environment + request.connection_id, costs_metric, request.environment ) def format_sse(self, chat_id: str, content: str) -> str: @@ -998,10 +1103,22 @@ def _initialize_service_components( environment=request.environment, connection_id=request.connection_id ) - # Initialize Guardrails Adapter (optional) - components["guardrails_adapter"] = self._safe_initialize_guardrails( - request.environment, request.connection_id - ) + if request.environment in self.shared_guardrails_adapters: + logger.info( + f" Using shared guardrails adapter for environment='{request.environment}' " + f"(startup-initialized, zero overhead)" + ) + components["guardrails_adapter"] = self.shared_guardrails_adapters[ + request.environment + ] + else: + logger.warning( + f" Shared guardrails unavailable for environment='{request.environment}', " + f"initializing per-request (slower)" + ) + components["guardrails_adapter"] = self._safe_initialize_guardrails( + request.environment, request.connection_id + ) # Initialize Contextual Retriever (replaces hybrid retriever) components["contextual_retriever"] = self._safe_initialize_contextual_retriever( @@ -1112,40 +1229,44 @@ async def _execute_orchestration_pipeline( self, request: OrchestrationRequest, components: Dict[str, Any], - costs_dict: Dict[str, Dict[str, Any]], - timing_dict: Dict[str, float], + costs_metric: Dict[str, Dict[str, Any]], + time_metric: Dict[str, float], + prefix: str = "", ) -> Union[OrchestrationResponse, TestOrchestrationResponse]: - """Execute the main orchestration pipeline with all components.""" - # Note: Query validation now happens in process_orchestration_request() - # before component initialization for true early rejection + """Execute the main orchestration pipeline with all components. - # Step 1: Input Guardrails Check - if components["guardrails_adapter"]: - start_time = time.time() - input_blocked_response = await self.handle_input_guardrails( - components["guardrails_adapter"], request, costs_dict - ) - timing_dict["input_guardrails_check"] = time.time() - start_time - if input_blocked_response: - return input_blocked_response + Args: + request: Orchestration request + components: Initialized service components + costs_metric: Dictionary for cost tracking + time_metric: Dictionary for timing tracking + prefix: Optional prefix for timing keys (e.g., "rag" for workflow namespacing) + """ + # Note: Query validation AND input guardrails check now happen at orchestration level + # (in process_orchestration_request) BEFORE classifier routing for true early rejection. + # This saves ~3.5s on blocked requests by failing fast before expensive workflow operations. - # Step 2: Refine user prompt + # Step 1: Refine user prompt start_time = time.time() refined_output, refiner_usage = self._refine_user_prompt( llm_manager=components["llm_manager"], original_message=request.message, conversation_history=request.conversationHistory, ) - timing_dict["prompt_refiner"] = time.time() - start_time - costs_dict["prompt_refiner"] = refiner_usage + timing_key = f"{prefix}.prompt_refiner" if prefix else "prompt_refiner" + time_metric[timing_key] = time.time() - start_time + costs_metric["prompt_refiner"] = refiner_usage - # Step 3: Retrieve relevant chunks using contextual retrieval + # Step 2: Retrieve relevant chunks using contextual retrieval try: start_time = time.time() relevant_chunks = await self._safe_retrieve_contextual_chunks( components["contextual_retriever"], refined_output, request ) - timing_dict["contextual_retrieval"] = time.time() - start_time + timing_key = ( + f"{prefix}.contextual_retrieval" if prefix else "contextual_retrieval" + ) + time_metric[timing_key] = time.time() - start_time except ( ContextualRetrieverInitializationError, ContextualRetrievalFailureError, @@ -1158,7 +1279,7 @@ async def _execute_orchestration_pipeline( logger.info("No relevant chunks found - returning out-of-scope response") return self._create_out_of_scope_response(request) - # Step 4: Generate response + # Step 3: Generate response start_time = time.time() generated_response = self._generate_rag_response( llm_manager=components["llm_manager"], @@ -1166,22 +1287,28 @@ async def _execute_orchestration_pipeline( refined_output=refined_output, relevant_chunks=relevant_chunks, response_generator=components["response_generator"], - costs_dict=costs_dict, + costs_metric=costs_metric, + ) + timing_key = ( + f"{prefix}.response_generation" if prefix else "response_generation" ) - timing_dict["response_generation"] = time.time() - start_time + time_metric[timing_key] = time.time() - start_time - # Step 5: Output Guardrails Check + # Step 4: Output Guardrails Check # Apply guardrails to all response types for consistent safety across all environments start_time = time.time() output_guardrails_response = await self.handle_output_guardrails( components["guardrails_adapter"], generated_response, request, - costs_dict, + costs_metric, ) - timing_dict["output_guardrails_check"] = time.time() - start_time + timing_key = ( + f"{prefix}.output_guardrails_check" if prefix else "output_guardrails_check" + ) + time_metric[timing_key] = time.time() - start_time - # Step 6: Store inference data (for production and testing environments) + # Step 5: Store inference data (for production and testing environments) # Only store OrchestrationResponse (has chatId), not TestOrchestrationResponse if request.environment in [ PRODUCTION_DEPLOYMENT_ENVIRONMENT, @@ -1252,13 +1379,13 @@ async def handle_input_guardrails( self, guardrails_adapter: NeMoRailsAdapter, request: OrchestrationRequest, - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> Union[OrchestrationResponse, TestOrchestrationResponse, None]: """Check input guardrails and return blocked response if needed.""" input_check_result = await self._check_input_guardrails_async( guardrails_adapter=guardrails_adapter, user_message=request.message, - costs_dict=costs_dict, + costs_metric=costs_metric, ) if not input_check_result.allowed: @@ -1378,7 +1505,7 @@ async def handle_output_guardrails( guardrails_adapter: Optional[NeMoRailsAdapter], generated_response: Union[OrchestrationResponse, TestOrchestrationResponse], request: OrchestrationRequest, - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> Union[OrchestrationResponse, TestOrchestrationResponse]: """Check output guardrails and handle blocked responses for both response types.""" # Determine if we should run guardrails (same logic for both response types) @@ -1394,7 +1521,7 @@ async def handle_output_guardrails( output_check_result = await self._check_output_guardrails( guardrails_adapter=guardrails_adapter, assistant_message=generated_response.content, - costs_dict=costs_dict, + costs_metric=costs_metric, ) if not output_check_result.allowed: @@ -1671,7 +1798,7 @@ async def _check_input_guardrails_async( self, guardrails_adapter: NeMoRailsAdapter, user_message: str, - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> GuardrailCheckResult: """ Check user input against guardrails and track costs (async version). @@ -1679,7 +1806,7 @@ async def _check_input_guardrails_async( Args: guardrails_adapter: The guardrails adapter instance user_message: The user message to check - costs_dict: Dictionary to store cost information + costs_metric: Dictionary to store cost information Returns: GuardrailCheckResult: Result of the guardrail check @@ -1691,7 +1818,7 @@ async def _check_input_guardrails_async( result = await guardrails_adapter.check_input_async(user_message) # Store guardrail costs - costs_dict["input_guardrails"] = result.usage + costs_metric["input_guardrails"] = result.usage if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client langfuse.update_current_generation( @@ -1744,7 +1871,7 @@ def _check_input_guardrails( self, guardrails_adapter: NeMoRailsAdapter, user_message: str, - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> GuardrailCheckResult: """ Check user input against guardrails and track costs (sync version for non-streaming). @@ -1752,7 +1879,7 @@ def _check_input_guardrails( Args: guardrails_adapter: The guardrails adapter instance user_message: The user message to check - costs_dict: Dictionary to store cost information + costs_metric: Dictionary to store cost information Returns: GuardrailCheckResult: Result of the guardrail check @@ -1763,7 +1890,7 @@ def _check_input_guardrails( result = guardrails_adapter.check_input(user_message) # Store guardrail costs - costs_dict["input_guardrails"] = result.usage + costs_metric["input_guardrails"] = result.usage if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client langfuse.update_current_generation( @@ -1816,7 +1943,7 @@ async def _check_output_guardrails( self, guardrails_adapter: NeMoRailsAdapter, assistant_message: str, - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> GuardrailCheckResult: """ Check assistant output against guardrails and track costs. @@ -1824,7 +1951,7 @@ async def _check_output_guardrails( Args: guardrails_adapter: The guardrails adapter instance assistant_message: The assistant message to check - costs_dict: Dictionary to store cost information + costs_metric: Dictionary to store cost information Returns: GuardrailCheckResult: Result of the guardrail check @@ -1835,7 +1962,7 @@ async def _check_output_guardrails( result = await guardrails_adapter.check_output_async(assistant_message) # Store guardrail costs - costs_dict["output_guardrails"] = result.usage + costs_metric["output_guardrails"] = result.usage if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client langfuse.update_current_generation( @@ -1885,22 +2012,22 @@ async def _check_output_guardrails( usage={}, ) - def log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: + def log_costs(self, costs_metric: Dict[str, Dict[str, Any]]) -> None: """ Log cost information for tracking. Args: - costs_dict: Dictionary of costs per component + costs_metric: Dictionary of costs per component """ try: - if not costs_dict: + if not costs_metric: return - total_costs = calculate_total_costs(costs_dict) + total_costs = calculate_total_costs(costs_metric) logger.info("LLM USAGE COSTS BREAKDOWN:") - for component, costs in costs_dict.items(): + for component, costs in costs_metric.items(): logger.info( f" {component:20s}: ${costs.get('total_cost', 0):.6f} " f"({costs.get('num_calls', 0)} calls, " @@ -1954,7 +2081,7 @@ def log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: def _update_connection_budget( self, connection_id: Optional[str], - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], environment: str = "development", ) -> None: """ @@ -1963,7 +2090,7 @@ def _update_connection_budget( Args: connection_id: The LLM connection ID (optional) - costs_dict: Dictionary of costs per component + costs_metric: Dictionary of costs per component environment: The deployment environment (production/testing/development) """ try: @@ -1991,7 +2118,9 @@ def _update_connection_budget( f"Error fetching production connection ID: {str(fetch_error)}" ) - result = budget_tracker.update_budget_from_costs(connection_id, costs_dict) + result = budget_tracker.update_budget_from_costs( + connection_id, costs_metric + ) if result.get("success"): if result.get("budget_exceeded"): @@ -2346,7 +2475,7 @@ def _generate_rag_response( refined_output: PromptRefinerOutput, relevant_chunks: List[Dict[str, Union[str, float, Dict[str, Any]]]], response_generator: Optional[ResponseGeneratorAgent] = None, - costs_dict: Optional[Dict[str, Dict[str, Any]]] = None, + costs_metric: Optional[Dict[str, Dict[str, Any]]] = None, ) -> Union[OrchestrationResponse, TestOrchestrationResponse]: """ Generate response using retrieved chunks and ResponseGeneratorAgent only. @@ -2354,8 +2483,8 @@ def _generate_rag_response( """ logger.info("Starting RAG response generation") - if costs_dict is None: - costs_dict = {} + if costs_metric is None: + costs_metric = {} # If response generator is not available -> standardized technical issue if response_generator is None: @@ -2413,7 +2542,7 @@ def _generate_rag_response( "num_calls": 0, }, ) - costs_dict["response_generator"] = generator_usage + costs_metric["response_generator"] = generator_usage if self.langfuse_config.langfuse_client: langfuse = self.langfuse_config.langfuse_client langfuse.update_current_generation( diff --git a/src/tool_classifier/base_workflow.py b/src/tool_classifier/base_workflow.py index 50faf7a..3f5835c 100644 --- a/src/tool_classifier/base_workflow.py +++ b/src/tool_classifier/base_workflow.py @@ -33,6 +33,7 @@ async def execute_async( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: """ Execute workflow in non-streaming mode. @@ -43,6 +44,7 @@ async def execute_async( Args: request: The orchestration request containing user query and context context: Workflow-specific metadata from ClassificationResult.metadata + time_metric: Optional dictionary for tracking step execution times Returns: OrchestrationResponse if workflow can handle this query @@ -68,6 +70,7 @@ async def execute_streaming( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: """ Execute workflow in streaming mode (Server-Sent Events). @@ -78,6 +81,7 @@ async def execute_streaming( Args: request: The orchestration request containing user query and context context: Workflow-specific metadata from ClassificationResult.metadata + time_metric: Optional dictionary for tracking step execution times Returns: AsyncIterator[str] yielding SSE-formatted strings if workflow can handle diff --git a/src/tool_classifier/classifier.py b/src/tool_classifier/classifier.py index c8bef8a..4455f8c 100644 --- a/src/tool_classifier/classifier.py +++ b/src/tool_classifier/classifier.py @@ -1,6 +1,6 @@ """Main tool classifier for workflow routing.""" -from typing import Any, AsyncIterator, Dict, List, Literal, Union, overload +from typing import Any, AsyncIterator, Dict, List, Literal, Optional, Union, overload from loguru import logger from models.request_models import ( @@ -106,6 +106,7 @@ async def route_to_workflow( classification: ClassificationResult, request: OrchestrationRequest, is_streaming: Literal[False] = False, + time_metric: Optional[Dict[str, float]] = None, ) -> OrchestrationResponse: ... @overload @@ -114,6 +115,7 @@ async def route_to_workflow( classification: ClassificationResult, request: OrchestrationRequest, is_streaming: Literal[True], + time_metric: Optional[Dict[str, float]] = None, ) -> AsyncIterator[str]: ... async def route_to_workflow( @@ -121,6 +123,7 @@ async def route_to_workflow( classification: ClassificationResult, request: OrchestrationRequest, is_streaming: bool = False, + time_metric: Optional[Dict[str, float]] = None, ) -> Union[OrchestrationResponse, AsyncIterator[str]]: """ Route request to appropriate workflow based on classification. @@ -132,6 +135,7 @@ async def route_to_workflow( classification: Classification result from classify() request: Original orchestration request is_streaming: Whether to use streaming mode (for /orchestrate/stream) + time_metric: Optional timing dictionary for workflow step tracking Returns: OrchestrationResponse for non-streaming mode @@ -162,6 +166,7 @@ async def route_to_workflow( request=request, context=classification.metadata, start_layer=classification.workflow, + time_metric=time_metric, ) else: # NON-STREAMING MODE: For /orchestrate and /orchestrate/test endpoints @@ -170,6 +175,7 @@ async def route_to_workflow( request=request, context=classification.metadata, start_layer=classification.workflow, + time_metric=time_metric, ) def _get_workflow_executor(self, workflow_type: WorkflowType) -> Any: @@ -188,6 +194,7 @@ async def _execute_with_fallback_async( request: OrchestrationRequest, context: Dict[str, Any], start_layer: WorkflowType, + time_metric: Optional[Dict[str, float]] = None, ) -> OrchestrationResponse: """ Execute workflow with fallback to subsequent layers (non-streaming). @@ -197,6 +204,13 @@ async def _execute_with_fallback_async( 2. If returns None, try next layer in WORKFLOW_LAYER_ORDER 3. Continue until workflow returns non-None result 4. OOD workflow always returns result (never None) + + Args: + workflow: Primary workflow executor + request: Orchestration request + context: Workflow context/metadata + start_layer: Starting workflow type + time_metric: Optional timing dictionary for tracking """ chat_id = request.chatId workflow_name = WORKFLOW_DISPLAY_NAMES.get(start_layer, start_layer.value) @@ -204,7 +218,7 @@ async def _execute_with_fallback_async( logger.info(f"[{chat_id}] Executing {workflow_name} (non-streaming)") try: - result = await workflow.execute_async(request, context) + result = await workflow.execute_async(request, context, time_metric) if result is not None: logger.info(f"[{chat_id}] {workflow_name} handled successfully") @@ -232,7 +246,7 @@ async def _execute_with_fallback_async( f"(Layer {WORKFLOW_LAYER_ORDER.index(next_layer) + 1})" ) - result = await next_workflow.execute_async(request, {}) + result = await next_workflow.execute_async(request, {}, time_metric) if result is not None: logger.info(f"[{chat_id}] {next_name} handled successfully") @@ -248,7 +262,7 @@ async def _execute_with_fallback_async( logger.error(f"[{chat_id}] Error executing {workflow_name}: {e}") # Fallback to RAG on error logger.info(f"[{chat_id}] Falling back to RAG due to error") - rag_result = await self.rag_workflow.execute_async(request, {}) + rag_result = await self.rag_workflow.execute_async(request, {}, time_metric) if rag_result is not None: return rag_result else: @@ -260,6 +274,7 @@ async def _execute_with_fallback_streaming( request: OrchestrationRequest, context: Dict[str, Any], start_layer: WorkflowType, + time_metric: Optional[Dict[str, float]] = None, ) -> AsyncIterator[str]: """ Execute workflow with fallback to subsequent layers (streaming). @@ -269,6 +284,13 @@ async def _execute_with_fallback_streaming( 2. If returns None, try next layer in WORKFLOW_LAYER_ORDER 3. Stream from the first workflow that returns non-None 4. OOD workflow always returns result (never None) + + Args: + workflow: Primary workflow executor + request: Orchestration request + context: Workflow context/metadata + start_layer: Starting workflow type + time_metric: Optional timing dictionary for tracking """ chat_id = request.chatId workflow_name = WORKFLOW_DISPLAY_NAMES.get(start_layer, start_layer.value) @@ -276,7 +298,7 @@ async def _execute_with_fallback_streaming( logger.info(f"[{chat_id}] Executing {workflow_name} (streaming)") try: - result = await workflow.execute_streaming(request, context) + result = await workflow.execute_streaming(request, context, time_metric) if result is not None: logger.info(f"[{chat_id}] {workflow_name} streaming started") @@ -307,7 +329,7 @@ async def _execute_with_fallback_streaming( f"(Layer {layer_number})" ) - result = await next_workflow.execute_streaming(request, {}) + result = await next_workflow.execute_streaming(request, {}, time_metric) if result is not None: logger.info(f"[{chat_id}] {next_name} streaming started") @@ -325,7 +347,9 @@ async def _execute_with_fallback_streaming( logger.error(f"[{chat_id}] Error executing {workflow_name} streaming: {e}") # Fallback to RAG on error logger.info(f"[{chat_id}] Falling back to RAG streaming due to error") - streaming_result = await self.rag_workflow.execute_streaming(request, {}) + streaming_result = await self.rag_workflow.execute_streaming( + request, {}, time_metric + ) if streaming_result is not None: async for chunk in streaming_result: yield chunk diff --git a/src/tool_classifier/intent_detector.py b/src/tool_classifier/intent_detector.py index 24c1538..a2abb74 100644 --- a/src/tool_classifier/intent_detector.py +++ b/src/tool_classifier/intent_detector.py @@ -42,9 +42,9 @@ class IntentDetectionModule(dspy.Module): """DSPy Module for service intent detection.""" def __init__(self) -> None: - """Initialize intent detection module with ChainOfThought.""" + """Initialize intent detection module with Predict (direct prediction).""" super().__init__() - self.detector = dspy.ChainOfThought(ServiceIntentDetector) + self.detector = dspy.Predict(ServiceIntentDetector) def forward( self, diff --git a/src/tool_classifier/workflows/context_workflow.py b/src/tool_classifier/workflows/context_workflow.py index 88212ef..dc23e8b 100644 --- a/src/tool_classifier/workflows/context_workflow.py +++ b/src/tool_classifier/workflows/context_workflow.py @@ -35,6 +35,7 @@ async def execute_async( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: """ Execute context workflow in non-streaming mode. @@ -45,6 +46,7 @@ async def execute_async( Args: request: Orchestration request with user query and history context: Metadata with is_greeting, can_answer_from_history flags + time_metric: Optional timing dictionary for future timing tracking Returns: OrchestrationResponse with context-based answer or None to fallback @@ -62,6 +64,7 @@ async def execute_streaming( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: """ Execute context workflow in streaming mode. @@ -72,6 +75,7 @@ async def execute_streaming( Args: request: Orchestration request with user query and history context: Metadata with is_greeting, can_answer_from_history flags + time_metric: Optional timing dictionary for future timing tracking Returns: AsyncIterator yielding SSE strings or None to fallback diff --git a/src/tool_classifier/workflows/ood_workflow.py b/src/tool_classifier/workflows/ood_workflow.py index cd114f7..35a1682 100644 --- a/src/tool_classifier/workflows/ood_workflow.py +++ b/src/tool_classifier/workflows/ood_workflow.py @@ -39,6 +39,7 @@ async def execute_async( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: """ Execute OOD workflow in non-streaming mode. @@ -68,6 +69,7 @@ async def execute_async( Args: request: Orchestration request with user query context: Unused (OOD doesn't need metadata) + time_metric: Optional timing dictionary for future timing tracking Returns: OrchestrationResponse with OOD message @@ -86,6 +88,7 @@ async def execute_streaming( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: """ Execute OOD workflow in streaming mode. diff --git a/src/tool_classifier/workflows/rag_workflow.py b/src/tool_classifier/workflows/rag_workflow.py index 6c58648..b5da35b 100644 --- a/src/tool_classifier/workflows/rag_workflow.py +++ b/src/tool_classifier/workflows/rag_workflow.py @@ -50,6 +50,7 @@ async def execute_async( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: """ Execute RAG workflow in non-streaming mode. @@ -64,6 +65,7 @@ async def execute_async( Args: request: Orchestration request with user query context: Unused (RAG doesn't need classification metadata) + time_metric: Optional timing dictionary from parent (for unified tracking) Returns: OrchestrationResponse with RAG-generated answer @@ -72,25 +74,25 @@ async def execute_async( logger.info(f"[{request.chatId}] Executing RAG workflow (non-streaming)") # Initialize components needed for RAG pipeline - costs_dict: Dict[str, Any] = {} - timing_dict: Dict[str, float] = {} + costs_metric: Dict[str, Any] = {} + # Use parent time_metric or create new one + if time_metric is None: + time_metric = {} # Initialize service components components = self.orchestration_service._initialize_service_components(request) - # Call existing RAG pipeline + # Call existing RAG pipeline with "rag" prefix for namespacing response = await self.orchestration_service._execute_orchestration_pipeline( request=request, components=components, - costs_dict=costs_dict, - timing_dict=timing_dict, + costs_metric=costs_metric, + time_metric=time_metric, + prefix="rag", ) - # Log costs and timings - self.orchestration_service.log_costs(costs_dict) - from src.utils.time_tracker import log_step_timings - - log_step_timings(timing_dict, request.chatId) + # Log costs (timing is logged by parent orchestration service) + self.orchestration_service.log_costs(costs_metric) return response @@ -98,6 +100,7 @@ async def execute_streaming( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: """ Execute RAG workflow in streaming mode. @@ -116,6 +119,7 @@ async def execute_streaming( Args: request: Orchestration request with user query context: Unused (RAG doesn't need classification metadata) + time_metric: Optional timing dictionary from parent (for unified tracking) Returns: AsyncIterator yielding SSE-formatted strings @@ -124,8 +128,10 @@ async def execute_streaming( logger.info(f"[{request.chatId}] Executing RAG workflow (streaming)") # Initialize tracking dictionaries - costs_dict: Dict[str, Any] = {} - timing_dict: Dict[str, float] = {} + costs_metric: Dict[str, Any] = {} + # Use parent time_metric or create new one + if time_metric is None: + time_metric = {} # Get components from context if provided, otherwise initialize components = context.get("components") @@ -166,7 +172,7 @@ def mark_error(self, error_id: str) -> None: request=request, components=components, stream_ctx=stream_ctx, - costs_dict=costs_dict, - timing_dict=timing_dict, + costs_metric=costs_metric, + time_metric=time_metric, ): yield sse_chunk diff --git a/src/tool_classifier/workflows/service_workflow.py b/src/tool_classifier/workflows/service_workflow.py index d71e2d9..b432c62 100644 --- a/src/tool_classifier/workflows/service_workflow.py +++ b/src/tool_classifier/workflows/service_workflow.py @@ -27,6 +27,7 @@ SERVICE_DISCOVERY_TIMEOUT, ) from tool_classifier.intent_detector import IntentDetectionModule +import time class LLMServiceProtocol(Protocol): @@ -64,11 +65,11 @@ def format_sse(self, chat_id: str, content: str) -> str: """ ... - def log_costs(self, costs_dict: Dict[str, Dict[str, Any]]) -> None: + def log_costs(self, costs_metric: Dict[str, Dict[str, Any]]) -> None: """Log cost information for tracking. Args: - costs_dict: Dictionary of costs per component + costs_metric: Dictionary of costs per component """ ... @@ -296,7 +297,7 @@ async def _process_intent_detection( request: OrchestrationRequest, chat_id: str, context: Dict[str, Any], - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> None: """Detect intent, validate service, and populate context. @@ -311,7 +312,7 @@ async def _process_intent_detection( request: Orchestration request chat_id: Chat ID for logging context: Context dict to populate with results - costs_dict: Dictionary to track LLM costs + costs_metric: Dictionary to track LLM costs """ intent_result, intent_usage = await self._detect_service_intent( user_query=request.message, @@ -319,7 +320,7 @@ async def _process_intent_detection( conversation_history=request.conversationHistory, chat_id=chat_id, ) - costs_dict["intent_detection"] = intent_usage + costs_metric["intent_detection"] = intent_usage if intent_result and intent_result.get("matched_service_id"): service_id = intent_result["matched_service_id"] @@ -463,7 +464,7 @@ async def _log_request_details( request: OrchestrationRequest, context: Dict[str, Any], mode: str, - costs_dict: Dict[str, Dict[str, Any]], + costs_metric: Dict[str, Dict[str, Any]], ) -> None: """Log request details and perform service discovery. @@ -471,7 +472,7 @@ async def _log_request_details( request: The orchestration request context: Workflow context dictionary mode: Execution mode ("streaming" or "non-streaming") - costs_dict: Dictionary to accumulate cost tracking information + costs_metric: Dictionary to accumulate cost tracking information """ chat_id = request.chatId logger.info(f"[{chat_id}] SERVICE WORKFLOW ({mode}): {request.message}") @@ -529,7 +530,7 @@ async def _log_request_details( request=request, chat_id=chat_id, context=context, - costs_dict=costs_dict, + costs_metric=costs_metric, ) else: services = response_data.get("services", []) @@ -540,7 +541,7 @@ async def _log_request_details( request=request, chat_id=chat_id, context=context, - costs_dict=costs_dict, + costs_metric=costs_metric, ) else: logger.warning(f"[{chat_id}] Service discovery failed") @@ -549,17 +550,30 @@ async def execute_async( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[OrchestrationResponse]: - """Execute service workflow in non-streaming mode.""" + """Execute service workflow in non-streaming mode. + + Args: + request: Orchestration request + context: Workflow context + time_metric: Optional timing dictionary for unified tracking + """ + chat_id = request.chatId # Create costs tracking dictionary (follows RAG workflow pattern) - costs_dict: Dict[str, Dict[str, Any]] = {} + costs_metric: Dict[str, Dict[str, Any]] = {} + # Use parent time_metric or create new one + if time_metric is None: + time_metric = {} - # Log comprehensive request details and perform service discovery + # Service discovery with timing + start_time = time.time() await self._log_request_details( - request, context, mode="non-streaming", costs_dict=costs_dict + request, context, mode="non-streaming", costs_metric=costs_metric ) + time_metric["service.discovery"] = time.time() - start_time # Check if service was detected and validated if not context.get("service_id"): @@ -573,6 +587,7 @@ async def execute_async( logger.info(f"[{chat_id}] Entity Transformation:") # Step 1: Extract service metadata from context + start_time = time.time() service_metadata = self._extract_service_metadata(context, chat_id) if not service_metadata: logger.error( @@ -596,6 +611,7 @@ async def execute_async( service_name=service_metadata["service_name"], chat_id=chat_id, ) + time_metric["service.entity_validation"] = time.time() - start_time logger.info( f"[{chat_id}] - Validation status: " @@ -657,7 +673,7 @@ async def execute_async( # Log costs after service workflow completes (follows RAG workflow pattern) if self.orchestration_service: - self.orchestration_service.log_costs(costs_dict) + self.orchestration_service.log_costs(costs_metric) return OrchestrationResponse( chatId=request.chatId, @@ -672,17 +688,30 @@ async def execute_streaming( self, request: OrchestrationRequest, context: Dict[str, Any], + time_metric: Optional[Dict[str, float]] = None, ) -> Optional[AsyncIterator[str]]: - """Execute service workflow in streaming mode.""" + """Execute service workflow in streaming mode. + + Args: + request: Orchestration request + context: Workflow context + time_metric: Optional timing dictionary for unified tracking + """ + chat_id = request.chatId # Create costs tracking dictionary (follows RAG workflow pattern) - costs_dict: Dict[str, Dict[str, Any]] = {} + costs_metric: Dict[str, Dict[str, Any]] = {} + # Use parent time_metric or create new one + if time_metric is None: + time_metric = {} - # Log comprehensive request details and perform service discovery + # Service discovery with timing + start_time = time.time() await self._log_request_details( - request, context, mode="streaming", costs_dict=costs_dict + request, context, mode="streaming", costs_metric=costs_metric ) + time_metric["service.discovery"] = time.time() - start_time # Check if service was detected and validated if not context.get("service_id"): @@ -790,7 +819,7 @@ async def debug_stream() -> AsyncIterator[str]: # Log costs after streaming completes (follows RAG workflow pattern) # Must be inside generator because costs are accumulated during streaming - orchestration_service.log_costs(costs_dict) + orchestration_service.log_costs(costs_metric) return debug_stream() # REMOVE THIS BLOCK AFTER STEP 7 IMPLEMENTATION (END) diff --git a/src/utils/budget_tracker.py b/src/utils/budget_tracker.py index 134b034..aaa3b15 100644 --- a/src/utils/budget_tracker.py +++ b/src/utils/budget_tracker.py @@ -186,26 +186,26 @@ def update_budget( return {"success": False, "reason": "unexpected_error", "error": str(e)} def update_budget_from_costs( - self, connection_id: Optional[str], costs_dict: Dict[str, Dict[str, Any]] + self, connection_id: Optional[str], costs_metric: Dict[str, Dict[str, Any]] ) -> Dict[str, Any]: """ Update budget from a costs dictionary containing component costs. Args: connection_id: The LLM connection ID (optional) - costs_dict: Dictionary of component costs with total_cost values + costs_metric: Dictionary of component costs with total_cost values Returns: Dictionary containing the response from the update endpoint """ # Calculate total cost from all components total_cost = 0.0 - for component_costs in costs_dict.values(): + for component_costs in costs_metric.values(): total_cost += component_costs.get("total_cost", 0.0) logger.debug( f"Total cost calculated from components: ${total_cost:.6f} " - f"(components: {list(costs_dict.keys())})" + f"(components: {list(costs_metric.keys())})" ) return self.update_budget(connection_id, total_cost) diff --git a/src/utils/time_tracker.py b/src/utils/time_tracker.py index 5b6d8de..fce45f4 100644 --- a/src/utils/time_tracker.py +++ b/src/utils/time_tracker.py @@ -5,23 +5,31 @@ def log_step_timings( - timing_dict: Dict[str, float], chat_id: Optional[str] = None + time_metric: Dict[str, float], chat_id: Optional[str] = None ) -> None: """ Log all step timings in a clean format. Args: - timing_dict: Dictionary containing step names and their execution times + time_metric: Dictionary containing step names and their execution times chat_id: Optional chat ID for context """ - if not timing_dict: + if not time_metric: return + # Parent/composite timings that should be hidden from logs + # These are aggregate timings that already include their sub-steps + PARENT_TIMINGS = {"classifier.route"} + prefix = f"[{chat_id}] " if chat_id else "" logger.info(f"{prefix}STEP EXECUTION TIMES:") total_time = 0.0 - for step_name, elapsed_time in timing_dict.items(): + for step_name, elapsed_time in time_metric.items(): + # Skip parent/composite timings entirely + if step_name in PARENT_TIMINGS: + continue + # Special handling for inline streaming guardrails if step_name == "output_guardrails" and elapsed_time < 0.001: logger.info(f" {step_name:25s}: (inline during streaming)")