diff --git a/config/config.example.yaml b/config/config.example.yaml index 42575813..a95e6910 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -198,6 +198,14 @@ providers: type: anthropic api_key: "sk-ant-..." + # OAuth-authenticated provider example. + # Set api_key to "oauth" to enable OAuth authentication via the dashboard. + # After configuring, visit /admin/dashboard/oauth to authenticate. + # Requests use the passthrough route: POST /p/anthropic_oauth/v1/chat/completions + # anthropic_oauth: + # type: anthropic + # api_key: "oauth" + gemini: type: gemini api_key: "..." diff --git a/docs/dev/oauth-implementation-guide.md b/docs/dev/oauth-implementation-guide.md new file mode 100644 index 00000000..54fa9e4e --- /dev/null +++ b/docs/dev/oauth-implementation-guide.md @@ -0,0 +1,367 @@ +# OAuth Provider Implementation Guide + +Reference implementation: Anthropic OAuth (branch `feat/anthropic-oauth-pkce`) + +This document captures every decision, file, and fix made when adding OAuth 2.0 + PKCE support for the Anthropic provider. Use it as the blueprint when adding OAuth for other providers (e.g. OpenAI Codex). + +--- + +## Overview + +OAuth providers are configured with `api_key: "oauth"` in `config.yaml`. Once the user authenticates via the admin dashboard, the provider behaves identically to a static API key provider — tokens are stored, refreshed automatically, and injected into upstream requests. + +```yaml +providers: + my_claude: + type: anthropic + api_key: "oauth" +``` + +Requests are sent via the passthrough route: + +``` +POST /p/{provider_name}/v1/chat/completions +``` + +--- + +## Dual-mode callback — no configuration required + +GoModel supports both local and remote OAuth flows without any extra configuration: + +| Mode | When to use | How it works | +|---|---|---| +| **Local** (Authenticate button) | GoModel and browser on the same machine | Popup redirects to `http://localhost:54545/callback` — GoModel receives the code automatically | +| **Remote** (Remote button) | GoModel on a remote server | Popup redirects to `https://platform.claude.com/oauth/code/callback` — user copies the URL and pastes it in the dashboard | + +Both modes are always available. No `GOMODEL_PUBLIC_URL` or any other config needed. + +--- + +## Architecture + +``` +config.yaml (api_key: "oauth") + │ + ▼ +ProviderFactory.SetOAuthStore(store) ← called in app.go before providers.Init() + │ + ▼ +AnthropicProvider detects "oauth" sentinel + │ + ├── on request: load token from store, inject as Bearer + ├── on expiry: call RefreshToken(), persist new token + └── on missing token: cancel request context → upstream call aborted +``` + +### New packages + +| Package | Path | Purpose | +|---|---|---| +| `oauth` | `internal/oauth/` | OAuth 2.0 + PKCE primitives, provider interface, Anthropic implementation | +| `oauthstore` | `internal/oauthstore/` | Token persistence (SQLite, PostgreSQL, MongoDB) | +| `oauthusage` | `internal/oauthusage/` | Fetch and cache Anthropic rate-limit usage windows | + +--- + +## Package: `internal/oauth` + +### `oauth.go` — core types and PKCE helpers + +```go +type Provider interface { + // redirectURI is the full callback URI — either LocalCallbackURI(port) + // or a provider-hosted URI like platform.claude.com/oauth/code/callback. + AuthorizationURL(state, verifier, redirectURI string) string + ExchangeCode(ctx context.Context, code, verifier, state, redirectURI string) (*TokenResponse, error) + RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) + FetchProfile(ctx context.Context, accessToken string) (*Profile, error) +} +``` + +Key functions: +- `NewPKCEPair()` — generates verifier + S256 challenge +- `NewState()` — generates hex-encoded CSRF state (16 bytes → 32 hex chars) +- `LocalCallbackURI(port)` — builds `http://localhost:{port}/callback` + +**Critical**: State must be hex-encoded (`hex.EncodeToString`), not base64url. Anthropic rejects base64url states with "Invalid request format". + +### `anthropic.go` — Anthropic provider + +Constants: +```go +AnthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" // same as Claude Code +anthropicAuthURL = "https://claude.ai/oauth/authorize" +anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" +anthropicProfileURL = "https://api.anthropic.com/api/oauth/profile" +DefaultCallbackPort = 54545 +anthropicDefaultScopes = "org:create_api_key user:profile user:inference" +``` + +**Critical quirks**: +1. `code=true` must be the **first** parameter in the authorization URL +2. Query string must be built **manually** (not via `url.Values.Encode()` which sorts alphabetically) +3. Redirect URI must use `http://localhost:{port}/callback` (not `127.0.0.1`) +4. `state` must be included in the token exchange body (non-standard) +5. For remote/manual flow, use `https://platform.claude.com/oauth/code/callback` as redirect URI — `console.anthropic.com/oauth/code/callback` redirects there, and the token exchange must use the **final** URI + +Authorization URL construction: +```go +query := "code=true" + + "&client_id=" + url.QueryEscape(AnthropicClientID) + + "&redirect_uri=" + url.QueryEscape(redirectURI) + + "&response_type=code" + + "&scope=" + url.QueryEscape(anthropicDefaultScopes) + + "&state=" + url.QueryEscape(state) + + "&code_challenge=" + url.QueryEscape(challenge) + + "&code_challenge_method=S256" +``` + +Reference implementation: `/Users/vfeitoza/Projetos/cligate/src/claude-oauth.js` + +--- + +## Package: `internal/oauthstore` + +### Store interface + +```go +type Store interface { + Save(ctx context.Context, token *Token) error + Get(ctx context.Context, providerName string) (*Token, error) + Delete(ctx context.Context, providerName string) error + List(ctx context.Context) ([]*Token, error) + Close() error +} +``` + +### Token struct + +```go +type Token struct { + ProviderName string + AccessToken string + RefreshToken string + ExpiresAt time.Time + Scopes []string + AccountID string + AccountEmail string + DisplayName string + SubscriptionType string +} +``` + +### Factory + +```go +// internal/oauthstore/factory.go +func NewFromStorage(ctx context.Context, shared storage.Storage) (Store, error) { + return storage.ResolveBackend[Store]( + shared, + func(db *sql.DB) (Store, error) { return NewSQLiteStore(db) }, + func(pool *pgxpool.Pool) (Store, error) { return NewPostgreSQLStore(ctx, pool) }, + func(db *mongo.Database) (Store, error) { return NewMongoDBStore(db) }, + ) +} +``` + +Follows the same pattern as `internal/authkeys/factory.go`. + +--- + +## Admin API endpoints + +Registered under `/admin/api/v1/oauth` via `admin.RegisterOAuthRoutes(group, handler)`. + +| Method | Path | Description | +|---|---|---| +| GET | `/oauth/providers` | List all OAuth-configured providers with status | +| POST | `/oauth/start` | Start PKCE flow, returns `auth_url`, `manual_auth_url`, `state` | +| GET | `/oauth/callback` | Receive authorization code from local callback server | +| POST | `/oauth/callback-manual` | Receive pasted callback URL or raw code from dashboard | +| POST | `/oauth/revoke` | Delete stored token | +| GET | `/oauth/usage/:name` | Fetch usage windows for a provider | +| GET | `/oauth/status/:name` | Token status for a single provider | + +### `StartOAuth` response + +```json +{ + "auth_url": "https://claude.ai/oauth/authorize?...&redirect_uri=http%3A%2F%2Flocalhost%3A54545%2Fcallback...", + "manual_auth_url": "https://claude.ai/oauth/authorize?...&redirect_uri=https%3A%2F%2Fplatform.claude.com%2Foauth%2Fcode%2Fcallback...", + "manual_uri": "https://platform.claude.com/oauth/code/callback", + "state": "4b477a04aea23843fb82c61e0872cb31", + "callback_port": 54545 +} +``` + +### `oauthFlowState` — dual redirect URI + +```go +type oauthFlowState struct { + verifier string + state string + providerName string + providerType string + redirectURI string // used in AuthorizationURL (local: localhost, manual: platform.claude.com) + exchangeURI string // used in token exchange (same as redirectURI for local; platform.claude.com for manual) + callbackPort int + server *oauth.CallbackServer + createdAt time.Time +} +``` + +### Handler wiring (`internal/app/app.go`) + +```go +// 1. Create store before providers.Init() +oauthStore, err := oauthstore.NewFromStorage(ctx, sharedStorage) + +// 2. Pass store to provider factory +cfg.Factory.SetOAuthStore(oauthStore) + +// 3. Create handler +oauthHandler = admin.NewOAuthHandler(oauthStore, configuredProviders) + +// 4. Wire into server config +serverCfg.OAuthHandler = oauthHandler +``` + +### Route registration (`internal/server/http.go`) + +```go +if cfg != nil && cfg.AdminEndpointsEnabled && cfg.AdminHandler != nil { + adminGroup := e.Group("/admin/api/v1") + cfg.AdminHandler.RegisterRoutes(adminGroup) + admin.RegisterOAuthRoutes(adminGroup, cfg.OAuthHandler) +} +``` + +**Note**: Echo v5 uses `c.Param()`, not `c.PathParam()`. + +--- + +## Dashboard UI + +### Files + +| File | Change | +|---|---| +| `internal/admin/dashboard/templates/page-oauth.html` | OAuth page template | +| `internal/admin/dashboard/templates/index.html` | Added `{{template "dashboard-page-oauth" .}}` | +| `internal/admin/dashboard/templates/layout.html` | Added ` + diff --git a/internal/admin/dashboard/templates/page-oauth.html b/internal/admin/dashboard/templates/page-oauth.html new file mode 100644 index 00000000..72c61f1c --- /dev/null +++ b/internal/admin/dashboard/templates/page-oauth.html @@ -0,0 +1,248 @@ +{{define "dashboard-page-oauth"}} + + +{{end}} diff --git a/internal/admin/dashboard/templates/sidebar.html b/internal/admin/dashboard/templates/sidebar.html index f93fd0cc..f3ddb275 100644 --- a/internal/admin/dashboard/templates/sidebar.html +++ b/internal/admin/dashboard/templates/sidebar.html @@ -36,6 +36,10 @@

GoModel

API Keys + + + OAuth + Workflows diff --git a/internal/admin/handler_oauth.go b/internal/admin/handler_oauth.go new file mode 100644 index 00000000..c1e2e6e1 --- /dev/null +++ b/internal/admin/handler_oauth.go @@ -0,0 +1,613 @@ +package admin + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/labstack/echo/v5" + + "gomodel/internal/oauth" + "gomodel/internal/oauthstore" + "gomodel/internal/oauthusage" + "gomodel/internal/providers" +) + +// anthropicManualCallbackURI is the Anthropic-hosted callback page accepted by +// the public client ID. Used in remote mode when GoModel cannot receive the +// redirect directly. The user copies the resulting URL and submits it via the +// dashboard's manual callback endpoint. +const anthropicManualCallbackURI = "https://platform.claude.com/oauth/code/callback" + +// OAuthProviderStatus is the admin-facing view of a single OAuth provider. +type OAuthProviderStatus struct { + ProviderName string `json:"provider_name"` + ProviderType string `json:"provider_type"` + Status string `json:"status"` // "pending", "authenticated", "expired", "error" + Authenticated bool `json:"authenticated"` + AccountEmail string `json:"account_email,omitempty"` + DisplayName string `json:"display_name,omitempty"` + SubscriptionType string `json:"subscription_type,omitempty"` + TokenExpiresAt *time.Time `json:"token_expires_at,omitempty"` + LastRefreshedAt *time.Time `json:"last_refreshed_at,omitempty"` +} + +// OAuthUsageWindowResponse is the admin-facing view of a usage window. +type OAuthUsageWindowResponse struct { + Utilization float64 `json:"utilization"` + UtilizationPercent int `json:"utilization_percent"` + ResetsAt time.Time `json:"resets_at"` +} + +// OAuthExtraUsageResponse is the admin-facing view of extra credit usage. +type OAuthExtraUsageResponse struct { + IsEnabled bool `json:"is_enabled"` + MonthlyLimit float64 `json:"monthly_limit,omitempty"` + UsedCredits float64 `json:"used_credits,omitempty"` + Utilization float64 `json:"utilization,omitempty"` + DisabledReason string `json:"disabled_reason,omitempty"` +} + +// OAuthUsageResponse is the admin-facing view of OAuth usage data. +type OAuthUsageResponse struct { + ProviderName string `json:"provider_name"` + AccountEmail string `json:"account_email"` + SubscriptionType string `json:"subscription_type"` + FiveHour *OAuthUsageWindowResponse `json:"five_hour,omitempty"` + SevenDay *OAuthUsageWindowResponse `json:"seven_day,omitempty"` + SevenDayOAuthApps *OAuthUsageWindowResponse `json:"seven_day_oauth_apps,omitempty"` + SevenDayOpus *OAuthUsageWindowResponse `json:"seven_day_opus,omitempty"` + SevenDaySonnet *OAuthUsageWindowResponse `json:"seven_day_sonnet,omitempty"` + ExtraUsage *OAuthExtraUsageResponse `json:"extra_usage,omitempty"` + FetchedAt time.Time `json:"fetched_at"` +} + +// oauthFlowState holds in-progress PKCE state for a pending OAuth flow. +type oauthFlowState struct { + verifier string + state string + providerName string + providerType string + redirectURI string // full redirect URI used in AuthorizationURL + exchangeURI string // redirect URI to use in token exchange (may differ) + callbackPort int // non-zero only for local callback server mode + server *oauth.CallbackServer // non-nil only for local callback server mode + createdAt time.Time +} + +// OAuthHandler handles OAuth-related admin endpoints. +type OAuthHandler struct { + store oauthstore.Store + usageFetcher *oauthusage.CachingFetcher + configuredProviders []providers.SanitizedProviderConfig + + flowMu sync.Mutex + flows map[string]*oauthFlowState // keyed by state token +} + +// NewOAuthHandler creates a new OAuthHandler. +func NewOAuthHandler(store oauthstore.Store, configuredProviders []providers.SanitizedProviderConfig) *OAuthHandler { + return &OAuthHandler{ + store: store, + usageFetcher: oauthusage.NewCachingFetcher(oauthusage.NewHTTPFetcher()), + configuredProviders: configuredProviders, + flows: make(map[string]*oauthFlowState), + } +} + +// RegisterOAuthRoutes mounts the OAuth admin routes on the given registrar. +func (h *OAuthHandler) RegisterOAuthRoutes(g RouteRegistrar) { + g.GET("/oauth/providers", h.ListOAuthProviders) + g.POST("/oauth/start", h.StartOAuth) + g.GET("/oauth/callback", h.OAuthCallback) + g.POST("/oauth/callback-manual", h.OAuthCallbackManual) + g.POST("/oauth/revoke", h.RevokeOAuth) + g.GET("/oauth/usage/:provider_name", h.GetOAuthUsage) + g.GET("/oauth/status/:provider_name", h.GetOAuthStatus) +} + +// ListOAuthProviders returns all providers configured with api_key: "oauth". +func (h *OAuthHandler) ListOAuthProviders(c *echo.Context) error { + ctx := c.Request().Context() + statuses, err := h.buildProviderStatuses(ctx) + if err != nil { + return handleError(c, err) + } + return c.JSON(http.StatusOK, map[string]any{ + "providers": statuses, + "total": len(statuses), + }) +} + +// StartOAuth initiates the OAuth flow for a provider. +// Body: {"provider_name": "anthropic_oauth"} +func (h *OAuthHandler) StartOAuth(c *echo.Context) error { + var req struct { + ProviderName string `json:"provider_name"` + } + if err := c.Bind(&req); err != nil { + return handleError(c, fmt.Errorf("invalid request body: %w", err)) + } + req.ProviderName = strings.TrimSpace(req.ProviderName) + if req.ProviderName == "" { + return handleError(c, fmt.Errorf("provider_name is required")) + } + + // Verify the provider exists and is OAuth-configured + provCfg, ok := h.findOAuthProvider(req.ProviderName) + if !ok { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": fmt.Sprintf("provider %q not found or not configured for OAuth", req.ProviderName), + }) + } + + // Generate PKCE pair and state + pkce, err := oauth.NewPKCEPair() + if err != nil { + return handleError(c, fmt.Errorf("generate PKCE: %w", err)) + } + state, err := oauth.NewState() + if err != nil { + return handleError(c, fmt.Errorf("generate state: %w", err)) + } + + // Start local callback server (works when browser and server are on the same machine). + fallbackPorts := []int{54546, 54547, 54548, 54549, 54550} + cs, actualPort, err := oauth.TryCallbackPorts(oauth.DefaultCallbackPort, fallbackPorts...) + if err != nil { + return handleError(c, fmt.Errorf("start OAuth callback server: %w", err)) + } + localRedirectURI := oauth.LocalCallbackURI(actualPort) + + // Store flow state with both local and manual redirect URIs. + // completeOAuthFlow will use whichever path completes first. + h.flowMu.Lock() + h.cleanExpiredFlows() + h.flows[state] = &oauthFlowState{ + verifier: pkce.Verifier, + state: state, + providerName: req.ProviderName, + providerType: provCfg.Type, + redirectURI: localRedirectURI, + exchangeURI: localRedirectURI, + callbackPort: actualPort, + server: cs, + createdAt: time.Now(), + } + h.flowMu.Unlock() + + // Build two authorization URLs: one for local callback, one for manual flow. + oauthProv := oauth.NewAnthropicProvider() + authURL := oauthProv.AuthorizationURL(state, pkce.Verifier, localRedirectURI) + manualAuthURL := oauthProv.AuthorizationURL(state, pkce.Verifier, anthropicManualCallbackURI) + + // Wait for local callback in background. + go h.waitForCallback(state) + + return c.JSON(http.StatusOK, map[string]any{ + "auth_url": authURL, + "manual_auth_url": manualAuthURL, + "manual_uri": anthropicManualCallbackURI, + "state": state, + "callback_port": actualPort, + }) +} + +// OAuthCallback handles the redirect from the OAuth provider. +func (h *OAuthHandler) OAuthCallback(c *echo.Context) error { + code := c.QueryParam("code") + state := c.QueryParam("state") + errParam := c.QueryParam("error") + + if errParam != "" { + return c.HTML(http.StatusBadRequest, oauthErrorHTML(errParam)) + } + if code == "" || state == "" { + return c.HTML(http.StatusBadRequest, oauthErrorHTML("missing code or state parameter")) + } + + h.flowMu.Lock() + flow, ok := h.flows[state] + h.flowMu.Unlock() + + if !ok { + return c.HTML(http.StatusBadRequest, oauthErrorHTML("invalid or expired OAuth state — please try again")) + } + + ctx := c.Request().Context() + if err := h.completeOAuthFlow(ctx, flow, code, state, flow.exchangeURI); err != nil { + slog.Error("oauth callback: flow completion failed", "provider", flow.providerName, "error", err) + return c.HTML(http.StatusInternalServerError, oauthErrorHTML("authentication failed: "+err.Error())) + } + + h.flowMu.Lock() + delete(h.flows, state) + h.flowMu.Unlock() + + return c.HTML(http.StatusOK, oauthSuccessHTML()) +} + +// OAuthCallbackManual handles the manual callback flow for remote servers. +// The user pastes the full callback URL (or just the code) from the Anthropic +// callback page into the dashboard, which POSTs it here. +// Body: {"callback_url": "https://console.anthropic.com/oauth/code/callback?code=...&state=..."} +// or {"code": "...", "state": "..."} +func (h *OAuthHandler) OAuthCallbackManual(c *echo.Context) error { + var req struct { + CallbackURL string `json:"callback_url"` + Code string `json:"code"` + State string `json:"state"` + } + if err := c.Bind(&req); err != nil { + return handleError(c, fmt.Errorf("invalid request body: %w", err)) + } + + // Extract code and state from URL if provided, otherwise treat as raw code. + if req.CallbackURL != "" { + trimmed := strings.TrimSpace(req.CallbackURL) + if strings.HasPrefix(trimmed, "http://") || strings.HasPrefix(trimmed, "https://") { + parsed, err := url.Parse(trimmed) + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid callback URL"}) + } + if e := parsed.Query().Get("error"); e != "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "OAuth error: " + e}) + } + if c := parsed.Query().Get("code"); c != "" { + req.Code = c + } + if s := parsed.Query().Get("state"); s != "" { + req.State = s + } + } else { + // Treat as raw authorization code. + req.Code = trimmed + } + } + + req.Code = strings.TrimSpace(req.Code) + req.State = strings.TrimSpace(req.State) + + if req.Code == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "code is required"}) + } + if req.State == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "state is required"}) + } + + h.flowMu.Lock() + flow, ok := h.flows[req.State] + h.flowMu.Unlock() + + if !ok { + return c.JSON(http.StatusBadRequest, map[string]string{ + "error": "invalid or expired OAuth state — please start the authentication flow again", + }) + } + + ctx := c.Request().Context() + if err := h.completeOAuthFlow(ctx, flow, req.Code, req.State, anthropicManualCallbackURI); err != nil { + slog.Error("oauth manual callback: flow completion failed", "provider", flow.providerName, "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) + } + + h.flowMu.Lock() + delete(h.flows, req.State) + h.flowMu.Unlock() + + slog.Info("oauth: manual callback completed", "provider", flow.providerName) + return c.JSON(http.StatusOK, map[string]string{"status": "authenticated"}) +} + +// RevokeOAuth removes the stored OAuth token for a provider. +// Body: {"provider_name": "anthropic_oauth"} +func (h *OAuthHandler) RevokeOAuth(c *echo.Context) error { + var req struct { + ProviderName string `json:"provider_name"` + } + if err := c.Bind(&req); err != nil { + return handleError(c, fmt.Errorf("invalid request body: %w", err)) + } + req.ProviderName = strings.TrimSpace(req.ProviderName) + if req.ProviderName == "" { + return handleError(c, fmt.Errorf("provider_name is required")) + } + + ctx := c.Request().Context() + if err := h.store.Delete(ctx, req.ProviderName); err != nil { + return handleError(c, fmt.Errorf("revoke OAuth token: %w", err)) + } + + if h.usageFetcher != nil { + h.usageFetcher.Invalidate(req.ProviderName) + } + + slog.Info("oauth: token revoked", "provider", req.ProviderName) + return c.JSON(http.StatusOK, map[string]string{"status": "revoked"}) +} + +// GetOAuthUsage returns usage data for an OAuth provider. +// Always invalidates the cache so the user gets fresh data on demand. +func (h *OAuthHandler) GetOAuthUsage(c *echo.Context) error { + providerName := strings.TrimSpace(c.Param("provider_name")) + if providerName == "" { + return handleError(c, fmt.Errorf("provider_name is required")) + } + + ctx := c.Request().Context() + token, err := h.store.Get(ctx, providerName) + if err != nil { + if errors.Is(err, oauthstore.ErrNotFound) { + return c.JSON(http.StatusNotFound, map[string]string{ + "error": fmt.Sprintf("provider %q is not authenticated", providerName), + }) + } + return handleError(c, err) + } + + // Invalidate cache so every explicit refresh fetches fresh data. + if h.usageFetcher != nil { + h.usageFetcher.Invalidate(providerName) + } + + usage, err := h.usageFetcher.FetchUsage(ctx, providerName, token.AccessToken) + if err != nil { + return handleError(c, fmt.Errorf("fetch OAuth usage: %w", err)) + } + if usage == nil { + return c.JSON(http.StatusOK, map[string]any{ + "provider_name": providerName, + "account_email": token.AccountEmail, + "note": "usage data not available for this account", + }) + } + + return c.JSON(http.StatusOK, buildUsageResponse(providerName, token, usage)) +} + +// GetOAuthStatus returns the authentication status for a single OAuth provider. +func (h *OAuthHandler) GetOAuthStatus(c *echo.Context) error { + providerName := strings.TrimSpace(c.Param("provider_name")) + if providerName == "" { + return handleError(c, fmt.Errorf("provider_name is required")) + } + + ctx := c.Request().Context() + status := h.buildProviderStatus(ctx, providerName, "") + return c.JSON(http.StatusOK, status) +} + +// --- helpers --- + +func (h *OAuthHandler) buildProviderStatuses(ctx context.Context) ([]OAuthProviderStatus, error) { + result := make([]OAuthProviderStatus, 0) + for _, p := range h.configuredProviders { + if !isOAuthProviderConfig(p) { + continue + } + result = append(result, h.buildProviderStatus(ctx, p.Name, p.Type)) + } + return result, nil +} + +func (h *OAuthHandler) buildProviderStatus(ctx context.Context, providerName, providerType string) OAuthProviderStatus { + status := OAuthProviderStatus{ + ProviderName: providerName, + ProviderType: providerType, + Status: "pending", + } + + token, err := h.store.Get(ctx, providerName) + if err != nil { + if !errors.Is(err, oauthstore.ErrNotFound) { + slog.Warn("oauth: failed to load token for status", "provider", providerName, "error", err) + } + return status + } + + if providerType == "" { + status.ProviderType = token.ProviderType + } + status.AccountEmail = token.AccountEmail + status.DisplayName = token.DisplayName + status.SubscriptionType = token.SubscriptionType + expiresAt := token.ExpiresAt + status.TokenExpiresAt = &expiresAt + updatedAt := token.UpdatedAt + status.LastRefreshedAt = &updatedAt + + if token.IsExpired() { + status.Status = "expired" + status.Authenticated = false + } else { + status.Status = "authenticated" + status.Authenticated = true + } + + return status +} + +func (h *OAuthHandler) findOAuthProvider(name string) (providers.SanitizedProviderConfig, bool) { + for _, p := range h.configuredProviders { + if p.Name == name && isOAuthProviderConfig(p) { + return p, true + } + } + return providers.SanitizedProviderConfig{}, false +} + +func isOAuthProviderConfig(p providers.SanitizedProviderConfig) bool { + return p.IsOAuth +} + +func (h *OAuthHandler) waitForCallback(state string) { + h.flowMu.Lock() + flow, ok := h.flows[state] + h.flowMu.Unlock() + if !ok { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + result, err := flow.server.Wait(ctx) + if err != nil { + slog.Warn("oauth: callback wait failed", "provider", flow.providerName, "error", err) + h.flowMu.Lock() + delete(h.flows, state) + h.flowMu.Unlock() + return + } + + if err := h.completeOAuthFlow(ctx, flow, result.Code, result.State, flow.exchangeURI); err != nil { + slog.Error("oauth: flow completion failed (background)", "provider", flow.providerName, "error", err) + } + + h.flowMu.Lock() + delete(h.flows, state) + h.flowMu.Unlock() +} + +func (h *OAuthHandler) completeOAuthFlow(ctx context.Context, flow *oauthFlowState, code, state, exchangeURI string) error { + oauthProv := oauth.NewAnthropicProvider() + + tokens, err := oauthProv.ExchangeCode(ctx, code, flow.verifier, state, exchangeURI) + if err != nil { + return fmt.Errorf("exchange code: %w", err) + } + + profile, err := oauthProv.FetchProfile(ctx, tokens.AccessToken) + if err != nil { + slog.Warn("oauth: profile fetch failed, continuing without profile", "provider", flow.providerName, "error", err) + } + + expiresAt := time.Now().Add(time.Duration(tokens.ExpiresIn) * time.Second) + if tokens.ExpiresIn <= 0 { + expiresAt = time.Now().Add(24 * time.Hour) // safe default + } + + token := &oauthstore.Token{ + ProviderName: flow.providerName, + ProviderType: flow.providerType, + AccessToken: tokens.AccessToken, + RefreshToken: tokens.RefreshToken, + ExpiresAt: expiresAt, + Scopes: tokens.Scopes, + SubscriptionType: tokens.SubscriptionType, + } + if profile != nil { + token.AccountEmail = profile.Email + token.AccountID = profile.AccountID + token.DisplayName = profile.DisplayName + if profile.SubscriptionType != "" { + token.SubscriptionType = profile.SubscriptionType + } + } + + if err := h.store.Save(ctx, token); err != nil { + return fmt.Errorf("save token: %w", err) + } + + slog.Info("oauth: authentication successful", + "provider", flow.providerName, + "email", token.AccountEmail, + "subscription", token.SubscriptionType, + ) + return nil +} + +func (h *OAuthHandler) cleanExpiredFlows() { + cutoff := time.Now().Add(-5 * time.Minute) + for state, flow := range h.flows { + if flow.createdAt.Before(cutoff) { + delete(h.flows, state) + } + } +} + +func buildUsageResponse(providerName string, token *oauthstore.Token, usage *oauthusage.Usage) OAuthUsageResponse { + resp := OAuthUsageResponse{ + ProviderName: providerName, + AccountEmail: token.AccountEmail, + SubscriptionType: token.SubscriptionType, + FetchedAt: usage.FetchedAt, + } + resp.FiveHour = toWindowResponse(usage.FiveHour) + resp.SevenDay = toWindowResponse(usage.SevenDay) + resp.SevenDayOAuthApps = toWindowResponse(usage.SevenDayOAuthApps) + resp.SevenDayOpus = toWindowResponse(usage.SevenDayOpus) + resp.SevenDaySonnet = toWindowResponse(usage.SevenDaySonnet) + if usage.ExtraUsage != nil { + resp.ExtraUsage = &OAuthExtraUsageResponse{ + IsEnabled: usage.ExtraUsage.IsEnabled, + MonthlyLimit: usage.ExtraUsage.MonthlyLimit, + UsedCredits: usage.ExtraUsage.UsedCredits, + Utilization: usage.ExtraUsage.Utilization, + DisabledReason: usage.ExtraUsage.DisabledReason, + } + } + return resp +} + +func toWindowResponse(w *oauthusage.UsageWindow) *OAuthUsageWindowResponse { + if w == nil { + return nil + } + return &OAuthUsageWindowResponse{ + Utilization: w.Utilization, + UtilizationPercent: w.UtilizationPercent(), + ResetsAt: w.ResetsAt, + } +} + +func oauthSuccessHTML() string { + return ` + +Authentication Successful + + +
+

✓ Authentication Successful

+

You can close this window and return to the dashboard.

+
+ +` +} + +func oauthErrorHTML(errMsg string) string { + return ` + +Authentication Failed + + +
+

✗ Authentication Failed

+

Please close this window and try again.

+
` + errMsg + `
+
+` +} diff --git a/internal/admin/routes.go b/internal/admin/routes.go index 7f618cd5..04b95115 100644 --- a/internal/admin/routes.go +++ b/internal/admin/routes.go @@ -66,3 +66,12 @@ func (h *Handler) RegisterRoutes(g RouteRegistrar) { g.POST("/workflows", h.CreateWorkflow) g.POST("/workflows/:id/deactivate", h.DeactivateWorkflow) } + +// RegisterOAuthRoutes mounts the OAuth admin routes on the given route group. +// oauthHandler may be nil — in that case no OAuth routes are registered. +func RegisterOAuthRoutes(g RouteRegistrar, oauthHandler *OAuthHandler) { + if oauthHandler == nil { + return + } + oauthHandler.RegisterOAuthRoutes(g) +} diff --git a/internal/app/app.go b/internal/app/app.go index 0a929dd3..8bc64b47 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -26,6 +26,7 @@ import ( "gomodel/internal/fallback" "gomodel/internal/guardrails" "gomodel/internal/modeloverrides" + "gomodel/internal/oauthstore" "gomodel/internal/providers" "gomodel/internal/responsecache" "gomodel/internal/server" @@ -98,6 +99,29 @@ func New(ctx context.Context, cfg Config) (*App, error) { config: appCfg, } + // Initialize a temporary storage connection to back the OAuth token store. + // We reuse the same backend as the rest of the app so tokens land in the + // same database. The connection is shared (not owned) — it will be closed + // by whichever subsystem owns it (audit, usage, etc.). + // We open a dedicated connection here only to bootstrap the factory before + // providers.Init; the store itself is lightweight (no background goroutines). + var oauthStore oauthstore.Store + { + tmpStorage, storageErr := storage.New(ctx, cfg.AppConfig.Config.Storage.BackendConfig()) + if storageErr != nil { + slog.Warn("oauth store unavailable: failed to open storage", "error", storageErr) + } else { + s, storeErr := oauthstore.NewFromStorage(ctx, tmpStorage) + if storeErr != nil { + slog.Warn("oauth store unavailable: failed to create store", "error", storeErr) + _ = tmpStorage.Close() + } else { + oauthStore = s + cfg.Factory.SetOAuthStore(oauthStore) + } + } + } + providerResult, err := providers.Init(ctx, cfg.AppConfig, cfg.Factory) if err != nil { return nil, fmt.Errorf("failed to initialize providers: %w", err) @@ -399,7 +423,7 @@ func New(ctx context.Context, cfg Config) (*App, error) { } usageEnabledForDashboard := usageResult.Logger.Config().Enabled if adminCfg.EndpointsEnabled { - adminHandler, dashHandler, adminErr := initAdmin( + adminHandler, oauthHandler, dashHandler, adminErr := initAdmin( auditResult.Storage, usageResult.Storage, providerResult.Registry, @@ -410,6 +434,7 @@ func New(ctx context.Context, cfg Config) (*App, error) { workflowResult.Service, app.guardrails.Service, budgetResult.Service, + oauthStore, app, dashboardRuntimeConfig(appCfg, usageEnabledForDashboard), usagePricingRecalculationConfigured(appCfg), @@ -421,6 +446,7 @@ func New(ctx context.Context, cfg Config) (*App, error) { } else { serverCfg.AdminEndpointsEnabled = true serverCfg.AdminHandler = adminHandler + serverCfg.OAuthHandler = oauthHandler slog.Info("admin API enabled", "api", config.JoinBasePath(appCfg.Server.BasePath, "/admin/api/v1")) if adminCfg.UIEnabled { serverCfg.AdminUIEnabled = true @@ -800,12 +826,13 @@ func initAdmin( workflowService *workflows.Service, guardrailService *guardrails.Service, budgetService *budget.Service, + oauthTokenStore oauthstore.Store, runtimeRefresher admin.RuntimeRefresher, runtimeConfig admin.DashboardConfigResponse, usagePricingRecalculationEnabled bool, basePath string, uiEnabled bool, -) (*admin.Handler, *dashboard.Handler, error) { +) (*admin.Handler, *admin.OAuthHandler, *dashboard.Handler, error) { // Find a storage connection for reading usage data var store storage.Storage if auditStorage != nil { @@ -821,7 +848,7 @@ func initAdmin( var err error reader, err = usage.NewReader(store) if err != nil { - return nil, nil, fmt.Errorf("failed to create usage reader: %w", err) + return nil, nil, nil, fmt.Errorf("failed to create usage reader: %w", err) } if usagePricingRecalculationEnabled { pricingRecalculator, err = usage.NewPricingRecalculator(store) @@ -840,7 +867,7 @@ func initAdmin( var err error auditReader, err = auditlog.NewReader(auditStorage) if err != nil { - return nil, nil, fmt.Errorf("failed to create audit reader: %w", err) + return nil, nil, nil, fmt.Errorf("failed to create audit reader: %w", err) } } @@ -860,16 +887,22 @@ func initAdmin( admin.WithDashboardRuntimeConfig(runtimeConfig), ) + // Create OAuth handler if an OAuth token store is available. + var oauthHandler *admin.OAuthHandler + if oauthTokenStore != nil { + oauthHandler = admin.NewOAuthHandler(oauthTokenStore, configuredProviders) + } + var dashHandler *dashboard.Handler if uiEnabled { var err error dashHandler, err = dashboard.NewWithBasePath(basePath) if err != nil { - return nil, nil, fmt.Errorf("failed to initialize dashboard: %w", err) + return nil, nil, nil, fmt.Errorf("failed to initialize dashboard: %w", err) } } - return adminHandler, dashHandler, nil + return adminHandler, oauthHandler, dashHandler, nil } func configGuardrailDefinitions(cfg config.GuardrailsConfig) ([]guardrails.Definition, error) { diff --git a/internal/oauth/anthropic.go b/internal/oauth/anthropic.go new file mode 100644 index 00000000..4e3c5b07 --- /dev/null +++ b/internal/oauth/anthropic.go @@ -0,0 +1,279 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + // AnthropicClientID is the public OAuth client ID for Claude/Anthropic. + // This is the same client ID used by Claude Code and other first-party tools. + AnthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + + anthropicAuthURL = "https://claude.ai/oauth/authorize" + anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" + anthropicProfileURL = "https://api.anthropic.com/api/oauth/profile" + + // DefaultCallbackPort is the preferred local port for the OAuth callback server. + DefaultCallbackPort = 54545 + + anthropicDefaultScopes = "org:create_api_key user:profile user:inference" +) + +// AnthropicProvider implements Provider for Anthropic OAuth. +type AnthropicProvider struct { + httpClient *http.Client + tokenURL string // overridable for tests + profileURL string // overridable for tests +} + +// NewAnthropicProvider creates an AnthropicProvider using the default HTTP client. +func NewAnthropicProvider() *AnthropicProvider { + return &AnthropicProvider{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + tokenURL: anthropicTokenURL, + profileURL: anthropicProfileURL, + } +} + +// NewAnthropicProviderWithClient creates an AnthropicProvider with a custom HTTP client. +func NewAnthropicProviderWithClient(client *http.Client) *AnthropicProvider { + if client == nil { + client = &http.Client{Timeout: 30 * time.Second} + } + return &AnthropicProvider{ + httpClient: client, + tokenURL: anthropicTokenURL, + profileURL: anthropicProfileURL, + } +} + +// AuthorizationURL builds the Anthropic OAuth authorization URL. +// Parameter order matches the reference implementation (cligate/claude-oauth.js). +func (p *AnthropicProvider) AuthorizationURL(state, verifier, redirectURI string) string { + challenge := deriveChallenge(verifier) + + // Build params in the exact order used by the reference Claude OAuth client. + // url.Values.Encode() sorts alphabetically which may differ from what the + // Anthropic endpoint expects, so we construct the query string manually. + query := "code=true" + + "&client_id=" + url.QueryEscape(AnthropicClientID) + + "&redirect_uri=" + url.QueryEscape(redirectURI) + + "&response_type=code" + + "&scope=" + url.QueryEscape(anthropicDefaultScopes) + + "&state=" + url.QueryEscape(state) + + "&code_challenge=" + url.QueryEscape(challenge) + + "&code_challenge_method=S256" + + return anthropicAuthURL + "?" + query +} + +// ExchangeCode exchanges an authorization code for tokens. +// Claude requires state in the token exchange body (non-standard). +func (p *AnthropicProvider) ExchangeCode(ctx context.Context, code, verifier, state, redirectURI string) (*TokenResponse, error) { + body := map[string]string{ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirectURI, + "client_id": AnthropicClientID, + "code_verifier": verifier, + } + if state != "" { + body["state"] = state + } + + resp, err := p.postJSON(ctx, p.tokenURL, body) + if err != nil { + return nil, fmt.Errorf("anthropic token exchange: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("anthropic token exchange failed (%d): %s", resp.StatusCode, string(raw)) + } + + var payload struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + SubscriptionType string `json:"subscription_type"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, fmt.Errorf("decode anthropic token response: %w", err) + } + if payload.AccessToken == "" { + return nil, fmt.Errorf("anthropic token exchange: no access_token in response") + } + + scopes := splitScopes(payload.Scope) + if len(scopes) == 0 { + scopes = splitScopes(anthropicDefaultScopes) + } + + return &TokenResponse{ + AccessToken: payload.AccessToken, + RefreshToken: payload.RefreshToken, + ExpiresIn: payload.ExpiresIn, + Scopes: scopes, + SubscriptionType: payload.SubscriptionType, + }, nil +} + +// RefreshToken obtains a new access token using a refresh token. +func (p *AnthropicProvider) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { + if strings.TrimSpace(refreshToken) == "" { + return nil, fmt.Errorf("refresh token is required") + } + + body := map[string]string{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": AnthropicClientID, + } + + resp, err := p.postJSON(ctx, p.tokenURL, body) + if err != nil { + return nil, fmt.Errorf("anthropic token refresh: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("anthropic token refresh failed (%d): %s", resp.StatusCode, string(raw)) + } + + var payload struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + SubscriptionType string `json:"subscription_type"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, fmt.Errorf("decode anthropic refresh response: %w", err) + } + if payload.AccessToken == "" { + return nil, fmt.Errorf("anthropic token refresh: no access_token in response") + } + + scopes := splitScopes(payload.Scope) + if len(scopes) == 0 { + scopes = splitScopes(anthropicDefaultScopes) + } + + // Preserve the original refresh token if the provider did not rotate it. + newRefresh := payload.RefreshToken + if newRefresh == "" { + newRefresh = refreshToken + } + + return &TokenResponse{ + AccessToken: payload.AccessToken, + RefreshToken: newRefresh, + ExpiresIn: payload.ExpiresIn, + Scopes: scopes, + SubscriptionType: payload.SubscriptionType, + }, nil +} + +// FetchProfile retrieves the authenticated user's profile from Anthropic. +func (p *AnthropicProvider) FetchProfile(ctx context.Context, accessToken string) (*Profile, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.profileURL, nil) + if err != nil { + return nil, fmt.Errorf("build profile request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch anthropic profile: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + return nil, fmt.Errorf("fetch anthropic profile failed (%d): %s", resp.StatusCode, string(raw)) + } + + var payload struct { + Account struct { + UUID string `json:"uuid"` + Email string `json:"email"` + FullName string `json:"full_name"` + HasClaudePro bool `json:"has_claude_pro"` + HasClaudeMax bool `json:"has_claude_max"` + } `json:"account"` + Organization struct { + Name string `json:"name"` + } `json:"organization"` + } + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return nil, fmt.Errorf("decode anthropic profile: %w", err) + } + + subscriptionType := "free" + switch { + case payload.Account.HasClaudeMax: + subscriptionType = "max" + case payload.Account.HasClaudePro: + subscriptionType = "pro" + } + + return &Profile{ + AccountID: payload.Account.UUID, + Email: payload.Account.Email, + DisplayName: payload.Account.FullName, + SubscriptionType: subscriptionType, + HasClaudePro: payload.Account.HasClaudePro, + HasClaudeMax: payload.Account.HasClaudeMax, + OrganizationName: payload.Organization.Name, + }, nil +} + +// postJSON sends a JSON POST request and returns the raw response. +func (p *AnthropicProvider) postJSON(ctx context.Context, endpoint string, body map[string]string) (*http.Response, error) { + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal request body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(data))) + if err != nil { + return nil, fmt.Errorf("build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + return p.httpClient.Do(req) +} + +// LocalCallbackURI builds the local redirect URI for the given port. +// Use this when GoModel is running on the same machine as the browser. +func LocalCallbackURI(port int) string { + return fmt.Sprintf("http://localhost:%d/callback", port) +} + +// splitScopes splits a space-separated scope string into a slice. +func splitScopes(s string) []string { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + parts := strings.Fields(s) + result := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + result = append(result, p) + } + } + return result +} diff --git a/internal/oauth/export_test.go b/internal/oauth/export_test.go new file mode 100644 index 00000000..99fa5f6b --- /dev/null +++ b/internal/oauth/export_test.go @@ -0,0 +1,21 @@ +package oauth + +// Test helpers — exported only for use in _test packages. +// These allow injecting custom URLs in unit tests without exposing them +// in the production API. + +// NewAnthropicProviderWithTestTokenURL creates an AnthropicProvider that +// sends token requests to the given URL instead of the real Anthropic endpoint. +func NewAnthropicProviderWithTestTokenURL(tokenURL string) *AnthropicProvider { + p := NewAnthropicProvider() + p.tokenURL = tokenURL + return p +} + +// NewAnthropicProviderWithTestProfileURL creates an AnthropicProvider that +// sends profile requests to the given URL instead of the real Anthropic endpoint. +func NewAnthropicProviderWithTestProfileURL(profileURL string) *AnthropicProvider { + p := NewAnthropicProvider() + p.profileURL = profileURL + return p +} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 00000000..4b445880 --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,103 @@ +// Package oauth provides OAuth 2.0 with PKCE support for provider authentication. +package oauth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" +) + +// ErrStateMismatch is returned when the OAuth state parameter does not match. +var ErrStateMismatch = errors.New("oauth state mismatch") + +// ErrCallbackTimeout is returned when the local callback server times out. +var ErrCallbackTimeout = errors.New("oauth callback timeout") + +// TokenResponse holds the tokens returned by the OAuth token endpoint. +type TokenResponse struct { + AccessToken string + RefreshToken string + ExpiresIn int // seconds + Scopes []string + SubscriptionType string +} + +// Profile holds the authenticated user's profile information. +type Profile struct { + AccountID string + Email string + DisplayName string + SubscriptionType string + HasClaudePro bool + HasClaudeMax bool + OrganizationName string +} + +// Provider defines the operations needed to complete an OAuth flow. +type Provider interface { + // AuthorizationURL returns the URL the user must visit to authorize. + // state is a random CSRF token; verifier is the PKCE code verifier. + // redirectURI is the full callback URI (e.g. "http://localhost:54545/callback" + // or "https://example.com/admin/api/v1/oauth/callback"). + AuthorizationURL(state, verifier, redirectURI string) string + + // ExchangeCode exchanges an authorization code for tokens. + ExchangeCode(ctx context.Context, code, verifier, state, redirectURI string) (*TokenResponse, error) + + // RefreshToken obtains a new access token using a refresh token. + RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) + + // FetchProfile retrieves the authenticated user's profile. + FetchProfile(ctx context.Context, accessToken string) (*Profile, error) +} + +// generateVerifier creates a cryptographically random PKCE code verifier. +func generateVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// deriveChallenge computes the S256 PKCE code challenge from a verifier. +func deriveChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +// generateState creates a random CSRF state token as a hex string, +// matching the format expected by the Anthropic OAuth endpoint. +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// PKCEPair holds a verifier and its derived challenge. +type PKCEPair struct { + Verifier string + Challenge string +} + +// NewPKCEPair generates a fresh PKCE verifier/challenge pair. +func NewPKCEPair() (PKCEPair, error) { + verifier, err := generateVerifier() + if err != nil { + return PKCEPair{}, err + } + return PKCEPair{ + Verifier: verifier, + Challenge: deriveChallenge(verifier), + }, nil +} + +// NewState generates a random CSRF state token. +func NewState() (string, error) { + return generateState() +} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go new file mode 100644 index 00000000..ba764cc2 --- /dev/null +++ b/internal/oauth/oauth_test.go @@ -0,0 +1,254 @@ +package oauth_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gomodel/internal/oauth" +) + +// --- PKCE tests --- + +func TestNewPKCEPair(t *testing.T) { + pair1, err := oauth.NewPKCEPair() + require.NoError(t, err) + assert.NotEmpty(t, pair1.Verifier) + assert.NotEmpty(t, pair1.Challenge) + assert.NotEqual(t, pair1.Verifier, pair1.Challenge) + + // Two pairs must be distinct + pair2, err := oauth.NewPKCEPair() + require.NoError(t, err) + assert.NotEqual(t, pair1.Verifier, pair2.Verifier) + assert.NotEqual(t, pair1.Challenge, pair2.Challenge) +} + +func TestNewState(t *testing.T) { + s1, err := oauth.NewState() + require.NoError(t, err) + assert.NotEmpty(t, s1) + + s2, err := oauth.NewState() + require.NoError(t, err) + assert.NotEqual(t, s1, s2) +} + +// --- AnthropicProvider tests --- + +func TestAnthropicProvider_AuthorizationURL(t *testing.T) { + p := oauth.NewAnthropicProvider() + pair, err := oauth.NewPKCEPair() + require.NoError(t, err) + + authURL := p.AuthorizationURL("test-state", pair.Verifier, oauth.LocalCallbackURI(54545)) + + assert.Contains(t, authURL, "claude.ai/oauth/authorize") + assert.Contains(t, authURL, oauth.AnthropicClientID) + assert.Contains(t, authURL, "test-state") + assert.Contains(t, authURL, "code_challenge_method=S256") + assert.Contains(t, authURL, "localhost%3A54545") // encoded port in redirect_uri +} + +func TestAnthropicProvider_ExchangeCode(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + var body map[string]string + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "authorization_code", body["grant_type"]) + assert.Equal(t, "test-code", body["code"]) + assert.Equal(t, "test-verifier", body["code_verifier"]) + assert.Equal(t, "test-state", body["state"]) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access-abc", + "refresh_token": "refresh-xyz", + "expires_in": 3600, + "scope": "org:create_api_key user:profile user:inference", + "subscription_type": "pro", + }) + })) + defer srv.Close() + + p := oauth.NewAnthropicProviderWithTestTokenURL(srv.URL) + resp, err := p.ExchangeCode(context.Background(), "test-code", "test-verifier", "test-state", oauth.LocalCallbackURI(54545)) + require.NoError(t, err) + + assert.Equal(t, "access-abc", resp.AccessToken) + assert.Equal(t, "refresh-xyz", resp.RefreshToken) + assert.Equal(t, 3600, resp.ExpiresIn) + assert.Equal(t, "pro", resp.SubscriptionType) + assert.Contains(t, resp.Scopes, "user:profile") +} + +func TestAnthropicProvider_ExchangeCode_Error(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + })) + defer srv.Close() + + p := oauth.NewAnthropicProviderWithTestTokenURL(srv.URL) + _, err := p.ExchangeCode(context.Background(), "bad-code", "verifier", "state", oauth.LocalCallbackURI(54545)) + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +func TestAnthropicProvider_RefreshToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]string + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, "refresh_token", body["grant_type"]) + assert.Equal(t, "old-refresh", body["refresh_token"]) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-access", + "refresh_token": "new-refresh", + "expires_in": 3600, + }) + })) + defer srv.Close() + + p := oauth.NewAnthropicProviderWithTestTokenURL(srv.URL) + resp, err := p.RefreshToken(context.Background(), "old-refresh") + require.NoError(t, err) + + assert.Equal(t, "new-access", resp.AccessToken) + assert.Equal(t, "new-refresh", resp.RefreshToken) +} + +func TestAnthropicProvider_RefreshToken_PreservesOriginalWhenNotRotated(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Provider does not return a new refresh token + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "new-access", + "expires_in": 3600, + }) + })) + defer srv.Close() + + p := oauth.NewAnthropicProviderWithTestTokenURL(srv.URL) + resp, err := p.RefreshToken(context.Background(), "original-refresh") + require.NoError(t, err) + + // Original refresh token must be preserved + assert.Equal(t, "original-refresh", resp.RefreshToken) +} + +func TestAnthropicProvider_RefreshToken_EmptyToken(t *testing.T) { + p := oauth.NewAnthropicProvider() + _, err := p.RefreshToken(context.Background(), "") + require.Error(t, err) + assert.Contains(t, err.Error(), "refresh token is required") +} + +func TestAnthropicProvider_FetchProfile(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "account": map[string]any{ + "uuid": "acc-123", + "email": "user@example.com", + "full_name": "Test User", + "has_claude_pro": true, + "has_claude_max": false, + }, + "organization": map[string]any{ + "name": "Test Org", + }, + }) + })) + defer srv.Close() + + p := oauth.NewAnthropicProviderWithTestProfileURL(srv.URL) + profile, err := p.FetchProfile(context.Background(), "test-token") + require.NoError(t, err) + + assert.Equal(t, "acc-123", profile.AccountID) + assert.Equal(t, "user@example.com", profile.Email) + assert.Equal(t, "Test User", profile.DisplayName) + assert.Equal(t, "pro", profile.SubscriptionType) + assert.True(t, profile.HasClaudePro) + assert.False(t, profile.HasClaudeMax) + assert.Equal(t, "Test Org", profile.OrganizationName) +} + +// --- CallbackServer tests --- + +func TestCallbackServer_ReceivesCode(t *testing.T) { + cs := oauth.NewCallbackServer(0) // OS picks port + port, err := cs.Start() + require.NoError(t, err) + assert.Greater(t, port, 0) + + // Simulate the OAuth provider redirecting back + go func() { + time.Sleep(50 * time.Millisecond) + resp, err := http.Get(strings.Replace( + "http://localhost:PORT/callback?code=auth-code-123&state=csrf-state", + "PORT", strconv.Itoa(port), 1, + )) + if err == nil { + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + result, err := cs.Wait(ctx) + require.NoError(t, err) + assert.Equal(t, "auth-code-123", result.Code) + assert.Equal(t, "csrf-state", result.State) +} + +func TestCallbackServer_ErrorFromProvider(t *testing.T) { + cs := oauth.NewCallbackServer(0) + port, err := cs.Start() + require.NoError(t, err) + + go func() { + time.Sleep(50 * time.Millisecond) + resp, err := http.Get(strings.Replace( + "http://localhost:PORT/callback?error=access_denied", + "PORT", strconv.Itoa(port), 1, + )) + if err == nil { + resp.Body.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err = cs.Wait(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "access_denied") +} + +func TestCallbackServer_ContextCancelled(t *testing.T) { + cs := oauth.NewCallbackServer(0) + _, err := cs.Start() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, err = cs.Wait(ctx) + require.Error(t, err) +} diff --git a/internal/oauth/server.go b/internal/oauth/server.go new file mode 100644 index 00000000..f1946097 --- /dev/null +++ b/internal/oauth/server.go @@ -0,0 +1,191 @@ +package oauth + +import ( + "context" + "fmt" + "net" + "net/http" + "time" +) + +const ( + callbackPath = "/callback" + defaultCallbackTimeout = 2 * time.Minute +) + +// CallbackResult holds the authorization code and state received from the +// OAuth provider after the user authorizes the application. +type CallbackResult struct { + Code string + State string +} + +// CallbackServer listens on a local port for the OAuth redirect and captures +// the authorization code. It shuts down automatically after receiving one +// successful callback or when the context is cancelled. +type CallbackServer struct { + port int + server *http.Server + result chan CallbackResult + errCh chan error +} + +// NewCallbackServer creates a callback server bound to the given port. +// Use port 0 to let the OS pick a free port; call Port() after Start(). +func NewCallbackServer(port int) *CallbackServer { + cs := &CallbackServer{ + port: port, + result: make(chan CallbackResult, 1), + errCh: make(chan error, 1), + } + + mux := http.NewServeMux() + mux.HandleFunc(callbackPath, cs.handleCallback) + + cs.server = &http.Server{ + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + return cs +} + +// Start begins listening. It returns the actual bound port (useful when port 0 +// was requested) and any bind error. +func (cs *CallbackServer) Start() (int, error) { + ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cs.port)) + if err != nil { + return 0, fmt.Errorf("oauth callback: bind port %d: %w", cs.port, err) + } + cs.port = ln.Addr().(*net.TCPAddr).Port + + go func() { + if err := cs.server.Serve(ln); err != nil && err != http.ErrServerClosed { + select { + case cs.errCh <- err: + default: + } + } + }() + + return cs.port, nil +} + +// Wait blocks until the callback is received, the context is cancelled, or the +// timeout elapses. It shuts down the server before returning. +func (cs *CallbackServer) Wait(ctx context.Context) (*CallbackResult, error) { + defer cs.shutdown() + + timer := time.NewTimer(defaultCallbackTimeout) + defer timer.Stop() + + select { + case result := <-cs.result: + return &result, nil + case err := <-cs.errCh: + return nil, fmt.Errorf("oauth callback server error: %w", err) + case <-timer.C: + return nil, ErrCallbackTimeout + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Port returns the port the server is (or will be) listening on. +func (cs *CallbackServer) Port() int { + return cs.port +} + +func (cs *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + if errParam := q.Get("error"); errParam != "" { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(errorHTML(errParam))) //nolint:errcheck + select { + case cs.errCh <- fmt.Errorf("oauth provider error: %s", errParam): + default: + } + return + } + + code := q.Get("code") + state := q.Get("state") + if code == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing authorization code")) //nolint:errcheck + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write([]byte(successHTML())) //nolint:errcheck + + select { + case cs.result <- CallbackResult{Code: code, State: state}: + default: + } +} + +func (cs *CallbackServer) shutdown() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cs.server.Shutdown(ctx) //nolint:errcheck +} + +// TryCallbackPorts attempts to start a callback server on the preferred port, +// falling back to a list of alternatives if the preferred port is busy. +func TryCallbackPorts(preferred int, fallbacks ...int) (*CallbackServer, int, error) { + ports := append([]int{preferred}, fallbacks...) + for _, port := range ports { + cs := NewCallbackServer(port) + actualPort, err := cs.Start() + if err == nil { + return cs, actualPort, nil + } + } + return nil, 0, fmt.Errorf("oauth callback: no available port in %v", ports) +} + +func successHTML() string { + return ` + +Authentication Successful + + +
+

✓ Authentication Successful

+

You can close this window and return to the dashboard.

+
+ +` +} + +func errorHTML(errMsg string) string { + return ` + +Authentication Failed + + +
+

✗ Authentication Failed

+

Please close this window and try again.

+
` + errMsg + `
+
+` +} diff --git a/internal/oauthstore/factory.go b/internal/oauthstore/factory.go new file mode 100644 index 00000000..cafa4b73 --- /dev/null +++ b/internal/oauthstore/factory.go @@ -0,0 +1,22 @@ +package oauthstore + +import ( + "context" + "database/sql" + + "github.com/jackc/pgx/v5/pgxpool" + "go.mongodb.org/mongo-driver/v2/mongo" + + "gomodel/internal/storage" +) + +// NewFromStorage creates an oauthstore.Store backed by the given shared storage connection. +// It supports SQLite, PostgreSQL, and MongoDB backends. +func NewFromStorage(ctx context.Context, shared storage.Storage) (Store, error) { + return storage.ResolveBackend[Store]( + shared, + func(db *sql.DB) (Store, error) { return NewSQLiteStore(db) }, + func(pool *pgxpool.Pool) (Store, error) { return NewPostgreSQLStore(ctx, pool) }, + func(db *mongo.Database) (Store, error) { return NewMongoDBStore(db) }, + ) +} diff --git a/internal/oauthstore/store.go b/internal/oauthstore/store.go new file mode 100644 index 00000000..ef7f4793 --- /dev/null +++ b/internal/oauthstore/store.go @@ -0,0 +1,128 @@ +// Package oauthstore provides persistence for OAuth tokens used by providers +// configured with api_key: "oauth". +package oauthstore + +import ( + "context" + "errors" + "strings" + "time" +) + +var ( + // ErrNotFound indicates no OAuth token exists for the given provider name. + ErrNotFound = errors.New("oauth token not found") +) + +// Token holds a persisted OAuth token for a named provider instance. +type Token struct { + ProviderName string // configured provider name (e.g. "anthropic_oauth") + ProviderType string // provider type (e.g. "anthropic") + AccessToken string // current bearer token + RefreshToken string // used to obtain a new access token; may be empty + ExpiresAt time.Time // when the access token expires + Scopes []string // granted OAuth scopes + AccountEmail string // authenticated account email + AccountID string // provider account/org ID + DisplayName string // human-readable account name + SubscriptionType string // e.g. "free", "pro", "max" + CreatedAt time.Time + UpdatedAt time.Time +} + +// IsExpired reports whether the access token has expired, using a 5-minute +// safety margin so callers can refresh before the token actually expires. +func (t *Token) IsExpired() bool { + if t == nil { + return true + } + return time.Now().Add(5 * time.Minute).After(t.ExpiresAt) +} + +// Store defines persistence operations for OAuth tokens. +type Store interface { + // Save creates or replaces the token for the given provider name. + Save(ctx context.Context, token *Token) error + // Get returns the token for the given provider name, or ErrNotFound. + Get(ctx context.Context, providerName string) (*Token, error) + // Delete removes the token for the given provider name. + // Returns nil if the token did not exist. + Delete(ctx context.Context, providerName string) error + // List returns all stored tokens ordered by provider name. + List(ctx context.Context) ([]*Token, error) + // Close releases any resources held by the store. + Close() error +} + +// rowScanner is implemented by *sql.Row, *sql.Rows, and pgx equivalents. +type rowScanner interface { + Scan(dest ...any) error +} + +// scanTokenRow scans a token from any row scanner (SQLite or PostgreSQL). +// Column order must match the SELECT used in each store implementation. +func scanTokenRow(scanner rowScanner) (*Token, error) { + var t Token + var expiresAt int64 + var createdAt int64 + var updatedAt int64 + var scopes string + + if err := scanner.Scan( + &t.ProviderName, + &t.ProviderType, + &t.AccessToken, + &t.RefreshToken, + &expiresAt, + &scopes, + &t.AccountEmail, + &t.AccountID, + &t.DisplayName, + &t.SubscriptionType, + &createdAt, + &updatedAt, + ); err != nil { + return nil, err + } + + t.ExpiresAt = time.Unix(expiresAt, 0).UTC() + t.CreatedAt = time.Unix(createdAt, 0).UTC() + t.UpdatedAt = time.Unix(updatedAt, 0).UTC() + t.Scopes = splitScopes(scopes) + return &t, nil +} + +// normalizeProviderName trims and lowercases the provider name for consistent +// storage and lookup. +func normalizeProviderName(name string) string { + return strings.TrimSpace(name) +} + +// joinScopes serialises a scope slice to a space-separated string. +func joinScopes(scopes []string) string { + filtered := make([]string, 0, len(scopes)) + for _, s := range scopes { + s = strings.TrimSpace(s) + if s != "" { + filtered = append(filtered, s) + } + } + return strings.Join(filtered, " ") +} + +// splitScopes deserialises a space-separated scope string. +func splitScopes(s string) []string { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + parts := strings.Fields(s) + result := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + result = append(result, p) + } + } + return result +} diff --git a/internal/oauthstore/store_mongodb.go b/internal/oauthstore/store_mongodb.go new file mode 100644 index 00000000..71f76ccc --- /dev/null +++ b/internal/oauthstore/store_mongodb.go @@ -0,0 +1,170 @@ +package oauthstore + +import ( + "context" + "fmt" + "strings" + "time" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +type mongoOAuthTokenDocument struct { + ProviderName string `bson:"_id"` + ProviderType string `bson:"provider_type"` + AccessToken string `bson:"access_token"` + RefreshToken string `bson:"refresh_token"` + ExpiresAt time.Time `bson:"expires_at"` + Scopes string `bson:"scopes"` + AccountEmail string `bson:"account_email"` + AccountID string `bson:"account_id"` + DisplayName string `bson:"display_name"` + SubscriptionType string `bson:"subscription_type"` + CreatedAt time.Time `bson:"created_at"` + UpdatedAt time.Time `bson:"updated_at"` +} + +// MongoDBStore stores OAuth tokens in MongoDB. +type MongoDBStore struct { + collection *mongo.Collection +} + +// NewMongoDBStore creates collection indexes if needed. +func NewMongoDBStore(database *mongo.Database) (*MongoDBStore, error) { + if database == nil { + return nil, fmt.Errorf("database is required") + } + coll := database.Collection("oauth_tokens") + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + indexes := []mongo.IndexModel{ + {Keys: bson.D{{Key: "expires_at", Value: 1}}}, + {Keys: bson.D{{Key: "provider_type", Value: 1}}}, + } + if _, err := coll.Indexes().CreateMany(ctx, indexes); err != nil { + return nil, fmt.Errorf("create oauth_tokens indexes: %w", err) + } + return &MongoDBStore{collection: coll}, nil +} + +func (s *MongoDBStore) Save(ctx context.Context, token *Token) error { + if token == nil { + return fmt.Errorf("token is required") + } + name := normalizeProviderName(token.ProviderName) + if name == "" { + return fmt.Errorf("provider_name is required") + } + + now := time.Now().UTC() + createdAt := now + existing, err := s.Get(ctx, name) + if err == nil { + createdAt = existing.CreatedAt + } + + doc := mongoOAuthTokenDocument{ + ProviderName: name, + ProviderType: strings.TrimSpace(token.ProviderType), + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + ExpiresAt: token.ExpiresAt.UTC(), + Scopes: joinScopes(token.Scopes), + AccountEmail: strings.TrimSpace(token.AccountEmail), + AccountID: strings.TrimSpace(token.AccountID), + DisplayName: strings.TrimSpace(token.DisplayName), + SubscriptionType: strings.TrimSpace(token.SubscriptionType), + CreatedAt: createdAt, + UpdatedAt: now, + } + + upsert := true + _, err = s.collection.ReplaceOne( + ctx, + bson.M{"_id": name}, + doc, + options.Replace().SetUpsert(upsert), + ) + if err != nil { + return fmt.Errorf("save oauth token: %w", err) + } + return nil +} + +func (s *MongoDBStore) Get(ctx context.Context, providerName string) (*Token, error) { + name := normalizeProviderName(providerName) + if name == "" { + return nil, fmt.Errorf("provider_name is required") + } + + var doc mongoOAuthTokenDocument + err := s.collection.FindOne(ctx, bson.M{"_id": name}).Decode(&doc) + if err != nil { + if err == mongo.ErrNoDocuments { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get oauth token: %w", err) + } + return tokenFromMongo(doc), nil +} + +func (s *MongoDBStore) Delete(ctx context.Context, providerName string) error { + name := normalizeProviderName(providerName) + if name == "" { + return fmt.Errorf("provider_name is required") + } + _, err := s.collection.DeleteOne(ctx, bson.M{"_id": name}) + if err != nil { + return fmt.Errorf("delete oauth token: %w", err) + } + return nil +} + +func (s *MongoDBStore) List(ctx context.Context) ([]*Token, error) { + cursor, err := s.collection.Find( + ctx, + bson.M{}, + options.Find().SetSort(bson.D{{Key: "_id", Value: 1}}), + ) + if err != nil { + return nil, fmt.Errorf("list oauth tokens: %w", err) + } + defer cursor.Close(ctx) + + result := make([]*Token, 0) + for cursor.Next(ctx) { + var doc mongoOAuthTokenDocument + if err := cursor.Decode(&doc); err != nil { + return nil, fmt.Errorf("decode oauth token: %w", err) + } + result = append(result, tokenFromMongo(doc)) + } + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("iterate oauth tokens: %w", err) + } + return result, nil +} + +func (s *MongoDBStore) Close() error { + return nil +} + +func tokenFromMongo(doc mongoOAuthTokenDocument) *Token { + return &Token{ + ProviderName: doc.ProviderName, + ProviderType: doc.ProviderType, + AccessToken: doc.AccessToken, + RefreshToken: doc.RefreshToken, + ExpiresAt: doc.ExpiresAt.UTC(), + Scopes: splitScopes(doc.Scopes), + AccountEmail: doc.AccountEmail, + AccountID: doc.AccountID, + DisplayName: doc.DisplayName, + SubscriptionType: doc.SubscriptionType, + CreatedAt: doc.CreatedAt.UTC(), + UpdatedAt: doc.UpdatedAt.UTC(), + } +} diff --git a/internal/oauthstore/store_postgresql.go b/internal/oauthstore/store_postgresql.go new file mode 100644 index 00000000..af082260 --- /dev/null +++ b/internal/oauthstore/store_postgresql.go @@ -0,0 +1,187 @@ +package oauthstore + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// PostgreSQLStore stores OAuth tokens in PostgreSQL. +type PostgreSQLStore struct { + pool *pgxpool.Pool +} + +// NewPostgreSQLStore creates the oauth_tokens table and indexes if needed. +func NewPostgreSQLStore(ctx context.Context, pool *pgxpool.Pool) (*PostgreSQLStore, error) { + if ctx == nil { + return nil, fmt.Errorf("context is required") + } + if pool == nil { + return nil, fmt.Errorf("connection pool is required") + } + + _, err := pool.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS oauth_tokens ( + provider_name TEXT PRIMARY KEY, + provider_type TEXT NOT NULL DEFAULT '', + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL DEFAULT '', + expires_at BIGINT NOT NULL, + scopes TEXT NOT NULL DEFAULT '', + account_email TEXT NOT NULL DEFAULT '', + account_id TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + subscription_type TEXT NOT NULL DEFAULT '', + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL + ) + `) + if err != nil { + return nil, fmt.Errorf("failed to create oauth_tokens table: %w", err) + } + + for _, index := range []string{ + `CREATE INDEX IF NOT EXISTS idx_oauth_tokens_expires ON oauth_tokens(expires_at)`, + `CREATE INDEX IF NOT EXISTS idx_oauth_tokens_type ON oauth_tokens(provider_type)`, + } { + if _, err := pool.Exec(ctx, index); err != nil { + return nil, fmt.Errorf("failed to create oauth_tokens index: %w", err) + } + } + + return &PostgreSQLStore{pool: pool}, nil +} + +func (s *PostgreSQLStore) Save(ctx context.Context, token *Token) error { + if token == nil { + return fmt.Errorf("token is required") + } + name := normalizeProviderName(token.ProviderName) + if name == "" { + return fmt.Errorf("provider_name is required") + } + + now := time.Now().UTC() + createdAt := now + existing, err := s.Get(ctx, name) + if err == nil { + createdAt = existing.CreatedAt + } + + _, err = s.pool.Exec(ctx, ` + INSERT INTO oauth_tokens + (provider_name, provider_type, access_token, refresh_token, expires_at, + scopes, account_email, account_id, display_name, subscription_type, + created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (provider_name) DO UPDATE SET + provider_type = EXCLUDED.provider_type, + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + expires_at = EXCLUDED.expires_at, + scopes = EXCLUDED.scopes, + account_email = EXCLUDED.account_email, + account_id = EXCLUDED.account_id, + display_name = EXCLUDED.display_name, + subscription_type = EXCLUDED.subscription_type, + updated_at = EXCLUDED.updated_at + `, + name, + strings.TrimSpace(token.ProviderType), + token.AccessToken, + token.RefreshToken, + token.ExpiresAt.UTC().Unix(), + joinScopes(token.Scopes), + strings.TrimSpace(token.AccountEmail), + strings.TrimSpace(token.AccountID), + strings.TrimSpace(token.DisplayName), + strings.TrimSpace(token.SubscriptionType), + createdAt.Unix(), + now.Unix(), + ) + if err != nil { + return fmt.Errorf("save oauth token: %w", err) + } + return nil +} + +func (s *PostgreSQLStore) Get(ctx context.Context, providerName string) (*Token, error) { + name := normalizeProviderName(providerName) + if name == "" { + return nil, fmt.Errorf("provider_name is required") + } + + row := s.pool.QueryRow(ctx, ` + SELECT provider_name, provider_type, access_token, refresh_token, expires_at, + scopes, account_email, account_id, display_name, subscription_type, + created_at, updated_at + FROM oauth_tokens + WHERE provider_name = $1 + `, name) + + token, err := scanPostgreSQLToken(row) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get oauth token: %w", err) + } + return token, nil +} + +func (s *PostgreSQLStore) Delete(ctx context.Context, providerName string) error { + name := normalizeProviderName(providerName) + if name == "" { + return fmt.Errorf("provider_name is required") + } + _, err := s.pool.Exec(ctx, `DELETE FROM oauth_tokens WHERE provider_name = $1`, name) + if err != nil { + return fmt.Errorf("delete oauth token: %w", err) + } + return nil +} + +//nolint:dupl // SQLite and PostgreSQL List methods are structurally identical but use incompatible driver interfaces +func (s *PostgreSQLStore) List(ctx context.Context) ([]*Token, error) { + rows, err := s.pool.Query(ctx, ` + SELECT provider_name, provider_type, access_token, refresh_token, expires_at, + scopes, account_email, account_id, display_name, subscription_type, + created_at, updated_at + FROM oauth_tokens + ORDER BY provider_name ASC + `) + if err != nil { + return nil, fmt.Errorf("list oauth tokens: %w", err) + } + defer rows.Close() + + result := make([]*Token, 0) + for rows.Next() { + token, err := scanPostgreSQLToken(rows) + if err != nil { + return nil, fmt.Errorf("scan oauth token: %w", err) + } + result = append(result, token) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate oauth tokens: %w", err) + } + return result, nil +} + +func (s *PostgreSQLStore) Close() error { + return nil +} + +type pgScanner interface { + Scan(dest ...any) error +} + +func scanPostgreSQLToken(scanner pgScanner) (*Token, error) { + return scanTokenRow(scanner) +} diff --git a/internal/oauthstore/store_sqlite.go b/internal/oauthstore/store_sqlite.go new file mode 100644 index 00000000..cb110202 --- /dev/null +++ b/internal/oauthstore/store_sqlite.go @@ -0,0 +1,183 @@ +package oauthstore + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" +) + +// SQLiteStore stores OAuth tokens in SQLite. +type SQLiteStore struct { + db *sql.DB +} + +// NewSQLiteStore creates the oauth_tokens table and indexes if needed. +func NewSQLiteStore(db *sql.DB) (*SQLiteStore, error) { + if db == nil { + return nil, fmt.Errorf("database connection is required") + } + + _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS oauth_tokens ( + provider_name TEXT PRIMARY KEY, + provider_type TEXT NOT NULL DEFAULT '', + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL DEFAULT '', + expires_at INTEGER NOT NULL, + scopes TEXT NOT NULL DEFAULT '', + account_email TEXT NOT NULL DEFAULT '', + account_id TEXT NOT NULL DEFAULT '', + display_name TEXT NOT NULL DEFAULT '', + subscription_type TEXT NOT NULL DEFAULT '', + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ) + `) + if err != nil { + return nil, fmt.Errorf("failed to create oauth_tokens table: %w", err) + } + + for _, index := range []string{ + `CREATE INDEX IF NOT EXISTS idx_oauth_tokens_expires ON oauth_tokens(expires_at)`, + `CREATE INDEX IF NOT EXISTS idx_oauth_tokens_type ON oauth_tokens(provider_type)`, + } { + if _, err := db.Exec(index); err != nil { + return nil, fmt.Errorf("failed to create oauth_tokens index: %w", err) + } + } + + return &SQLiteStore{db: db}, nil +} + +func (s *SQLiteStore) Save(ctx context.Context, token *Token) error { + if token == nil { + return fmt.Errorf("token is required") + } + name := normalizeProviderName(token.ProviderName) + if name == "" { + return fmt.Errorf("provider_name is required") + } + + now := time.Now().UTC() + createdAt := now + // Preserve original created_at if the record already exists. + existing, err := s.Get(ctx, name) + if err == nil { + createdAt = existing.CreatedAt + } + + _, err = s.db.ExecContext(ctx, ` + INSERT INTO oauth_tokens + (provider_name, provider_type, access_token, refresh_token, expires_at, + scopes, account_email, account_id, display_name, subscription_type, + created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(provider_name) DO UPDATE SET + provider_type = excluded.provider_type, + access_token = excluded.access_token, + refresh_token = excluded.refresh_token, + expires_at = excluded.expires_at, + scopes = excluded.scopes, + account_email = excluded.account_email, + account_id = excluded.account_id, + display_name = excluded.display_name, + subscription_type = excluded.subscription_type, + updated_at = excluded.updated_at + `, + name, + strings.TrimSpace(token.ProviderType), + token.AccessToken, + token.RefreshToken, + token.ExpiresAt.UTC().Unix(), + joinScopes(token.Scopes), + strings.TrimSpace(token.AccountEmail), + strings.TrimSpace(token.AccountID), + strings.TrimSpace(token.DisplayName), + strings.TrimSpace(token.SubscriptionType), + createdAt.Unix(), + now.Unix(), + ) + if err != nil { + return fmt.Errorf("save oauth token: %w", err) + } + return nil +} + +func (s *SQLiteStore) Get(ctx context.Context, providerName string) (*Token, error) { + name := normalizeProviderName(providerName) + if name == "" { + return nil, fmt.Errorf("provider_name is required") + } + + row := s.db.QueryRowContext(ctx, ` + SELECT provider_name, provider_type, access_token, refresh_token, expires_at, + scopes, account_email, account_id, display_name, subscription_type, + created_at, updated_at + FROM oauth_tokens + WHERE provider_name = ? + `, name) + + token, err := scanSQLiteToken(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get oauth token: %w", err) + } + return token, nil +} + +func (s *SQLiteStore) Delete(ctx context.Context, providerName string) error { + name := normalizeProviderName(providerName) + if name == "" { + return fmt.Errorf("provider_name is required") + } + _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_tokens WHERE provider_name = ?`, name) + if err != nil { + return fmt.Errorf("delete oauth token: %w", err) + } + return nil +} + +//nolint:dupl // SQLite and PostgreSQL List methods are structurally identical but use incompatible driver interfaces +func (s *SQLiteStore) List(ctx context.Context) ([]*Token, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT provider_name, provider_type, access_token, refresh_token, expires_at, + scopes, account_email, account_id, display_name, subscription_type, + created_at, updated_at + FROM oauth_tokens + ORDER BY provider_name ASC + `) + if err != nil { + return nil, fmt.Errorf("list oauth tokens: %w", err) + } + defer rows.Close() + + result := make([]*Token, 0) + for rows.Next() { + token, err := scanSQLiteToken(rows) + if err != nil { + return nil, fmt.Errorf("scan oauth token: %w", err) + } + result = append(result, token) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate oauth tokens: %w", err) + } + return result, nil +} + +func (s *SQLiteStore) Close() error { + return nil +} + +type sqliteScanner interface { + Scan(dest ...any) error +} + +func scanSQLiteToken(scanner sqliteScanner) (*Token, error) { + return scanTokenRow(scanner) +} diff --git a/internal/oauthstore/store_test.go b/internal/oauthstore/store_test.go new file mode 100644 index 00000000..193b8228 --- /dev/null +++ b/internal/oauthstore/store_test.go @@ -0,0 +1,168 @@ +package oauthstore_test + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" + + "gomodel/internal/oauthstore" +) + +func newTestSQLiteStore(t *testing.T) oauthstore.Store { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + store, err := oauthstore.NewSQLiteStore(db) + require.NoError(t, err) + return store +} + +func sampleToken(providerName string) *oauthstore.Token { + return &oauthstore.Token{ + ProviderName: providerName, + ProviderType: "anthropic", + AccessToken: "access-token-abc", + RefreshToken: "refresh-token-xyz", + ExpiresAt: time.Now().Add(time.Hour).UTC().Truncate(time.Second), + Scopes: []string{"org:create_api_key", "user:profile", "user:inference"}, + AccountEmail: "user@example.com", + AccountID: "acc-123", + DisplayName: "Test User", + SubscriptionType: "pro", + } +} + +func TestSQLiteStore_SaveAndGet(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + token := sampleToken("anthropic_oauth") + require.NoError(t, store.Save(ctx, token)) + + got, err := store.Get(ctx, "anthropic_oauth") + require.NoError(t, err) + + assert.Equal(t, token.ProviderName, got.ProviderName) + assert.Equal(t, token.ProviderType, got.ProviderType) + assert.Equal(t, token.AccessToken, got.AccessToken) + assert.Equal(t, token.RefreshToken, got.RefreshToken) + assert.Equal(t, token.ExpiresAt.Unix(), got.ExpiresAt.Unix()) + assert.Equal(t, token.Scopes, got.Scopes) + assert.Equal(t, token.AccountEmail, got.AccountEmail) + assert.Equal(t, token.AccountID, got.AccountID) + assert.Equal(t, token.DisplayName, got.DisplayName) + assert.Equal(t, token.SubscriptionType, got.SubscriptionType) + assert.False(t, got.CreatedAt.IsZero()) + assert.False(t, got.UpdatedAt.IsZero()) +} + +func TestSQLiteStore_Save_UpdatesExisting(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + token := sampleToken("anthropic_oauth") + require.NoError(t, store.Save(ctx, token)) + + first, err := store.Get(ctx, "anthropic_oauth") + require.NoError(t, err) + originalCreatedAt := first.CreatedAt + + // Update the token + token.AccessToken = "new-access-token" + token.AccountEmail = "other@example.com" + require.NoError(t, store.Save(ctx, token)) + + updated, err := store.Get(ctx, "anthropic_oauth") + require.NoError(t, err) + + assert.Equal(t, "new-access-token", updated.AccessToken) + assert.Equal(t, "other@example.com", updated.AccountEmail) + // created_at must be preserved + assert.Equal(t, originalCreatedAt.Unix(), updated.CreatedAt.Unix()) +} + +func TestSQLiteStore_Get_NotFound(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + _, err := store.Get(ctx, "nonexistent") + assert.ErrorIs(t, err, oauthstore.ErrNotFound) +} + +func TestSQLiteStore_Delete(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + token := sampleToken("anthropic_oauth") + require.NoError(t, store.Save(ctx, token)) + + require.NoError(t, store.Delete(ctx, "anthropic_oauth")) + + _, err := store.Get(ctx, "anthropic_oauth") + assert.ErrorIs(t, err, oauthstore.ErrNotFound) +} + +func TestSQLiteStore_Delete_NonExistent(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + // Deleting a non-existent token should not error + assert.NoError(t, store.Delete(ctx, "nonexistent")) +} + +func TestSQLiteStore_List(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + require.NoError(t, store.Save(ctx, sampleToken("provider_b"))) + require.NoError(t, store.Save(ctx, sampleToken("provider_a"))) + require.NoError(t, store.Save(ctx, sampleToken("provider_c"))) + + tokens, err := store.List(ctx) + require.NoError(t, err) + require.Len(t, tokens, 3) + + // Should be ordered by provider_name ASC + assert.Equal(t, "provider_a", tokens[0].ProviderName) + assert.Equal(t, "provider_b", tokens[1].ProviderName) + assert.Equal(t, "provider_c", tokens[2].ProviderName) +} + +func TestSQLiteStore_List_Empty(t *testing.T) { + store := newTestSQLiteStore(t) + ctx := context.Background() + + tokens, err := store.List(ctx) + require.NoError(t, err) + assert.Empty(t, tokens) +} + +func TestToken_IsExpired(t *testing.T) { + t.Run("not expired", func(t *testing.T) { + token := &oauthstore.Token{ExpiresAt: time.Now().Add(time.Hour)} + assert.False(t, token.IsExpired()) + }) + + t.Run("expired", func(t *testing.T) { + token := &oauthstore.Token{ExpiresAt: time.Now().Add(-time.Minute)} + assert.True(t, token.IsExpired()) + }) + + t.Run("within safety margin", func(t *testing.T) { + // Expires in 3 minutes — within the 5-minute safety margin + token := &oauthstore.Token{ExpiresAt: time.Now().Add(3 * time.Minute)} + assert.True(t, token.IsExpired()) + }) + + t.Run("nil token", func(t *testing.T) { + var token *oauthstore.Token + assert.True(t, token.IsExpired()) + }) +} diff --git a/internal/oauthusage/usage.go b/internal/oauthusage/usage.go new file mode 100644 index 00000000..1078452f --- /dev/null +++ b/internal/oauthusage/usage.go @@ -0,0 +1,332 @@ +// Package oauthusage fetches and caches OAuth usage data from the Anthropic API. +// Usage data includes rate-limit windows (5-hour, 7-day) and extra credit usage. +package oauthusage + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" +) + +const ( + anthropicUsageURL = "https://api.anthropic.com/api/oauth/usage" + cacheStaleAfter = 5 * time.Minute + unsupportedRetryIn = 30 * time.Minute +) + +// UsageWindow describes utilization within a rolling time window. +type UsageWindow struct { + // Utilization is a value between 0 and 1 (e.g. 0.45 = 45%). + Utilization float64 `json:"utilization"` + ResetsAt time.Time `json:"resets_at"` +} + +// UtilizationPercent returns the utilization as an integer percentage (0–100). +func (w *UsageWindow) UtilizationPercent() int { + if w == nil { + return 0 + } + pct := w.Utilization * 100 + if pct < 0 { + return 0 + } + if pct > 100 { + return 100 + } + return int(pct) +} + +// ExtraUsage describes pay-as-you-go credit usage beyond the subscription. +type ExtraUsage struct { + IsEnabled bool `json:"is_enabled"` + MonthlyLimit float64 `json:"monthly_limit,omitempty"` + UsedCredits float64 `json:"used_credits,omitempty"` + Utilization float64 `json:"utilization,omitempty"` + DisabledReason string `json:"disabled_reason,omitempty"` +} + +// Usage holds the full usage snapshot for an OAuth-authenticated account. +type Usage struct { + FiveHour *UsageWindow `json:"five_hour,omitempty"` + SevenDay *UsageWindow `json:"seven_day,omitempty"` + SevenDayOAuthApps *UsageWindow `json:"seven_day_oauth_apps,omitempty"` + SevenDayOpus *UsageWindow `json:"seven_day_opus,omitempty"` + SevenDaySonnet *UsageWindow `json:"seven_day_sonnet,omitempty"` + ExtraUsage *ExtraUsage `json:"extra_usage,omitempty"` + FetchedAt time.Time `json:"fetched_at"` +} + +// Fetcher retrieves OAuth usage data for an access token. +type Fetcher interface { + FetchUsage(ctx context.Context, accessToken string) (*Usage, error) +} + +// cacheEntry holds a cached usage result for one provider. +type cacheEntry struct { + usage *Usage + fetchedAt time.Time + unsupported bool // true when the API returned "not supported" for this token +} + +func (e *cacheEntry) isStale() bool { + if e == nil || e.fetchedAt.IsZero() { + return true + } + return time.Since(e.fetchedAt) > cacheStaleAfter +} + +func (e *cacheEntry) unsupportedStillFresh() bool { + if e == nil || !e.unsupported { + return false + } + return time.Since(e.fetchedAt) < unsupportedRetryIn +} + +// CachingFetcher wraps an HTTP fetcher with an in-memory cache keyed by +// provider name. Cache entries expire after 5 minutes. +type CachingFetcher struct { + mu sync.Mutex + cache map[string]*cacheEntry + fetcher Fetcher +} + +// NewCachingFetcher creates a CachingFetcher backed by the given Fetcher. +func NewCachingFetcher(fetcher Fetcher) *CachingFetcher { + return &CachingFetcher{ + cache: make(map[string]*cacheEntry), + fetcher: fetcher, + } +} + +// FetchUsage returns cached usage if fresh, otherwise fetches from the API. +func (c *CachingFetcher) FetchUsage(ctx context.Context, providerName, accessToken string) (*Usage, error) { + c.mu.Lock() + entry := c.cache[providerName] + if entry != nil && !entry.isStale() { + usage := entry.usage + c.mu.Unlock() + return usage, nil + } + if entry != nil && entry.unsupportedStillFresh() { + c.mu.Unlock() + return nil, nil // unsupported, skip silently + } + c.mu.Unlock() + + usage, err := c.fetcher.FetchUsage(ctx, accessToken) + c.mu.Lock() + defer c.mu.Unlock() + + if err != nil { + // Mark as unsupported if the API explicitly says so + if isUnsupportedError(err) { + c.cache[providerName] = &cacheEntry{ + fetchedAt: time.Now(), + unsupported: true, + } + return nil, nil + } + return nil, err + } + + c.cache[providerName] = &cacheEntry{ + usage: usage, + fetchedAt: time.Now(), + } + return usage, nil +} + +// Invalidate removes the cached entry for the given provider. +func (c *CachingFetcher) Invalidate(providerName string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.cache, providerName) +} + +// HTTPFetcher fetches usage data from the Anthropic OAuth usage API. +type HTTPFetcher struct { + client *http.Client + usageURL string // overridable for tests +} + +// NewHTTPFetcher creates an HTTPFetcher using the default HTTP client. +func NewHTTPFetcher() *HTTPFetcher { + return &HTTPFetcher{ + client: &http.Client{Timeout: 15 * time.Second}, + usageURL: anthropicUsageURL, + } +} + +// NewHTTPFetcherWithURL creates an HTTPFetcher with a custom usage URL (for tests). +func NewHTTPFetcherWithURL(usageURL string) *HTTPFetcher { + return &HTTPFetcher{ + client: &http.Client{Timeout: 15 * time.Second}, + usageURL: usageURL, + } +} + +// FetchUsage calls the Anthropic OAuth usage API and returns normalized data. +func (f *HTTPFetcher) FetchUsage(ctx context.Context, accessToken string) (*Usage, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, f.usageURL, nil) + if err != nil { + return nil, fmt.Errorf("build usage request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + + resp, err := f.client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch oauth usage: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + + if resp.StatusCode == http.StatusUnauthorized { + // Check for explicit "not supported" message + if isUnsupportedBody(body) { + return nil, &UnsupportedError{Message: string(body)} + } + return nil, fmt.Errorf("oauth usage: unauthorized (%d)", resp.StatusCode) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oauth usage: unexpected status %d: %s", resp.StatusCode, truncate(string(body), 200)) + } + + return parseUsageResponse(body) +} + +// UnsupportedError indicates the usage API does not support this OAuth token. +type UnsupportedError struct { + Message string +} + +func (e *UnsupportedError) Error() string { + return "oauth usage API not supported for this token: " + e.Message +} + +func isUnsupportedError(err error) bool { + _, ok := err.(*UnsupportedError) + return ok +} + +func isUnsupportedBody(body []byte) bool { + lower := string(body) + return len(lower) > 0 && (contains(lower, "not supported") || contains(lower, "oauth authentication is currently not supported")) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContains(s, substr)) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} + +// rawUsagePayload mirrors the Anthropic OAuth usage API response shape. +type rawUsagePayload struct { + FiveHour *rawWindow `json:"five_hour"` + SevenDay *rawWindow `json:"seven_day"` + SevenDayOAuthApps *rawWindow `json:"seven_day_oauth_apps"` + SevenDayOpus *rawWindow `json:"seven_day_opus"` + SevenDaySonnet *rawWindow `json:"seven_day_sonnet"` + ExtraUsage *rawExtraUsage `json:"extra_usage"` +} + +type rawWindow struct { + Utilization float64 `json:"utilization"` + ResetsAt string `json:"resets_at"` +} + +type rawExtraUsage struct { + MonthlyLimit *float64 `json:"monthly_limit"` + UsedCredits *float64 `json:"used_credits"` + Utilization *float64 `json:"utilization"` + DisabledReason string `json:"disabled_reason"` +} + +func parseUsageResponse(body []byte) (*Usage, error) { + var raw rawUsagePayload + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("parse oauth usage response: %w", err) + } + + usage := &Usage{ + FetchedAt: time.Now().UTC(), + FiveHour: normalizeWindow(raw.FiveHour), + SevenDay: normalizeWindow(raw.SevenDay), + SevenDayOAuthApps: normalizeWindow(raw.SevenDayOAuthApps), + SevenDayOpus: normalizeWindow(raw.SevenDayOpus), + SevenDaySonnet: normalizeWindow(raw.SevenDaySonnet), + ExtraUsage: normalizeExtraUsage(raw.ExtraUsage), + } + return usage, nil +} + +func normalizeWindow(raw *rawWindow) *UsageWindow { + if raw == nil { + return nil + } + w := &UsageWindow{ + Utilization: clampUtilization(raw.Utilization), + } + if raw.ResetsAt != "" { + if t, err := time.Parse(time.RFC3339, raw.ResetsAt); err == nil { + w.ResetsAt = t.UTC() + } + } + return w +} + +func normalizeExtraUsage(raw *rawExtraUsage) *ExtraUsage { + if raw == nil { + return nil + } + e := &ExtraUsage{ + IsEnabled: raw.DisabledReason == "", + DisabledReason: raw.DisabledReason, + } + if raw.MonthlyLimit != nil { + e.MonthlyLimit = *raw.MonthlyLimit + } + if raw.UsedCredits != nil { + e.UsedCredits = *raw.UsedCredits + } + if raw.Utilization != nil { + e.Utilization = clampUtilization(*raw.Utilization) + } + return e +} + +func clampUtilization(v float64) float64 { + if v < 0 { + return 0 + } + // The Anthropic API may return utilization as a fraction (0–1) or as a + // percentage (0–100). Values above 1 are treated as already-percentage. + // Normalize to 0–1 range to match UsageWindow.Utilization semantics. + if v > 1 { + v = v / 100 + } + if v > 1 { + return 1 + } + return v +} diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 949cc0a0..26048a66 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -15,6 +15,7 @@ import ( "gomodel/internal/core" "gomodel/internal/llmclient" + "gomodel/internal/oauth" "gomodel/internal/providers" "gomodel/internal/streaming" ) @@ -45,6 +46,7 @@ var allowedAnthropicImageMediaTypes = map[string]struct{}{ type Provider struct { client *llmclient.Client apiKey string + oauth *oauthState // non-nil when api_key == "oauth" batchEndpointsMu sync.RWMutex // batchResultEndpoints keeps endpoint hints by provider batch id and custom_id. @@ -58,6 +60,15 @@ func New(providerCfg providers.ProviderConfig, opts providers.ProviderOptions) c apiKey: providerCfg.APIKey, batchResultEndpoints: make(map[string]map[string]string), } + + if isOAuthAPIKey(providerCfg.APIKey) && opts.OAuthStore != nil { + p.oauth = &oauthState{ + store: opts.OAuthStore, + providerName: providerCfg.Name, + oauthProv: oauth.NewAnthropicProvider(), + } + } + clientCfg := llmclient.Config{ ProviderName: "anthropic", BaseURL: providers.ResolveBaseURL(providerCfg.BaseURL, defaultBaseURL), @@ -156,7 +167,11 @@ func (p *Provider) getBatchResultEndpoints(batchID string) map[string]string { // setHeaders sets the required headers for Anthropic API requests func (p *Provider) setHeaders(req *http.Request) { - req.Header.Set("x-api-key", p.apiKey) + if p.oauth != nil { + p.setOAuthHeader(req) + } else { + req.Header.Set("x-api-key", p.apiKey) + } req.Header.Set("anthropic-version", anthropicAPIVersion) // Forward request ID if present in context diff --git a/internal/providers/anthropic/oauth.go b/internal/providers/anthropic/oauth.go new file mode 100644 index 00000000..a387532c --- /dev/null +++ b/internal/providers/anthropic/oauth.go @@ -0,0 +1,98 @@ +package anthropic + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "gomodel/internal/oauth" + "gomodel/internal/oauthstore" +) + + +// oauthState holds the runtime OAuth state for a provider instance. +type oauthState struct { + mu sync.Mutex + store oauthstore.Store + providerName string + oauthProv *oauth.AnthropicProvider +} + +// isOAuthAPIKey reports whether the given api_key value signals OAuth mode. +// The sentinel value is "oauth" (case-insensitive, trimmed). +func isOAuthAPIKey(apiKey string) bool { + return strings.EqualFold(strings.TrimSpace(apiKey), "oauth") +} + +// getValidAccessToken returns a valid access token, refreshing it if needed. +// It is safe for concurrent use. +func (s *oauthState) getValidAccessToken(ctx context.Context) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + + token, err := s.store.Get(ctx, s.providerName) + if err != nil { + if err == oauthstore.ErrNotFound { + return "", fmt.Errorf( + "provider %q requires OAuth authentication — visit the dashboard to authenticate", + s.providerName, + ) + } + return "", fmt.Errorf("oauth: failed to load token for %q: %w", s.providerName, err) + } + + if !token.IsExpired() { + return token.AccessToken, nil + } + + // Token expired — attempt refresh. + if token.RefreshToken == "" { + return "", fmt.Errorf( + "oauth: access token for %q has expired and no refresh token is available — re-authenticate via the dashboard", + s.providerName, + ) + } + + slog.Info("oauth: refreshing access token", "provider", s.providerName) + resp, err := s.oauthProv.RefreshToken(ctx, token.RefreshToken) + if err != nil { + return "", fmt.Errorf("oauth: token refresh failed for %q: %w", s.providerName, err) + } + + token.AccessToken = resp.AccessToken + if resp.RefreshToken != "" { + token.RefreshToken = resp.RefreshToken + } + if resp.ExpiresIn > 0 { + token.ExpiresAt = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) + } + token.UpdatedAt = time.Now() + + if err := s.store.Save(ctx, token); err != nil { + // Log but don't fail — the token is still usable for this request. + slog.Warn("oauth: failed to persist refreshed token", "provider", s.providerName, "error", err) + } + + return token.AccessToken, nil +} + +// setOAuthHeader sets the Authorization header using the stored OAuth token. +// If the token is unavailable (e.g. revoked), it cancels the request context +// so the llmclient aborts the upstream call immediately. +func (p *Provider) setOAuthHeader(req *http.Request) { + token, err := p.oauth.getValidAccessToken(req.Context()) + if err != nil { + slog.Error("oauth: cannot obtain access token", "provider", p.oauth.providerName, "error", err) + // Store the error in the request context so callers can surface it, + // then cancel the context to abort the upstream HTTP call. + ctx, cancel := context.WithCancelCause(req.Context()) + cancel(err) + *req = *req.WithContext(ctx) + return + } + req.Header.Set("Authorization", "Bearer "+token) +} diff --git a/internal/providers/config.go b/internal/providers/config.go index 28cbcc15..2c49d358 100644 --- a/internal/providers/config.go +++ b/internal/providers/config.go @@ -14,6 +14,7 @@ import ( // ProviderConfig holds the fully resolved provider configuration after merging // global defaults with per-provider overrides. type ProviderConfig struct { + Name string // configured provider name (e.g. "anthropic_oauth") Type string APIKey string BaseURL string @@ -409,6 +410,7 @@ func isUnresolvedEnvPlaceholder(value string) bool { } // filterEmptyProviders removes providers without valid credentials. +// Providers with api_key: "oauth" are kept — their token is managed at runtime. func filterEmptyProviders(raw map[string]config.RawProviderConfig, discovery map[string]DiscoveryConfig) map[string]config.RawProviderConfig { result := make(map[string]config.RawProviderConfig, len(raw)) for name, p := range raw { @@ -420,6 +422,11 @@ func filterEmptyProviders(raw map[string]config.RawProviderConfig, discovery map result[name] = p continue } + // Allow OAuth sentinel value through — token is stored at runtime. + if strings.EqualFold(strings.TrimSpace(p.APIKey), "oauth") { + result[name] = p + continue + } if p.APIKey != "" && !strings.Contains(p.APIKey, "${") { result[name] = p } @@ -432,15 +439,16 @@ func filterEmptyProviders(raw map[string]config.RawProviderConfig, discovery map func buildProviderConfigs(raw map[string]config.RawProviderConfig, global config.ResilienceConfig) map[string]ProviderConfig { result := make(map[string]ProviderConfig, len(raw)) for name, r := range raw { - result[name] = buildProviderConfig(r, global) + result[name] = buildProviderConfig(name, r, global) } return result } // buildProviderConfig merges a single RawProviderConfig with the global ResilienceConfig. // Non-nil fields in the raw config override the global defaults. -func buildProviderConfig(raw config.RawProviderConfig, global config.ResilienceConfig) ProviderConfig { +func buildProviderConfig(name string, raw config.RawProviderConfig, global config.ResilienceConfig) ProviderConfig { resolved := ProviderConfig{ + Name: name, Type: raw.Type, APIKey: raw.APIKey, BaseURL: raw.BaseURL, diff --git a/internal/providers/config_test.go b/internal/providers/config_test.go index 37f2736f..b0d71e97 100644 --- a/internal/providers/config_test.go +++ b/internal/providers/config_test.go @@ -63,7 +63,7 @@ var testDiscoveryConfigs = map[string]DiscoveryConfig{ func TestBuildProviderConfig_InheritsGlobal(t *testing.T) { raw := config.RawProviderConfig{Type: "openai", APIKey: "sk-test"} - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) if got.Type != "openai" { t.Errorf("Type = %q, want openai", got.Type) @@ -75,7 +75,7 @@ func TestBuildProviderConfig_InheritsGlobal(t *testing.T) { func TestBuildProviderConfig_NilResilience(t *testing.T) { raw := config.RawProviderConfig{Type: "openai", APIKey: "sk", Resilience: nil} - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) if got.Resilience.Retry != globalRetry { t.Error("nil Resilience should inherit global") @@ -88,7 +88,7 @@ func TestBuildProviderConfig_NilRetry(t *testing.T) { APIKey: "sk", Resilience: &config.RawResilienceConfig{Retry: nil}, } - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) if got.Resilience.Retry != globalRetry { t.Error("nil Retry should inherit global") @@ -105,7 +105,7 @@ func TestBuildProviderConfig_PartialOverride(t *testing.T) { }, }, } - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) if got.Resilience.Retry.MaxRetries != 10 { t.Errorf("MaxRetries = %d, want 10", got.Resilience.Retry.MaxRetries) @@ -132,7 +132,7 @@ func TestBuildProviderConfig_FullOverride(t *testing.T) { }, }, } - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) r := got.Resilience.Retry if r.MaxRetries != 7 { @@ -162,7 +162,7 @@ func TestBuildProviderConfig_ZeroValueOverride(t *testing.T) { }, }, } - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) if got.Resilience.Retry.MaxRetries != 0 { t.Errorf("explicit 0 should override global (3), got %d", got.Resilience.Retry.MaxRetries) @@ -176,7 +176,7 @@ func TestBuildProviderConfig_PreservesFields(t *testing.T) { BaseURL: "https://custom.endpoint.com", Models: []config.RawProviderModel{{ID: "gpt-4"}, {ID: "gpt-3.5-turbo"}}, } - got := buildProviderConfig(raw, globalResilience) + got := buildProviderConfig("test", raw, globalResilience) if got.APIKey != "sk-key" { t.Errorf("APIKey = %q, want sk-key", got.APIKey) @@ -1018,7 +1018,7 @@ func TestBuildProviderConfig_CircuitBreaker_InheritsGlobal(t *testing.T) { Timeout: 30 * time.Second, } raw := config.RawProviderConfig{Type: "openai", APIKey: "sk"} - got := buildProviderConfig(raw, global) + got := buildProviderConfig("test", raw, global) if got.Resilience.CircuitBreaker != global.CircuitBreaker { t.Errorf("expected global circuit breaker to be inherited\ngot: %+v\nwant: %+v", @@ -1034,7 +1034,7 @@ func TestBuildProviderConfig_CircuitBreaker_NilOverride(t *testing.T) { APIKey: "sk", Resilience: &config.RawResilienceConfig{CircuitBreaker: nil}, } - got := buildProviderConfig(raw, global) + got := buildProviderConfig("test", raw, global) if got.Resilience.CircuitBreaker != global.CircuitBreaker { t.Error("nil CircuitBreaker override should inherit global") @@ -1055,7 +1055,7 @@ func TestBuildProviderConfig_CircuitBreaker_PartialOverride(t *testing.T) { }, }, } - got := buildProviderConfig(raw, global) + got := buildProviderConfig("test", raw, global) if got.Resilience.CircuitBreaker.FailureThreshold != 10 { t.Errorf("FailureThreshold = %d, want 10", got.Resilience.CircuitBreaker.FailureThreshold) @@ -1087,7 +1087,7 @@ func TestBuildProviderConfig_CircuitBreaker_FullOverride(t *testing.T) { }, }, } - got := buildProviderConfig(raw, global) + got := buildProviderConfig("test", raw, global) cb := got.Resilience.CircuitBreaker if cb.FailureThreshold != 3 { @@ -1115,7 +1115,7 @@ func TestBuildProviderConfig_CircuitBreaker_ZeroValueOverride(t *testing.T) { }, }, } - got := buildProviderConfig(raw, global) + got := buildProviderConfig("test", raw, global) if got.Resilience.CircuitBreaker.FailureThreshold != 0 { t.Errorf("explicit 0 should override global, got %d", got.Resilience.CircuitBreaker.FailureThreshold) diff --git a/internal/providers/factory.go b/internal/providers/factory.go index efbf09e0..215a64bd 100644 --- a/internal/providers/factory.go +++ b/internal/providers/factory.go @@ -9,6 +9,7 @@ import ( "gomodel/config" "gomodel/internal/core" "gomodel/internal/llmclient" + "gomodel/internal/oauthstore" ) // ProviderOptions bundles runtime settings passed from the factory to provider constructors. @@ -16,6 +17,9 @@ type ProviderOptions struct { Hooks llmclient.Hooks Models []string Resilience config.ResilienceConfig + // OAuthStore is used by providers configured with api_key: "oauth". + // May be nil when OAuth storage is not configured. + OAuthStore oauthstore.Store } // ProviderConstructor is the constructor signature for providers. @@ -45,6 +49,7 @@ type ProviderFactory struct { discoveryConfigs map[string]DiscoveryConfig passthroughEnrichers map[string]core.PassthroughSemanticEnricher hooks llmclient.Hooks + oauthStore oauthstore.Store } // NewProviderFactory creates a new provider factory instance. @@ -63,6 +68,13 @@ func (f *ProviderFactory) SetHooks(hooks llmclient.Hooks) { f.hooks = hooks } +// SetOAuthStore configures the OAuth token store used by providers with api_key: "oauth". +func (f *ProviderFactory) SetOAuthStore(store oauthstore.Store) { + f.mu.Lock() + defer f.mu.Unlock() + f.oauthStore = store +} + // Add adds a provider constructor to the factory. // Panics if reg.Type is empty or reg.New is nil — both are programming errors // caught at startup, not runtime conditions. @@ -89,6 +101,7 @@ func (f *ProviderFactory) Create(cfg ProviderConfig) (core.Provider, error) { f.mu.RLock() builder, ok := f.builders[cfg.Type] hooks := f.hooks + oauthStore := f.oauthStore f.mu.RUnlock() if !ok { @@ -99,6 +112,7 @@ func (f *ProviderFactory) Create(cfg ProviderConfig) (core.Provider, error) { Hooks: hooks, Models: cfg.Models, Resilience: cfg.Resilience, + OAuthStore: oauthStore, } return builder(cfg, opts), nil diff --git a/internal/providers/provider_status.go b/internal/providers/provider_status.go index 54e5ebe8..7359dbab 100644 --- a/internal/providers/provider_status.go +++ b/internal/providers/provider_status.go @@ -36,6 +36,8 @@ type SanitizedProviderConfig struct { APIVersion string `json:"api_version,omitempty"` Models []string `json:"models,omitempty"` Resilience SanitizedResilienceConfig `json:"resilience"` + // IsOAuth is true when the provider is configured with api_key: "oauth". + IsOAuth bool `json:"is_oauth,omitempty"` } // ProviderRuntimeSnapshot describes runtime diagnostics for a configured provider. @@ -95,6 +97,7 @@ func SanitizeProviderConfigs(configs map[string]ProviderConfig) []SanitizedProvi BaseURL: strings.TrimSpace(cfg.BaseURL), APIVersion: strings.TrimSpace(cfg.APIVersion), Models: models, + IsOAuth: strings.EqualFold(strings.TrimSpace(cfg.APIKey), "oauth"), Resilience: SanitizedResilienceConfig{ Retry: SanitizedRetryConfig{ MaxRetries: cfg.Resilience.Retry.MaxRetries, diff --git a/internal/server/http.go b/internal/server/http.go index 84cd7478..693a3c33 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -72,6 +72,7 @@ type Config struct { AdminEndpointsEnabled bool // Whether admin API endpoints are enabled AdminUIEnabled bool // Whether admin dashboard UI is enabled AdminHandler *admin.Handler // Admin API handler (nil if disabled) + OAuthHandler *admin.OAuthHandler // OAuth admin handler (nil if no OAuth providers configured) DashboardHandler *dashboard.Handler // Dashboard UI handler (nil if disabled) SwaggerEnabled bool // Whether to expose the Swagger UI at /swagger/index.html ResponseCacheMiddleware *responsecache.ResponseCacheMiddleware // Optional: response cache middleware for cacheable endpoints @@ -315,7 +316,9 @@ func New(provider core.RoutableProvider, cfg *Config) *Server { // Admin API routes (behind ADMIN_ENDPOINTS_ENABLED flag) if cfg != nil && cfg.AdminEndpointsEnabled && cfg.AdminHandler != nil { - cfg.AdminHandler.RegisterRoutes(e.Group("/admin/api/v1")) + adminGroup := e.Group("/admin/api/v1") + cfg.AdminHandler.RegisterRoutes(adminGroup) + admin.RegisterOAuthRoutes(adminGroup, cfg.OAuthHandler) } // Admin dashboard UI routes (behind ADMIN_UI_ENABLED flag)