diff --git a/internal/server/passthrough_support.go b/internal/server/passthrough_support.go index c91ad322..32ee63dd 100644 --- a/internal/server/passthrough_support.go +++ b/internal/server/passthrough_support.go @@ -311,8 +311,25 @@ func (s *passthroughService) proxyPassthroughResponse(c *echo.Context, providerT return nil } + body, err := io.ReadAll(resp.Body) + if err != nil { + return handleError(c, core.NewProviderError(providerType, http.StatusBadGateway, "failed to read provider passthrough response body", err)) + } + + workflow := core.GetWorkflow(c.Request().Context()) + if s.usageLogger != nil && s.usageLogger.Config().Enabled && (workflow == nil || workflow.UsageEnabled()) { + model := "" + if info != nil { + model = strings.TrimSpace(info.Model) + } + model = resolvedModelFromWorkflow(workflow, model) + requestID := requestIDFromContextOrHeader(c.Request()) + usagePath := strings.TrimSpace(c.Request().URL.Path) + s.logPassthroughNonStreamUsage(body, model, providerType, providerName, requestID, usagePath, c.Request().Context()) + } + c.Response().WriteHeader(resp.StatusCode) - if _, err := io.Copy(c.Response(), resp.Body); err != nil { + if _, err := c.Response().Write(body); err != nil { return err } if f, ok := c.Response().(http.Flusher); ok { @@ -320,3 +337,29 @@ func (s *passthroughService) proxyPassthroughResponse(c *echo.Context, providerT } return nil } + +func (s *passthroughService) logPassthroughNonStreamUsage(body []byte, model, providerType, providerName, requestID, endpoint string, ctx context.Context) { + if len(body) == 0 { + return + } + + auditPath := passthroughStreamAuditPath(endpoint, providerType, endpoint) + var pricingArgs []*core.ModelPricing + if s.pricingResolver != nil { + pricingProvider := strings.TrimSpace(providerName) + if pricingProvider == "" { + pricingProvider = strings.TrimSpace(providerType) + } + if p := s.pricingResolver.ResolvePricing(model, pricingProvider); p != nil { + pricingArgs = append(pricingArgs, p) + } + } + + entry := usage.ExtractFromCachedResponseBody(body, requestID, model, providerType, auditPath, "", pricingArgs...) + if entry == nil { + return + } + entry.ProviderName = strings.TrimSpace(providerName) + entry.UserPath = core.UserPathFromContext(ctx) + s.usageLogger.Write(entry) +} diff --git a/internal/server/request_selector_peek.go b/internal/server/request_selector_peek.go index c038b4c4..2c0bfe88 100644 --- a/internal/server/request_selector_peek.go +++ b/internal/server/request_selector_peek.go @@ -26,7 +26,10 @@ func seedRequestBodySelectorHints(req *http.Request, bodyMode core.BodyMode, env } hints := peekRequestBodySelectorHints(req, requestSelectorPeekLimit) - if !hints.parsed || !hints.complete { + if !hints.parsed && !hints.complete { + if bodyMode == core.BodyModeOpaque && hints.model != "" { + core.ApplyBodySelectorHints(env, hints.model, hints.provider, hints.stream) + } return } core.ApplyBodySelectorHints(env, hints.model, hints.provider, hints.stream)