diff --git a/internal/providers/gemini/gemini.go b/internal/providers/gemini/gemini.go index 828c76de..5cd661a9 100644 --- a/internal/providers/gemini/gemini.go +++ b/internal/providers/gemini/gemini.go @@ -18,7 +18,7 @@ import ( "gomodel/internal/httpclient" "gomodel/internal/llmclient" "gomodel/internal/providers" - "gomodel/internal/providers/googleauth" + "gomodel/internal/providers/googlecommon" ) // Registration provides factory registration for the Gemini provider. @@ -190,7 +190,7 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht if p.configErr != nil || p.authType == geminiAuthTypeAPIKey { return base } - creds, err := googleauth.FindCredentials(context.Background(), googleauth.Config{ + creds, err := googlecommon.FindCredentials(context.Background(), googlecommon.Config{ AuthType: p.authType, ServiceAccountFile: providerCfg.ServiceAccountFile, ServiceAccountJSON: providerCfg.ServiceAccountJSON, @@ -208,7 +208,7 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht if strings.TrimSpace(quotaProject) == "" { quotaProject = strings.TrimSpace(providerCfg.VertexProject) } - return googleauth.HTTPClient(base, creds.TokenSource, quotaProject) + return googlecommon.HTTPClient(base, creds.TokenSource, quotaProject) } func (p *Provider) ready() error { @@ -273,7 +273,7 @@ func normalizeGeminiAuthType(backend string, cfg providers.ProviderConfig) strin switch authType { case "": if backend == geminiBackendVertex { - return googleauth.NormalizeAuthType(authType, googleauth.HasServiceAccount(googleauth.Config{ + return googlecommon.NormalizeAuthType(authType, googlecommon.HasServiceAccount(googlecommon.Config{ ServiceAccountFile: cfg.ServiceAccountFile, ServiceAccountJSON: cfg.ServiceAccountJSON, ServiceAccountJSONBase64: cfg.ServiceAccountJSONBase64, @@ -319,7 +319,7 @@ func useNativeAPIFromEnv() bool { func geminiBaseURLs(providerCfg providers.ProviderConfig, backend string) (openAICompatibleBaseURL, nativeBaseURL string) { if backend == geminiBackendVertex { - return vertexBaseURLs(providerCfg) + return googlecommon.VertexBaseURLs(providerCfg.BaseURL, providerCfg.VertexProject, providerCfg.VertexLocation) } configuredBaseURL := providerCfg.BaseURL baseURL := strings.TrimRight(strings.TrimSpace(configuredBaseURL), "/") @@ -349,47 +349,6 @@ func geminiModelsBaseURL(backend, nativeBaseURL string) string { return nativeBaseURL } -func vertexBaseURLs(providerCfg providers.ProviderConfig) (openAICompatibleBaseURL, nativeBaseURL string) { - baseURL := strings.TrimRight(strings.TrimSpace(providerCfg.BaseURL), "/") - if baseURL == "" { - project := strings.TrimSpace(providerCfg.VertexProject) - location := strings.TrimSpace(providerCfg.VertexLocation) - root := "https://aiplatform.googleapis.com/v1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location) - return root + "/endpoints/openapi", root + "/publishers/google" - } - if nativeBaseURL, ok := vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL); ok { - return baseURL, nativeBaseURL - } - if openAIBaseURL, ok := vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL); ok { - return openAIBaseURL, baseURL - } - return baseURL, baseURL -} - -func vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL string) (string, bool) { - const suffix = "/endpoints/openapi" - if !strings.HasSuffix(baseURL, suffix) { - return "", false - } - root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/") - if root == "" { - return "", false - } - return root + "/publishers/google", true -} - -func vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL string) (string, bool) { - const suffix = "/publishers/google" - if !strings.HasSuffix(baseURL, suffix) { - return "", false - } - root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/") - if root == "" { - return "", false - } - return root + "/endpoints/openapi", true -} - func vertexPublisherModelsBaseURL(nativeBaseURL string) (string, bool) { const projectsPath = "/v1/projects/" nativeBaseURL = strings.TrimRight(strings.TrimSpace(nativeBaseURL), "/") diff --git a/internal/providers/gemini/gemini_test.go b/internal/providers/gemini/gemini_test.go index 2564befe..2c6c02e0 100644 --- a/internal/providers/gemini/gemini_test.go +++ b/internal/providers/gemini/gemini_test.go @@ -13,7 +13,7 @@ import ( "gomodel/internal/core" "gomodel/internal/llmclient" "gomodel/internal/providers" - "gomodel/internal/providers/googleauth" + "gomodel/internal/providers/googlecommon" "golang.org/x/oauth2" ) @@ -748,7 +748,7 @@ func TestVertexListModelsErrorsUseVertexProviderName(t *testing.T) { } func newVertexTestProvider(server *httptest.Server, native bool) *Provider { - tokenClient := googleauth.HTTPClient(server.Client(), oauth2.StaticTokenSource(&oauth2.Token{ + tokenClient := googlecommon.HTTPClient(server.Client(), oauth2.StaticTokenSource(&oauth2.Token{ AccessToken: "vertex-token", TokenType: "Bearer", }), "") diff --git a/internal/providers/googleauth/googleauth.go b/internal/providers/googlecommon/auth.go similarity index 94% rename from internal/providers/googleauth/googleauth.go rename to internal/providers/googlecommon/auth.go index 78f5b06c..d7800eea 100644 --- a/internal/providers/googleauth/googleauth.go +++ b/internal/providers/googlecommon/auth.go @@ -1,4 +1,9 @@ -package googleauth +// Package googlecommon holds infrastructure shared by GoModel's Google-backed +// providers (Gemini AI Studio + Vertex AI). It currently covers authentication +// (ADC / service-account TokenSource resolution + quota project propagation +// via X-Goog-User-Project) and Vertex URL transformations between the native +// publisher endpoint and the OpenAI-compatible endpoint. +package googlecommon import ( "context" @@ -122,17 +127,6 @@ func resolveQuotaProject(creds *google.Credentials) string { return strings.TrimSpace(raw.QuotaProjectID) } -// TokenSource returns just the OAuth2 token source for the given config. It is -// a thin wrapper around FindCredentials kept for callers that do not need the -// resolved quota project. -func TokenSource(ctx context.Context, cfg Config) (oauth2.TokenSource, error) { - creds, err := FindCredentials(ctx, cfg) - if err != nil { - return nil, err - } - return creds.TokenSource, nil -} - func serviceAccountJSON(cfg Config) ([]byte, error) { if value := strings.TrimSpace(cfg.ServiceAccountJSONBase64); value != "" { decoded, err := decodeServiceAccountBase64(value) diff --git a/internal/providers/googleauth/googleauth_test.go b/internal/providers/googlecommon/auth_test.go similarity index 97% rename from internal/providers/googleauth/googleauth_test.go rename to internal/providers/googlecommon/auth_test.go index d25f9c44..fa303660 100644 --- a/internal/providers/googleauth/googleauth_test.go +++ b/internal/providers/googlecommon/auth_test.go @@ -1,4 +1,4 @@ -package googleauth +package googlecommon import ( "context" @@ -55,7 +55,7 @@ func TestServiceAccountJSONReportsOriginalBase64DecodeError(t *testing.T) { } } -func TestTokenSourceAndHTTPClientAuthSelection(t *testing.T) { +func TestFindCredentialsAndHTTPClientAuthSelection(t *testing.T) { tests := []struct { name string cfg func(t *testing.T, tokenURL string) Config @@ -110,10 +110,11 @@ func TestTokenSourceAndHTTPClientAuthSelection(t *testing.T) { })) defer tokenServer.Close() - source, err := TokenSource(context.Background(), tt.cfg(t, tokenServer.URL)) + creds, err := FindCredentials(context.Background(), tt.cfg(t, tokenServer.URL)) if err != nil { - t.Fatalf("TokenSource() error = %v", err) + t.Fatalf("FindCredentials() error = %v", err) } + source := creds.TokenSource upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Authorization"); got != "Bearer "+tt.wantToken { diff --git a/internal/providers/googlecommon/urls.go b/internal/providers/googlecommon/urls.go new file mode 100644 index 00000000..a9986b6a --- /dev/null +++ b/internal/providers/googlecommon/urls.go @@ -0,0 +1,66 @@ +package googlecommon + +import ( + "net/url" + "strings" +) + +// VertexBaseURLs derives Vertex AI's two base URLs (OpenAI-compatible endpoint +// + native publisher endpoint) from either an operator-supplied baseURL or the +// project/location pair. When baseURL is empty the canonical aiplatform.googleapis.com +// host is used. When baseURL is supplied: +// - if it ends in /endpoints/openapi it is treated as the OpenAI-compatible +// surface and the native URL is derived by replacing the suffix +// - if it ends in /publishers/google it is treated as the native surface and +// the OpenAI-compatible URL is derived likewise +// - otherwise both bases are set to the verbatim baseURL (caller has wired +// up a custom proxy and is responsible for routing both shapes through it) +// +// Returning two strings rather than mutating a *Config keeps this pure and +// trivially testable; callers unpack their own provider config. +func VertexBaseURLs(baseURL, project, location string) (openAICompatibleBaseURL, nativeBaseURL string) { + baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/") + if baseURL == "" { + project = strings.TrimSpace(project) + location = strings.TrimSpace(location) + root := "https://aiplatform.googleapis.com/v1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location) + return root + "/endpoints/openapi", root + "/publishers/google" + } + if nativeBaseURL, ok := VertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL); ok { + return baseURL, nativeBaseURL + } + if openAIBaseURL, ok := VertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL); ok { + return openAIBaseURL, baseURL + } + return baseURL, baseURL +} + +// VertexNativeBaseURLFromOpenAICompatibleBaseURL converts an +// /endpoints/openapi base URL into the matching /publishers/google native URL. +// Returns ok=false when the input does not end with /endpoints/openapi. +func VertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL string) (string, bool) { + const suffix = "/endpoints/openapi" + if !strings.HasSuffix(baseURL, suffix) { + return "", false + } + root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/") + if root == "" { + return "", false + } + return root + "/publishers/google", true +} + +// VertexOpenAICompatibleBaseURLFromNativeBaseURL converts a /publishers/google +// native base URL into the matching /endpoints/openapi OpenAI-compatible URL. +// Returns ok=false when the input does not end with /publishers/google. +func VertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL string) (string, bool) { + const suffix = "/publishers/google" + if !strings.HasSuffix(baseURL, suffix) { + return "", false + } + root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/") + if root == "" { + return "", false + } + return root + "/endpoints/openapi", true +} diff --git a/internal/providers/vertex/vertex.go b/internal/providers/vertex/vertex.go index 7ffb6144..963fc13b 100644 --- a/internal/providers/vertex/vertex.go +++ b/internal/providers/vertex/vertex.go @@ -18,7 +18,7 @@ import ( "gomodel/internal/llmclient" "gomodel/internal/providers" "gomodel/internal/providers/gemini" - "gomodel/internal/providers/googleauth" + "gomodel/internal/providers/googlecommon" ) // Registration provides factory registration for the Vertex AI provider. @@ -88,7 +88,7 @@ func (p *Provider) validateConfig(providerCfg providers.ProviderConfig) { case authTypeGCPADC: return case authTypeServiceAccount: - if googleauth.HasServiceAccount(buildGoogleAuthConfig(providerCfg)) { + if googlecommon.HasServiceAccount(buildGoogleAuthConfig(providerCfg)) { return } p.configErr = fmt.Errorf("vertex AI service account auth requires service_account_file, service_account_json, or service_account_json_base64") @@ -112,7 +112,7 @@ func hasResolvedProviderValue(value string) bool { } func normalizeAuthType(providerCfg providers.ProviderConfig) string { - return googleauth.NormalizeAuthType(providerCfg.AuthType, googleauth.HasServiceAccount(buildGoogleAuthConfig(providerCfg))) + return googlecommon.NormalizeAuthType(providerCfg.AuthType, googlecommon.HasServiceAccount(buildGoogleAuthConfig(providerCfg))) } func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *http.Client) *http.Client { @@ -121,7 +121,7 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht } authCfg := buildGoogleAuthConfig(providerCfg) authCfg.AuthType = p.authType - creds, err := googleauth.FindCredentials(context.Background(), authCfg) + creds, err := googlecommon.FindCredentials(context.Background(), authCfg) if err != nil { p.configErr = err return base @@ -133,11 +133,11 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht if strings.TrimSpace(quotaProject) == "" { quotaProject = strings.TrimSpace(providerCfg.VertexProject) } - return googleauth.HTTPClient(base, creds.TokenSource, quotaProject) + return googlecommon.HTTPClient(base, creds.TokenSource, quotaProject) } -func buildGoogleAuthConfig(providerCfg providers.ProviderConfig) googleauth.Config { - return googleauth.Config{ +func buildGoogleAuthConfig(providerCfg providers.ProviderConfig) googlecommon.Config { + return googlecommon.Config{ AuthType: providerCfg.AuthType, ServiceAccountFile: providerCfg.ServiceAccountFile, ServiceAccountJSON: providerCfg.ServiceAccountJSON, @@ -344,54 +344,10 @@ func encodeEmbedding(values []float64, encodingFormat string) (json.RawMessage, } func vertexNativeBaseURL(providerCfg providers.ProviderConfig) string { - _, nativeBaseURL := vertexBaseURLs(providerCfg) + _, nativeBaseURL := googlecommon.VertexBaseURLs(providerCfg.BaseURL, providerCfg.VertexProject, providerCfg.VertexLocation) return nativeBaseURL } -// TODO: Share Vertex URL derivation with the Gemini Vertex path if this logic -// changes again. It is intentionally duplicated today to keep provider package -// boundaries simple. -func vertexBaseURLs(providerCfg providers.ProviderConfig) (openAICompatibleBaseURL, nativeBaseURL string) { - baseURL := strings.TrimRight(strings.TrimSpace(providerCfg.BaseURL), "/") - if baseURL == "" { - project := strings.TrimSpace(providerCfg.VertexProject) - location := strings.TrimSpace(providerCfg.VertexLocation) - root := "https://aiplatform.googleapis.com/v1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location) - return root + "/endpoints/openapi", root + "/publishers/google" - } - if nativeBaseURL, ok := vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL); ok { - return baseURL, nativeBaseURL - } - if openAIBaseURL, ok := vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL); ok { - return openAIBaseURL, baseURL - } - return baseURL, baseURL -} - -func vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL string) (string, bool) { - const suffix = "/endpoints/openapi" - if !strings.HasSuffix(baseURL, suffix) { - return "", false - } - root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/") - if root == "" { - return "", false - } - return root + "/publishers/google", true -} - -func vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL string) (string, bool) { - const suffix = "/publishers/google" - if !strings.HasSuffix(baseURL, suffix) { - return "", false - } - root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/") - if root == "" { - return "", false - } - return root + "/endpoints/openapi", true -} - func vertexPredictEndpoint(model string) string { model = normalizeVertexModelID(model) return "/models/" + url.PathEscape(model) + ":predict" diff --git a/internal/providers/vertex/vertex_test.go b/internal/providers/vertex/vertex_test.go index 552bc186..2f94a853 100644 --- a/internal/providers/vertex/vertex_test.go +++ b/internal/providers/vertex/vertex_test.go @@ -19,7 +19,7 @@ import ( "gomodel/internal/core" "gomodel/internal/providers" - "gomodel/internal/providers/googleauth" + "gomodel/internal/providers/googlecommon" "golang.org/x/oauth2" ) @@ -330,7 +330,7 @@ func TestVertexBaseURLs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotCompat, gotNative := vertexBaseURLs(tt.cfg) + gotCompat, gotNative := googlecommon.VertexBaseURLs(tt.cfg.BaseURL, tt.cfg.VertexProject, tt.cfg.VertexLocation) if gotCompat != tt.wantCompat { t.Fatalf("OpenAI-compatible base = %q, want %q", gotCompat, tt.wantCompat) } @@ -351,7 +351,7 @@ func testConfig() providers.ProviderConfig { } func authedTestClient(base *http.Client) *http.Client { - return googleauth.HTTPClient(base, oauth2.StaticTokenSource(&oauth2.Token{ + return googlecommon.HTTPClient(base, oauth2.StaticTokenSource(&oauth2.Token{ AccessToken: "vertex-token", TokenType: "Bearer", }), "")