diff --git a/.env.template b/.env.template index 9f701f91..1fb0864f 100644 --- a/.env.template +++ b/.env.template @@ -114,6 +114,8 @@ # fallback (default): use configured models only when upstream /models fails, is nil, or is empty. # allowlist: expose only the configured models for providers that define a list, and skip their upstream /models calls. # CONFIGURED_PROVIDER_MODELS_MODE=fallback +# Model ID format for GET /v1/models: "qualified" (provider/model), "unqualified" (model only), or "both" +# MODELS_ENDPOINT_ID_FORMAT=qualified # Examples: OPENROUTER_MODELS=..., OPENROUTER_EU_MODELS=..., AZURE_MODELS=..., VLLM_MODELS=... # Fallback & Workflow Configuration diff --git a/config/config.example.yaml b/config/config.example.yaml index 24ea3bed..b85cf301 100644 --- a/config/config.example.yaml +++ b/config/config.example.yaml @@ -19,6 +19,7 @@ models: enabled_by_default: true # env: MODELS_ENABLED_BY_DEFAULT; when false, models stay unavailable until an override allows one or more user paths overrides_enabled: true # env: MODEL_OVERRIDES_ENABLED; load/enforce persisted model overrides and enable dashboard editing configured_provider_models_mode: "fallback" # env: CONFIGURED_PROVIDER_MODELS_MODE; "fallback" uses configured lists only when upstream /models is unavailable/empty, "allowlist" exposes only configured models and skips upstream /models for configured lists + models_endpoint_id_format: "qualified" # env: MODELS_ENDPOINT_ID_FORMAT; "qualified" (provider/model), "unqualified" (model only), or "both" cache: model: diff --git a/config/config.go b/config/config.go index 9c6412c1..f3739587 100644 --- a/config/config.go +++ b/config/config.go @@ -62,6 +62,7 @@ func buildDefaultConfig() *Config { OverridesEnabled: true, KeepOnlyAliasesAtModelsEndpoint: false, ConfiguredProviderModelsMode: ConfiguredProviderModelsModeFallback, + ModelsEndpointIDFormat: ModelsEndpointIDFormatQualified, }, Cache: CacheConfig{ Model: ModelCacheConfig{ diff --git a/config/models.go b/config/models.go index aa8b2c29..bf8b0296 100644 --- a/config/models.go +++ b/config/models.go @@ -23,6 +23,11 @@ type ModelsConfig struct { // provider *_MODELS env vars affect the provider model inventory. // Supported values: "fallback", "allowlist". Default: "fallback". ConfiguredProviderModelsMode ConfiguredProviderModelsMode `yaml:"configured_provider_models_mode" env:"CONFIGURED_PROVIDER_MODELS_MODE"` + + // ModelsEndpointIDFormat controls the model ID format returned by GET /v1/models. + // Supported values: "qualified" (provider/model), "unqualified" (model), "both". + // Default: "qualified". + ModelsEndpointIDFormat ModelsEndpointIDFormat `yaml:"models_endpoint_id_format" env:"MODELS_ENDPOINT_ID_FORMAT"` } // ConfiguredProviderModelsMode controls how explicitly configured provider @@ -57,3 +62,36 @@ func ResolveConfiguredProviderModelsMode(mode ConfiguredProviderModelsMode) Conf } return mode } + +// ModelsEndpointIDFormat controls the model ID format returned by GET /v1/models. +type ModelsEndpointIDFormat string + +const ( + ModelsEndpointIDFormatQualified ModelsEndpointIDFormat = "qualified" + ModelsEndpointIDFormatUnqualified ModelsEndpointIDFormat = "unqualified" + ModelsEndpointIDFormatBoth ModelsEndpointIDFormat = "both" +) + +// Valid reports whether f is one of the supported models endpoint ID formats. +func (f ModelsEndpointIDFormat) Valid() bool { + switch NormalizeModelsEndpointIDFormat(f) { + case ModelsEndpointIDFormatQualified, ModelsEndpointIDFormatUnqualified, ModelsEndpointIDFormatBoth: + return true + default: + return false + } +} + +// NormalizeModelsEndpointIDFormat canonicalizes a models endpoint ID format value. +func NormalizeModelsEndpointIDFormat(f ModelsEndpointIDFormat) ModelsEndpointIDFormat { + return ModelsEndpointIDFormat(strings.ToLower(strings.TrimSpace(string(f)))) +} + +// ResolveModelsEndpointIDFormat canonicalizes f and applies the process default. +func ResolveModelsEndpointIDFormat(f ModelsEndpointIDFormat) ModelsEndpointIDFormat { + f = NormalizeModelsEndpointIDFormat(f) + if f == "" { + return ModelsEndpointIDFormatQualified + } + return f +} diff --git a/internal/app/app.go b/internal/app/app.go index 2274be25..067b14c9 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -107,6 +107,7 @@ func New(ctx context.Context, cfg Config) (*App, error) { return nil, fmt.Errorf("failed to initialize providers: %w", err) } app.providers = providerResult + app.providers.Router.SetModelsEndpointIDFormat(appCfg.Models.ModelsEndpointIDFormat) // Initialize audit logging auditResult, err := auditlog.New(ctx, appCfg) diff --git a/internal/providers/registry.go b/internal/providers/registry.go index d9cb28e1..5fe6b34d 100644 --- a/internal/providers/registry.go +++ b/internal/providers/registry.go @@ -344,7 +344,60 @@ func (r *ModelRegistry) ListPublicModels() []core.Model { return result } -// ModelCount returns the number of registered models +// ListModelsWithFormat returns models formatted according to the given ID format. +func (r *ModelRegistry) ListModelsWithFormat(format config.ModelsEndpointIDFormat) []core.Model { + format = config.ResolveModelsEndpointIDFormat(format) + switch format { + case config.ModelsEndpointIDFormatUnqualified: + return r.listModelsUnqualified() + case config.ModelsEndpointIDFormatBoth: + return r.listModelsBoth() + default: + return r.ListPublicModels() + } +} + +func (r *ModelRegistry) listModelsUnqualified() []core.Model { + r.mu.RLock() + defer r.mu.RUnlock() + + seen := make(map[string]struct{}) + result := make([]core.Model, 0, len(r.models)) + for providerName, models := range r.modelsByProvider { + for modelID, info := range models { + if _, exists := seen[modelID]; exists { + continue + } + seen[modelID] = struct{}{} + model := info.Model + model.ID = modelID + model.OwnedBy = providerName + result = append(result, model) + } + } + sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID }) + return result +} + +func (r *ModelRegistry) listModelsBoth() []core.Model { + qualified := r.ListPublicModels() + unqualified := r.listModelsUnqualified() + + seen := make(map[string]struct{}, len(qualified)+len(unqualified)) + result := make([]core.Model, 0, len(qualified)+len(unqualified)) + for _, m := range qualified { + seen[m.ID] = struct{}{} + result = append(result, m) + } + for _, m := range unqualified { + if _, exists := seen[m.ID]; exists { + continue + } + result = append(result, m) + } + sort.Slice(result, func(i, j int) bool { return result[i].ID < result[j].ID }) + return result +} func (r *ModelRegistry) ModelCount() int { r.mu.RLock() defer r.mu.RUnlock() diff --git a/internal/providers/router.go b/internal/providers/router.go index d5df805b..d5dc3dba 100644 --- a/internal/providers/router.go +++ b/internal/providers/router.go @@ -11,6 +11,7 @@ import ( "sort" "strings" + "gomodel/config" "gomodel/internal/core" ) @@ -21,7 +22,8 @@ var ErrRegistryNotInitialized = fmt.Errorf("model registry has no models: ensure // It uses a dynamic model-to-provider mapping that is populated at startup // by fetching available models from each provider's /models endpoint. type Router struct { - lookup core.ModelLookup + lookup core.ModelLookup + modelsEndpointIDFormat config.ModelsEndpointIDFormat } type providerTypeRegistry interface { @@ -64,10 +66,16 @@ func NewRouter(lookup core.ModelLookup) (*Router, error) { return nil, fmt.Errorf("lookup cannot be nil") } return &Router{ - lookup: lookup, + lookup: lookup, + modelsEndpointIDFormat: config.ModelsEndpointIDFormatQualified, }, nil } +// SetModelsEndpointIDFormat configures the model ID format for GET /v1/models. +func (r *Router) SetModelsEndpointIDFormat(format config.ModelsEndpointIDFormat) { + r.modelsEndpointIDFormat = config.ResolveModelsEndpointIDFormat(format) +} + // checkReady verifies the lookup has models available. // Returns ErrRegistryNotInitialized if no models are loaded. func (r *Router) checkReady() error { @@ -527,7 +535,11 @@ func (r *Router) ListModels(_ context.Context) (*core.ModelsResponse, error) { return nil, registryUnavailableError(err) } var models []core.Model - if public, ok := r.lookup.(publicModelLister); ok { + if formatted, ok := r.lookup.(interface { + ListModelsWithFormat(config.ModelsEndpointIDFormat) []core.Model + }); ok { + models = formatted.ListModelsWithFormat(r.modelsEndpointIDFormat) + } else if public, ok := r.lookup.(publicModelLister); ok { models = public.ListPublicModels() } else { models = r.lookup.ListModels() diff --git a/internal/providers/router_test.go b/internal/providers/router_test.go index fc8b8523..201cd4bc 100644 --- a/internal/providers/router_test.go +++ b/internal/providers/router_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "gomodel/config" "gomodel/internal/core" ) @@ -841,6 +842,70 @@ func TestRouterListModels(t *testing.T) { } } +func TestRouterListModelsWithFormat(t *testing.T) { + registry := newTestRegistryWithModels( + registryModelEntry{provider: &mockProvider{}, providerName: "anthropic", providerType: "anthropic", modelID: "claude-sonnet-4-6"}, + registryModelEntry{provider: &mockProvider{}, providerName: "openai", providerType: "openai", modelID: "gpt-4o"}, + ) + + t.Run("qualified", func(t *testing.T) { + router, _ := NewRouter(registry) + router.SetModelsEndpointIDFormat(config.ModelsEndpointIDFormatQualified) + + resp, err := router.ListModels(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, m := range resp.Data { + if !strings.Contains(m.ID, "/") { + t.Errorf("expected qualified ID with '/', got %q", m.ID) + } + } + }) + + t.Run("unqualified", func(t *testing.T) { + router, _ := NewRouter(registry) + router.SetModelsEndpointIDFormat(config.ModelsEndpointIDFormatUnqualified) + + resp, err := router.ListModels(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 2 { + t.Fatalf("expected 2 models, got %d", len(resp.Data)) + } + for _, m := range resp.Data { + if strings.Contains(m.ID, "/") { + t.Errorf("expected unqualified ID without '/', got %q", m.ID) + } + } + }) + + t.Run("both", func(t *testing.T) { + router, _ := NewRouter(registry) + router.SetModelsEndpointIDFormat(config.ModelsEndpointIDFormatBoth) + + resp, err := router.ListModels(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Data) != 4 { + t.Fatalf("expected 4 models (2 qualified + 2 unqualified), got %d", len(resp.Data)) + } + var qualified, unqualified int + for _, m := range resp.Data { + if strings.Contains(m.ID, "/") { + qualified++ + } else { + unqualified++ + } + } + if qualified != 2 || unqualified != 2 { + t.Errorf("expected 2 qualified + 2 unqualified, got %d + %d", qualified, unqualified) + } + }) +} + func TestRouterGetProviderType(t *testing.T) { lookup := newMockLookup() lookup.addModel("gpt-4o", &mockProvider{}, "openai")