diff --git a/openfeature/provider/_registry.py b/openfeature/provider/_registry.py index d3656c60..bf8fa9a8 100644 --- a/openfeature/provider/_registry.py +++ b/openfeature/provider/_registry.py @@ -30,9 +30,12 @@ def set_provider(self, domain: str, provider: FeatureProvider) -> None: if domain in providers: old_provider = providers[domain] del providers[domain] - if old_provider not in providers.values(): + if ( + old_provider != self._default_provider + and old_provider not in providers.values() + ): self._shutdown_provider(old_provider) - if provider not in providers.values(): + if provider != self._default_provider and provider not in providers.values(): self._initialize_provider(provider) providers[domain] = provider @@ -44,10 +47,15 @@ def get_provider(self, domain: str | None) -> FeatureProvider: def set_default_provider(self, provider: FeatureProvider) -> None: if provider is None: raise GeneralError(error_message="No provider") - if self._default_provider: + if ( + self._default_provider + and self._default_provider not in self._providers.values() + ): self._shutdown_provider(self._default_provider) self._default_provider = provider - self._initialize_provider(provider) + + if self._default_provider not in self._providers.values(): + self._initialize_provider(provider) def get_default_provider(self) -> FeatureProvider: return self._default_provider @@ -94,7 +102,7 @@ def _shutdown_provider(self, provider: FeatureProvider) -> None: try: if hasattr(provider, "shutdown"): provider.shutdown() - self._provider_status[provider] = ProviderStatus.NOT_READY + del self._provider_status[provider] except Exception as err: self.dispatch_event( provider, diff --git a/tests/provider/test_registry.py b/tests/provider/test_registry.py index 22613329..b5e10503 100644 --- a/tests/provider/test_registry.py +++ b/tests/provider/test_registry.py @@ -3,6 +3,7 @@ import pytest from openfeature.exception import GeneralError +from openfeature.provider import ProviderStatus from openfeature.provider._registry import ProviderRegistry from openfeature.provider.no_op_provider import NoOpProvider @@ -105,3 +106,125 @@ def test_setting_default_provider_initializes_it(): registry.set_default_provider(provider) provider.initialize.assert_called_once() + + +def test_registering_provider_as_default_then_domain_only_initializes_once(): + """Test that registering the same provider as default and for a domain only initializes it once.""" + + registry = ProviderRegistry() + provider = Mock() + + registry.set_default_provider(provider) + registry.set_provider("domain", provider) + + provider.initialize.assert_called_once() + + +def test_registering_provider_as_domain_then_default_only_initializes_once(): + """Test that registering the same provider as default and for a domain only initializes it once.""" + + registry = ProviderRegistry() + provider = Mock() + + registry.set_provider("domain", provider) + registry.set_default_provider(provider) + + provider.initialize.assert_called_once() + + +def test_replacing_provider_used_as_default_does_not_shutdown(): + """Test that replacing a provider that is also the default does not shut it down twice.""" + + registry = ProviderRegistry() + provider1 = Mock() + provider2 = Mock() + + registry.set_default_provider(provider1) + registry.set_provider("domain", provider1) + + registry.set_provider("domain", provider2) + + provider1.shutdown.assert_not_called() + provider2.shutdown.assert_not_called() + + +def test_replacing_default_provider_used_as_domain_does_not_shutdown(): + """Test that replacing a default provider that is also used for a domain does not shut it down twice.""" + + registry = ProviderRegistry() + provider1 = Mock() + provider2 = Mock() + + registry.set_provider("domain", provider1) + registry.set_default_provider(provider1) + + registry.set_default_provider(provider2) + + provider1.shutdown.assert_not_called() + provider2.shutdown.assert_not_called() + + +def test_shutting_down_registry_shuts_down_providers_once(): + """Test that shutting down the registry shuts down each provider only once.""" + + registry = ProviderRegistry() + provider1 = Mock() + provider2 = Mock() + + registry.set_default_provider(provider1) + registry.set_provider("domain1", provider1) + + registry.set_provider("domain2a", provider2) + registry.set_provider("domain2b", provider2) + + registry.shutdown() + + provider1.shutdown.assert_called_once() + provider2.shutdown.assert_called_once() + + +def test_initializing_provider_sets_status_ready(): + """Test that initializing a provider sets its status to READY.""" + + registry = ProviderRegistry() + provider = Mock() + + assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY + + registry.set_provider("domain", provider) + + provider.initialize.assert_called_once() + assert registry.get_provider_status(provider) == ProviderStatus.READY + + +def test_shutting_down_provider_sets_status_not_ready(): + """Test that shutting down a provider sets its status to NOT_READY.""" + + registry = ProviderRegistry() + provider = Mock() + + registry.set_provider("domain", provider) + assert registry.get_provider_status(provider) == ProviderStatus.READY + + registry.shutdown() + assert registry.get_provider_status(provider) == ProviderStatus.NOT_READY + + +def test_clearing_registry_resets_providers_and_default(): + """Test that clearing the registry resets all providers and the default provider.""" + + registry = ProviderRegistry() + provider = Mock() + + registry.set_provider("domain", provider) + registry.set_default_provider(provider) + + registry.clear_providers() + + default_provider = registry.get_default_provider() + assert isinstance(default_provider, NoOpProvider) + assert registry.get_provider("domain") is default_provider + assert registry.get_provider_status(default_provider) == ProviderStatus.READY + + provider.initialize.assert_called_once() + provider.shutdown.assert_called_once()