diff --git a/fastapi_websocket_pubsub/event_broadcaster.py b/fastapi_websocket_pubsub/event_broadcaster.py index 4aa2d55..c47f6b2 100644 --- a/fastapi_websocket_pubsub/event_broadcaster.py +++ b/fastapi_websocket_pubsub/event_broadcaster.py @@ -9,6 +9,11 @@ from .logger import get_logger from .util import pydantic_serialize +try: + from importlib.metadata import version as get_version +except ImportError: + from importlib_metadata import version as get_version + logger = get_logger("EventBroadcaster") @@ -246,6 +251,8 @@ async def start_reader_task(self): self._broadcast_url ) await self.listening_broadcast_channel.connect() + # Add driver info for Redis connections + self._add_driver_info_to_broadcast(self.listening_broadcast_channel) except Exception as e: logger.error( f"Failed to connect to broadcast channel for reading incoming events: {e}" @@ -306,4 +313,70 @@ def cleanup(task): finally: if self.listening_broadcast_channel is not None: await self.listening_broadcast_channel.disconnect() - self.listening_broadcast_channel = None + + def _add_driver_info_to_broadcast(self, broadcast_channel: Broadcast) -> None: + """Add driver identification to Redis connection. + + Uses DriverInfo class if available, or falls back to + lib_name/lib_version for older versions. + + Args: + broadcast_channel: The Broadcast instance to add driver info to + """ + # Get fastapi-websocket-pubsub version + try: + pubsub_version = get_version("fastapi-websocket-pubsub") + except Exception: + pubsub_version = "unknown" + + # Get the Redis client from the backend + redis_client = self._get_redis_client_from_broadcast(broadcast_channel) + if redis_client is None: + return + + # Get connection pool from the redis client + connection_pool: Any = getattr(redis_client, "connection_pool", None) + if connection_pool is None: + return + + # Try to use DriverInfo class + try: + from redis import DriverInfo + + driver_info = DriverInfo().add_upstream_driver( + "fastapi-websocket-pubsub", pubsub_version + ) + connection_pool.connection_kwargs["driver_info"] = driver_info + except (ImportError, AttributeError): + # Fallback: use lib_name/lib_version + # Format: lib_name='redis-py(fastapi-websocket-pubsub_v{version})' + connection_pool.connection_kwargs["lib_name"] = ( + f"redis-py(fastapi-websocket-pubsub_v{pubsub_version})" + ) + # lib_version should be the redis client version + try: + import redis + + redis_version = redis.__version__ + except (ImportError, AttributeError): + redis_version = "unknown" + connection_pool.connection_kwargs["lib_version"] = redis_version + + def _get_redis_client_from_broadcast(self, broadcast_channel: Broadcast) -> Any: + """Extract the Redis client from the broadcaster backend. + + Args: + broadcast_channel: The Broadcast instance + + Returns: + The Redis client instance or None if not available. + """ + try: + backend = broadcast_channel._backend + # Check if it's a Redis backend by class name + if "Redis" not in type(backend).__name__: + return None + # Try _conn (current broadcaster) or _pub_conn (older versions) + return getattr(backend, "_conn", None) or getattr(backend, "_pub_conn", None) + except Exception: + return None diff --git a/tests/test_driver_info.py b/tests/test_driver_info.py new file mode 100644 index 0000000..736280b --- /dev/null +++ b/tests/test_driver_info.py @@ -0,0 +1,184 @@ +"""Tests for Redis DriverInfo support in EventBroadcaster.""" + +import sys +from unittest.mock import MagicMock, patch +import pytest + +from fastapi_websocket_pubsub import EventBroadcaster +from fastapi_websocket_pubsub.event_notifier import EventNotifier +from broadcaster import Broadcast + + +class MockRedisBackend: + """Mock backend with 'Redis' in class name for type checking.""" + pass + + +class MockPostgresBackend: + """Mock backend without 'Redis' in class name.""" + pass + + +class TestDriverInfo: + """Test suite for Redis driver info functionality.""" + + @pytest.mark.asyncio + async def test_add_driver_info_with_driver_info_class(self): + """Test _add_driver_info_to_broadcast with modern redis-py (DriverInfo class available).""" + # Create a mock notifier + mock_notifier = MagicMock(spec=EventNotifier) + + # Create EventBroadcaster instance + broadcaster = EventBroadcaster("redis://localhost:6379", mock_notifier) + + # Create a mock Broadcast instance + mock_broadcast = MagicMock(spec=Broadcast) + + # Mock the backend and Redis client + mock_redis_client = MagicMock() + mock_pool = MagicMock() + mock_pool.connection_kwargs = {} + mock_redis_client.connection_pool = mock_pool + + # Use MockRedisBackend so type(backend).__name__ contains "Redis" + mock_backend = MockRedisBackend() + mock_backend._conn = mock_redis_client + mock_broadcast._backend = mock_backend + + # Create a mock DriverInfo instance with add_upstream_driver method + mock_driver_info_instance = MagicMock() + mock_driver_info_instance.add_upstream_driver = MagicMock( + return_value=mock_driver_info_instance + ) + + # Create a mock DriverInfo class + mock_driver_info_class = MagicMock(return_value=mock_driver_info_instance) + + # Create a mock redis module with DriverInfo + mock_redis_module = MagicMock() + mock_redis_module.DriverInfo = mock_driver_info_class + + with patch( + "fastapi_websocket_pubsub.event_broadcaster.get_version", + return_value="0.3.9", + ): + # Temporarily add our mock redis module + original_redis = sys.modules.get("redis") + sys.modules["redis"] = mock_redis_module + + try: + broadcaster._add_driver_info_to_broadcast(mock_broadcast) + + # Verify driver_info was set in connection_kwargs + assert "driver_info" in mock_pool.connection_kwargs + + # Verify add_upstream_driver was called with correct arguments + mock_driver_info_instance.add_upstream_driver.assert_called_once_with( + "fastapi-websocket-pubsub", "0.3.9" + ) + finally: + # Restore original redis module + if original_redis: + sys.modules["redis"] = original_redis + else: + sys.modules.pop("redis", None) + + @pytest.mark.asyncio + async def test_add_driver_info_fallback_old_redis(self): + """Test _add_driver_info_to_broadcast fallback for older redis-py versions (no DriverInfo).""" + mock_notifier = MagicMock(spec=EventNotifier) + broadcaster = EventBroadcaster("redis://localhost:6379", mock_notifier) + + mock_broadcast = MagicMock(spec=Broadcast) + mock_redis_client = MagicMock() + mock_pool = MagicMock() + mock_pool.connection_kwargs = {} + mock_redis_client.connection_pool = mock_pool + + # Use MockRedisBackend so type(backend).__name__ contains "Redis" + mock_backend = MockRedisBackend() + mock_backend._conn = mock_redis_client + mock_broadcast._backend = mock_backend + + # Create a mock redis module WITHOUT DriverInfo but WITH __version__ + mock_redis_module = MagicMock(spec=["__version__"]) + mock_redis_module.__version__ = "3.5.3" + + with patch( + "fastapi_websocket_pubsub.event_broadcaster.get_version", + return_value="0.3.9", + ): + # Temporarily replace redis module + original_redis = sys.modules.get("redis") + sys.modules["redis"] = mock_redis_module + + try: + broadcaster._add_driver_info_to_broadcast(mock_broadcast) + + # For older redis-py versions without DriverInfo, should fall back to lib_name and lib_version + assert "driver_info" not in mock_pool.connection_kwargs + assert ( + mock_pool.connection_kwargs["lib_name"] + == "redis-py(fastapi-websocket-pubsub_v0.3.9)" + ) + assert mock_pool.connection_kwargs["lib_version"] == "3.5.3" + finally: + # Restore original redis module + if original_redis: + sys.modules["redis"] = original_redis + else: + sys.modules.pop("redis", None) + + @pytest.mark.asyncio + async def test_add_driver_info_no_backend(self): + """Test _add_driver_info_to_broadcast when backend is not available.""" + mock_notifier = MagicMock(spec=EventNotifier) + broadcaster = EventBroadcaster("redis://localhost:6379", mock_notifier) + + mock_broadcast = MagicMock(spec=Broadcast) + mock_broadcast._backend = None + + # Should not raise an exception + broadcaster._add_driver_info_to_broadcast(mock_broadcast) + + @pytest.mark.asyncio + async def test_add_driver_info_no_connection_pool(self): + """Test _add_driver_info_to_broadcast when redis client has no connection_pool.""" + mock_notifier = MagicMock(spec=EventNotifier) + broadcaster = EventBroadcaster("redis://localhost:6379", mock_notifier) + + mock_broadcast = MagicMock(spec=Broadcast) + mock_redis_client = MagicMock() + mock_redis_client.connection_pool = None + + # Use MockRedisBackend so type(backend).__name__ contains "Redis" + mock_backend = MockRedisBackend() + mock_backend._conn = mock_redis_client + mock_broadcast._backend = mock_backend + + # Should not raise an exception + broadcaster._add_driver_info_to_broadcast(mock_broadcast) + + @pytest.mark.asyncio + async def test_add_driver_info_skips_non_redis(self): + """Test that _add_driver_info_to_broadcast skips non-Redis backend types.""" + mock_notifier = MagicMock(spec=EventNotifier) + broadcaster = EventBroadcaster("postgres://localhost:5432/db", mock_notifier) + + mock_broadcast = MagicMock(spec=Broadcast) + mock_redis_client = MagicMock() + mock_pool = MagicMock() + mock_pool.connection_kwargs = {} + mock_redis_client.connection_pool = mock_pool + + # Use MockPostgresBackend so type(backend).__name__ does NOT contain "Redis" + mock_backend = MockPostgresBackend() + mock_backend._conn = mock_redis_client + mock_broadcast._backend = mock_backend + + # Should not add driver info for non-Redis backend types + broadcaster._add_driver_info_to_broadcast(mock_broadcast) + + # Verify no driver info was added + assert "driver_info" not in mock_pool.connection_kwargs + assert "lib_name" not in mock_pool.connection_kwargs