diff --git a/internal/core/interfaces.go b/internal/core/interfaces.go index 3dfc05b8..eb87f635 100644 --- a/internal/core/interfaces.go +++ b/internal/core/interfaces.go @@ -194,4 +194,19 @@ type ModelLookup interface { // ModelCount returns the number of registered models ModelCount() int + + // GetProviderName maps a model selector back to the concrete configured + // provider instance name. Implementations that have no such mapping return + // an empty string. Same shape as the optional ProviderNameResolver + // interface used elsewhere for provider-side type assertions. + GetProviderName(model string) string + + // GetProviderNameForType maps a provider type such as "openai" to the + // concrete configured instance name chosen for routing, e.g. + // "openai-primary". Returns empty when no mapping exists. + GetProviderNameForType(providerType string) string + + // GetProviderTypeForName maps a concrete configured instance name back to + // its provider type. Returns empty when no mapping exists. + GetProviderTypeForName(providerName string) string } diff --git a/internal/providers/router.go b/internal/providers/router.go index 4ea165f3..d5df805b 100644 --- a/internal/providers/router.go +++ b/internal/providers/router.go @@ -112,12 +112,7 @@ func (r *Router) resolveUnqualifiedSelector(selector core.ModelSelector) (core.M if selector.Provider != "" || strings.TrimSpace(selector.Model) == "" { return core.ModelSelector{}, false } - - named, ok := r.lookup.(core.ProviderNameResolver) - if !ok { - return core.ModelSelector{}, false - } - providerName := strings.TrimSpace(named.GetProviderName(selector.Model)) + providerName := strings.TrimSpace(r.lookup.GetProviderName(selector.Model)) if providerName == "" { return core.ModelSelector{}, false } @@ -613,19 +608,13 @@ func (r *Router) GetProviderName(model string) string { if selector.Provider != "" { return selector.Provider } - if named, ok := r.lookup.(core.ProviderNameResolver); ok { - return named.GetProviderName(selector.QualifiedModel()) - } - return "" + return r.lookup.GetProviderName(selector.QualifiedModel()) } // GetProviderNameForType returns the concrete configured provider instance name // chosen for a provider-typed route. func (r *Router) GetProviderNameForType(providerType string) string { - if named, ok := r.lookup.(core.ProviderTypeNameResolver); ok { - return strings.TrimSpace(named.GetProviderNameForType(providerType)) - } - return "" + return strings.TrimSpace(r.lookup.GetProviderNameForType(providerType)) } // GetProviderTypeForName returns the provider type for a concrete configured @@ -635,20 +624,7 @@ func (r *Router) GetProviderTypeForName(providerName string) string { if providerName == "" { return "" } - if typed, ok := r.lookup.(core.ProviderNameTypeResolver); ok { - return strings.TrimSpace(typed.GetProviderTypeForName(providerName)) - } - if models, ok := r.lookup.(modelWithProviderLister); ok { - for _, entry := range models.ListModelsWithProvider() { - if strings.TrimSpace(entry.ProviderName) != providerName { - continue - } - if providerType := strings.TrimSpace(entry.ProviderType); providerType != "" { - return providerType - } - } - } - return "" + return strings.TrimSpace(r.lookup.GetProviderTypeForName(providerName)) } func (r *Router) providerByType(providerType string) core.Provider { diff --git a/internal/providers/router_test.go b/internal/providers/router_test.go index efad078c..fc8b8523 100644 --- a/internal/providers/router_test.go +++ b/internal/providers/router_test.go @@ -104,6 +104,13 @@ func (m *mockModelLookup) ModelCount() int { return len(m.models) } +// The mock keeps no provider-name <-> type mapping, so the three resolver +// methods always return empty. Tests that need provider-name routing use the +// real ModelRegistry via newTestRegistryWithModels instead of this mock. +func (m *mockModelLookup) GetProviderName(_ string) string { return "" } +func (m *mockModelLookup) GetProviderNameForType(_ string) string { return "" } +func (m *mockModelLookup) GetProviderTypeForName(_ string) string { return "" } + // mockProvider is a simple mock implementation of core.Provider for testing type mockProvider struct { name string