Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 5 additions & 46 deletions internal/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"gomodel/internal/httpclient"
"gomodel/internal/llmclient"
"gomodel/internal/providers"
"gomodel/internal/providers/googleauth"
"gomodel/internal/providers/googlecommon"
)

// Registration provides factory registration for the Gemini provider.
Expand Down Expand Up @@ -190,7 +190,7 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht
if p.configErr != nil || p.authType == geminiAuthTypeAPIKey {
return base
}
creds, err := googleauth.FindCredentials(context.Background(), googleauth.Config{
creds, err := googlecommon.FindCredentials(context.Background(), googlecommon.Config{
AuthType: p.authType,
ServiceAccountFile: providerCfg.ServiceAccountFile,
ServiceAccountJSON: providerCfg.ServiceAccountJSON,
Expand All @@ -208,7 +208,7 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht
if strings.TrimSpace(quotaProject) == "" {
quotaProject = strings.TrimSpace(providerCfg.VertexProject)
}
return googleauth.HTTPClient(base, creds.TokenSource, quotaProject)
return googlecommon.HTTPClient(base, creds.TokenSource, quotaProject)
}

func (p *Provider) ready() error {
Expand Down Expand Up @@ -273,7 +273,7 @@ func normalizeGeminiAuthType(backend string, cfg providers.ProviderConfig) strin
switch authType {
case "":
if backend == geminiBackendVertex {
return googleauth.NormalizeAuthType(authType, googleauth.HasServiceAccount(googleauth.Config{
return googlecommon.NormalizeAuthType(authType, googlecommon.HasServiceAccount(googlecommon.Config{
ServiceAccountFile: cfg.ServiceAccountFile,
ServiceAccountJSON: cfg.ServiceAccountJSON,
ServiceAccountJSONBase64: cfg.ServiceAccountJSONBase64,
Expand Down Expand Up @@ -319,7 +319,7 @@ func useNativeAPIFromEnv() bool {

func geminiBaseURLs(providerCfg providers.ProviderConfig, backend string) (openAICompatibleBaseURL, nativeBaseURL string) {
if backend == geminiBackendVertex {
return vertexBaseURLs(providerCfg)
return googlecommon.VertexBaseURLs(providerCfg.BaseURL, providerCfg.VertexProject, providerCfg.VertexLocation)
}
configuredBaseURL := providerCfg.BaseURL
baseURL := strings.TrimRight(strings.TrimSpace(configuredBaseURL), "/")
Expand Down Expand Up @@ -349,47 +349,6 @@ func geminiModelsBaseURL(backend, nativeBaseURL string) string {
return nativeBaseURL
}

func vertexBaseURLs(providerCfg providers.ProviderConfig) (openAICompatibleBaseURL, nativeBaseURL string) {
baseURL := strings.TrimRight(strings.TrimSpace(providerCfg.BaseURL), "/")
if baseURL == "" {
project := strings.TrimSpace(providerCfg.VertexProject)
location := strings.TrimSpace(providerCfg.VertexLocation)
root := "https://aiplatform.googleapis.com/v1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location)
return root + "/endpoints/openapi", root + "/publishers/google"
}
if nativeBaseURL, ok := vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL); ok {
return baseURL, nativeBaseURL
}
if openAIBaseURL, ok := vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL); ok {
return openAIBaseURL, baseURL
}
return baseURL, baseURL
}

func vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL string) (string, bool) {
const suffix = "/endpoints/openapi"
if !strings.HasSuffix(baseURL, suffix) {
return "", false
}
root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/")
if root == "" {
return "", false
}
return root + "/publishers/google", true
}

func vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL string) (string, bool) {
const suffix = "/publishers/google"
if !strings.HasSuffix(baseURL, suffix) {
return "", false
}
root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/")
if root == "" {
return "", false
}
return root + "/endpoints/openapi", true
}

func vertexPublisherModelsBaseURL(nativeBaseURL string) (string, bool) {
const projectsPath = "/v1/projects/"
nativeBaseURL = strings.TrimRight(strings.TrimSpace(nativeBaseURL), "/")
Expand Down
4 changes: 2 additions & 2 deletions internal/providers/gemini/gemini_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"gomodel/internal/core"
"gomodel/internal/llmclient"
"gomodel/internal/providers"
"gomodel/internal/providers/googleauth"
"gomodel/internal/providers/googlecommon"

"golang.org/x/oauth2"
)
Expand Down Expand Up @@ -748,7 +748,7 @@ func TestVertexListModelsErrorsUseVertexProviderName(t *testing.T) {
}

func newVertexTestProvider(server *httptest.Server, native bool) *Provider {
tokenClient := googleauth.HTTPClient(server.Client(), oauth2.StaticTokenSource(&oauth2.Token{
tokenClient := googlecommon.HTTPClient(server.Client(), oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "vertex-token",
TokenType: "Bearer",
}), "")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
package googleauth
// Package googlecommon holds infrastructure shared by GoModel's Google-backed
// providers (Gemini AI Studio + Vertex AI). It currently covers authentication
// (ADC / service-account TokenSource resolution + quota project propagation
// via X-Goog-User-Project) and Vertex URL transformations between the native
// publisher endpoint and the OpenAI-compatible endpoint.
package googlecommon

import (
"context"
Expand Down Expand Up @@ -122,17 +127,6 @@ func resolveQuotaProject(creds *google.Credentials) string {
return strings.TrimSpace(raw.QuotaProjectID)
}

// TokenSource returns just the OAuth2 token source for the given config. It is
// a thin wrapper around FindCredentials kept for callers that do not need the
// resolved quota project.
func TokenSource(ctx context.Context, cfg Config) (oauth2.TokenSource, error) {
creds, err := FindCredentials(ctx, cfg)
if err != nil {
return nil, err
}
return creds.TokenSource, nil
}

func serviceAccountJSON(cfg Config) ([]byte, error) {
if value := strings.TrimSpace(cfg.ServiceAccountJSONBase64); value != "" {
decoded, err := decodeServiceAccountBase64(value)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package googleauth
package googlecommon

import (
"context"
Expand Down Expand Up @@ -55,7 +55,7 @@ func TestServiceAccountJSONReportsOriginalBase64DecodeError(t *testing.T) {
}
}

func TestTokenSourceAndHTTPClientAuthSelection(t *testing.T) {
func TestFindCredentialsAndHTTPClientAuthSelection(t *testing.T) {
tests := []struct {
name string
cfg func(t *testing.T, tokenURL string) Config
Expand Down Expand Up @@ -110,10 +110,11 @@ func TestTokenSourceAndHTTPClientAuthSelection(t *testing.T) {
}))
defer tokenServer.Close()

source, err := TokenSource(context.Background(), tt.cfg(t, tokenServer.URL))
creds, err := FindCredentials(context.Background(), tt.cfg(t, tokenServer.URL))
if err != nil {
t.Fatalf("TokenSource() error = %v", err)
t.Fatalf("FindCredentials() error = %v", err)
}
source := creds.TokenSource

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer "+tt.wantToken {
Expand Down
66 changes: 66 additions & 0 deletions internal/providers/googlecommon/urls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package googlecommon

import (
"net/url"
"strings"
)

// VertexBaseURLs derives Vertex AI's two base URLs (OpenAI-compatible endpoint
// + native publisher endpoint) from either an operator-supplied baseURL or the
// project/location pair. When baseURL is empty the canonical aiplatform.googleapis.com
// host is used. When baseURL is supplied:
// - if it ends in /endpoints/openapi it is treated as the OpenAI-compatible
// surface and the native URL is derived by replacing the suffix
// - if it ends in /publishers/google it is treated as the native surface and
// the OpenAI-compatible URL is derived likewise
// - otherwise both bases are set to the verbatim baseURL (caller has wired
// up a custom proxy and is responsible for routing both shapes through it)
//
// Returning two strings rather than mutating a *Config keeps this pure and
// trivially testable; callers unpack their own provider config.
func VertexBaseURLs(baseURL, project, location string) (openAICompatibleBaseURL, nativeBaseURL string) {
baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
if baseURL == "" {
project = strings.TrimSpace(project)
location = strings.TrimSpace(location)
root := "https://aiplatform.googleapis.com/v1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location)
return root + "/endpoints/openapi", root + "/publishers/google"
}
if nativeBaseURL, ok := VertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL); ok {
return baseURL, nativeBaseURL
}
if openAIBaseURL, ok := VertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL); ok {
return openAIBaseURL, baseURL
}
return baseURL, baseURL
}

// VertexNativeBaseURLFromOpenAICompatibleBaseURL converts an
// /endpoints/openapi base URL into the matching /publishers/google native URL.
// Returns ok=false when the input does not end with /endpoints/openapi.
func VertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL string) (string, bool) {
const suffix = "/endpoints/openapi"
if !strings.HasSuffix(baseURL, suffix) {
return "", false
}
root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/")
if root == "" {
return "", false
}
return root + "/publishers/google", true
}

// VertexOpenAICompatibleBaseURLFromNativeBaseURL converts a /publishers/google
// native base URL into the matching /endpoints/openapi OpenAI-compatible URL.
// Returns ok=false when the input does not end with /publishers/google.
func VertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL string) (string, bool) {
const suffix = "/publishers/google"
if !strings.HasSuffix(baseURL, suffix) {
return "", false
}
root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/")
if root == "" {
return "", false
}
return root + "/endpoints/openapi", true
}
60 changes: 8 additions & 52 deletions internal/providers/vertex/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"gomodel/internal/llmclient"
"gomodel/internal/providers"
"gomodel/internal/providers/gemini"
"gomodel/internal/providers/googleauth"
"gomodel/internal/providers/googlecommon"
)

// Registration provides factory registration for the Vertex AI provider.
Expand Down Expand Up @@ -88,7 +88,7 @@ func (p *Provider) validateConfig(providerCfg providers.ProviderConfig) {
case authTypeGCPADC:
return
case authTypeServiceAccount:
if googleauth.HasServiceAccount(buildGoogleAuthConfig(providerCfg)) {
if googlecommon.HasServiceAccount(buildGoogleAuthConfig(providerCfg)) {
return
}
p.configErr = fmt.Errorf("vertex AI service account auth requires service_account_file, service_account_json, or service_account_json_base64")
Expand All @@ -112,7 +112,7 @@ func hasResolvedProviderValue(value string) bool {
}

func normalizeAuthType(providerCfg providers.ProviderConfig) string {
return googleauth.NormalizeAuthType(providerCfg.AuthType, googleauth.HasServiceAccount(buildGoogleAuthConfig(providerCfg)))
return googlecommon.NormalizeAuthType(providerCfg.AuthType, googlecommon.HasServiceAccount(buildGoogleAuthConfig(providerCfg)))
}

func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *http.Client) *http.Client {
Expand All @@ -121,7 +121,7 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht
}
authCfg := buildGoogleAuthConfig(providerCfg)
authCfg.AuthType = p.authType
creds, err := googleauth.FindCredentials(context.Background(), authCfg)
creds, err := googlecommon.FindCredentials(context.Background(), authCfg)
if err != nil {
p.configErr = err
return base
Expand All @@ -133,11 +133,11 @@ func (p *Provider) authHTTPClient(providerCfg providers.ProviderConfig, base *ht
if strings.TrimSpace(quotaProject) == "" {
quotaProject = strings.TrimSpace(providerCfg.VertexProject)
}
return googleauth.HTTPClient(base, creds.TokenSource, quotaProject)
return googlecommon.HTTPClient(base, creds.TokenSource, quotaProject)
}

func buildGoogleAuthConfig(providerCfg providers.ProviderConfig) googleauth.Config {
return googleauth.Config{
func buildGoogleAuthConfig(providerCfg providers.ProviderConfig) googlecommon.Config {
return googlecommon.Config{
AuthType: providerCfg.AuthType,
ServiceAccountFile: providerCfg.ServiceAccountFile,
ServiceAccountJSON: providerCfg.ServiceAccountJSON,
Expand Down Expand Up @@ -344,54 +344,10 @@ func encodeEmbedding(values []float64, encodingFormat string) (json.RawMessage,
}

func vertexNativeBaseURL(providerCfg providers.ProviderConfig) string {
_, nativeBaseURL := vertexBaseURLs(providerCfg)
_, nativeBaseURL := googlecommon.VertexBaseURLs(providerCfg.BaseURL, providerCfg.VertexProject, providerCfg.VertexLocation)
return nativeBaseURL
}

// TODO: Share Vertex URL derivation with the Gemini Vertex path if this logic
// changes again. It is intentionally duplicated today to keep provider package
// boundaries simple.
func vertexBaseURLs(providerCfg providers.ProviderConfig) (openAICompatibleBaseURL, nativeBaseURL string) {
baseURL := strings.TrimRight(strings.TrimSpace(providerCfg.BaseURL), "/")
if baseURL == "" {
project := strings.TrimSpace(providerCfg.VertexProject)
location := strings.TrimSpace(providerCfg.VertexLocation)
root := "https://aiplatform.googleapis.com/v1/projects/" + url.PathEscape(project) + "/locations/" + url.PathEscape(location)
return root + "/endpoints/openapi", root + "/publishers/google"
}
if nativeBaseURL, ok := vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL); ok {
return baseURL, nativeBaseURL
}
if openAIBaseURL, ok := vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL); ok {
return openAIBaseURL, baseURL
}
return baseURL, baseURL
}

func vertexNativeBaseURLFromOpenAICompatibleBaseURL(baseURL string) (string, bool) {
const suffix = "/endpoints/openapi"
if !strings.HasSuffix(baseURL, suffix) {
return "", false
}
root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/")
if root == "" {
return "", false
}
return root + "/publishers/google", true
}

func vertexOpenAICompatibleBaseURLFromNativeBaseURL(baseURL string) (string, bool) {
const suffix = "/publishers/google"
if !strings.HasSuffix(baseURL, suffix) {
return "", false
}
root := strings.TrimRight(strings.TrimSuffix(baseURL, suffix), "/")
if root == "" {
return "", false
}
return root + "/endpoints/openapi", true
}

func vertexPredictEndpoint(model string) string {
model = normalizeVertexModelID(model)
return "/models/" + url.PathEscape(model) + ":predict"
Expand Down
6 changes: 3 additions & 3 deletions internal/providers/vertex/vertex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

"gomodel/internal/core"
"gomodel/internal/providers"
"gomodel/internal/providers/googleauth"
"gomodel/internal/providers/googlecommon"

"golang.org/x/oauth2"
)
Expand Down Expand Up @@ -330,7 +330,7 @@ func TestVertexBaseURLs(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotCompat, gotNative := vertexBaseURLs(tt.cfg)
gotCompat, gotNative := googlecommon.VertexBaseURLs(tt.cfg.BaseURL, tt.cfg.VertexProject, tt.cfg.VertexLocation)
if gotCompat != tt.wantCompat {
t.Fatalf("OpenAI-compatible base = %q, want %q", gotCompat, tt.wantCompat)
}
Expand All @@ -351,7 +351,7 @@ func testConfig() providers.ProviderConfig {
}

func authedTestClient(base *http.Client) *http.Client {
return googleauth.HTTPClient(base, oauth2.StaticTokenSource(&oauth2.Token{
return googlecommon.HTTPClient(base, oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "vertex-token",
TokenType: "Bearer",
}), "")
Expand Down