diff --git a/Cargo.lock b/Cargo.lock index 27871d57b..6b340b92e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8025,6 +8025,7 @@ dependencies = [ "experimentation_platform", "fred", "frontend", + "futures-util", "idgenerator", "inventory", "json-subscriber", @@ -8041,6 +8042,7 @@ dependencies = [ "superposition_derives", "superposition_macros", "superposition_types", + "tokio", "tracing", "tracing-actix-web", "tracing-subscriber", diff --git a/clients/python/provider/examples/sse_example.py b/clients/python/provider/examples/sse_example.py new file mode 100644 index 000000000..652db2202 --- /dev/null +++ b/clients/python/provider/examples/sse_example.py @@ -0,0 +1,59 @@ +import asyncio +import logging +import os +import sys + +_PYTHON_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, os.path.join(_PYTHON_DIR, "sdk")) +sys.path.insert(0, os.path.join(_PYTHON_DIR, "bindings")) + +from openfeature.evaluation_context import EvaluationContext +from superposition_provider import LocalResolutionProvider, HttpDataSource +from superposition_provider.types import SuperpositionOptions, SseStrategy + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s") +logging.getLogger("superposition_sdk").setLevel(logging.CRITICAL) +logging.getLogger("superposition_provider").setLevel(logging.WARNING) +log = logging.getLogger(__name__) + + +ENDPOINT = os.environ.get("SUPERPOSITION_ENDPOINT", "http://localhost:8080") +TOKEN = os.environ.get("SUPERPOSITION_TOKEN", "token") +ORG = os.environ.get("SUPERPOSITION_ORG_ID", "localorg") +WORKSPACE = os.environ.get("SUPERPOSITION_WORKSPACE", "dev") + + +async def main(): + loop = asyncio.get_event_loop() + _orig = loop.default_exception_handler + loop.set_exception_handler( + lambda l, ctx: None if ctx.get("message") == "Unclosed client session" else _orig(ctx) + ) + + options = SuperpositionOptions(endpoint=ENDPOINT, token=TOKEN, org_id=ORG, workspace_id=WORKSPACE) + + def on_config_change(before, after): + for key, value in after.items(): + if before.get(key) != value: + log.info(f"[UPDATE] {key}: {before.get(key)!r} -> {value!r}") + + provider = LocalResolutionProvider( + primary_source=HttpDataSource(options), + refresh_strategy=SseStrategy(superposition_options=options, reconnect_delay=5), + on_config_change=on_config_change, + ) + + await provider.initialize(EvaluationContext()) + log.info(f"Initial config: {provider.resolve_all_features(EvaluationContext())}") + log.info("Listening for SSE config changes (Ctrl-C to stop)") + + try: + await asyncio.Event().wait() + except KeyboardInterrupt: + pass + finally: + await provider.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/clients/python/provider/superposition_provider/__init__.py b/clients/python/provider/superposition_provider/__init__.py index 3ee99e8c4..e17031687 100644 --- a/clients/python/provider/superposition_provider/__init__.py +++ b/clients/python/provider/superposition_provider/__init__.py @@ -4,7 +4,7 @@ Provides OpenFeature-compliant feature flag providers with support for: - Local resolution with caching (LocalResolutionProvider) - Remote evaluation without caching (SuperpositionAPIProvider) -- Configurable refresh strategies (Polling, OnDemand, Watch, Manual) +- Configurable refresh strategies (Polling, OnDemand, Watch, SSE, Manual) - File-based and HTTP-based data sources - Full FFI integration for performance """ @@ -29,6 +29,7 @@ PollingStrategy, OnDemandStrategy, WatchStrategy, + SseStrategy, ManualStrategy, RefreshStrategy as RefreshStrategyType, ConfigurationOptions, @@ -59,6 +60,7 @@ "PollingStrategy", "OnDemandStrategy", "WatchStrategy", + "SseStrategy", "ManualStrategy", "RefreshStrategyType", "ConfigurationOptions", diff --git a/clients/python/provider/superposition_provider/http_data_source.py b/clients/python/provider/superposition_provider/http_data_source.py index 7b6e3974b..eb2a978e3 100644 --- a/clients/python/provider/superposition_provider/http_data_source.py +++ b/clients/python/provider/superposition_provider/http_data_source.py @@ -197,6 +197,13 @@ def supports_experiments(self) -> bool: return True async def close(self) -> None: - """Close the HTTP client.""" + """Close the HTTP client and its underlying aiohttp session.""" if self.client: + try: + http_client = getattr(self.client._config, "http_client", None) + session = getattr(http_client, "_session", None) + if session and not session.closed: + await session.close() + except Exception: + pass self.client = None diff --git a/clients/python/provider/superposition_provider/local_provider.py b/clients/python/provider/superposition_provider/local_provider.py index a8db676dc..69525f55e 100644 --- a/clients/python/provider/superposition_provider/local_provider.py +++ b/clients/python/provider/superposition_provider/local_provider.py @@ -10,7 +10,7 @@ import json import weakref from datetime import datetime, timezone -from typing import Dict, List, Optional, Any, Tuple, Union, Sequence, Mapping +from typing import Callable, Dict, List, Optional, Any, Tuple, Union, Sequence, Mapping from openfeature.provider import ( AbstractProvider, @@ -26,7 +26,7 @@ from . import FetchResponse from .data_source import SuperpositionDataSource, ConfigData, ExperimentData from .interfaces import AllFeatureProvider, FeatureExperimentMeta -from .types import RefreshStrategy, OnDemandStrategy, WatchStrategy, PollingStrategy, ManualStrategy, default_on_demand_strategy +from .types import RefreshStrategy, OnDemandStrategy, WatchStrategy, PollingStrategy, ManualStrategy, SseStrategy, default_on_demand_strategy logger = logging.getLogger(__name__) @@ -45,6 +45,7 @@ def __init__( primary_source: SuperpositionDataSource, fallback_source: Optional[SuperpositionDataSource] = None, refresh_strategy: RefreshStrategy = default_on_demand_strategy(), + on_config_change: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, ): """Initialize local resolution provider. @@ -56,6 +57,7 @@ def __init__( self.primary_source = primary_source self.fallback_source = fallback_source self.refresh_strategy = refresh_strategy + self.on_config_change = on_config_change self.metadata = Metadata(name="LocalResolutionProvider") self.status = ProviderStatus.NOT_READY @@ -68,6 +70,8 @@ def __init__( # Background task for refresh strategy self._background_task: Optional[asyncio.Task] = None + # Set once the SSE connection is established (SseStrategy only) + self._sse_connected_event: Optional[asyncio.Event] = None async def initialize(self, context: EvaluationContext): """Initialize the provider. @@ -358,7 +362,21 @@ async def refresh(self) -> None: Useful for MANUAL refresh strategy. """ - await asyncio.gather(self._fetch_and_cache_config(), self._fetch_and_cache_experiments()) + before = self.resolve_all_features(EvaluationContext()) if self.on_config_change and self.ffi_cache else None + + results = await asyncio.gather( + self._fetch_and_cache_config(), + self._fetch_and_cache_experiments(), + return_exceptions=True, + ) + for result in results: + if isinstance(result, Exception): + logger.warning(f"Error during refresh: {result}") + + if self.on_config_change and before is not None: + after = self.resolve_all_features(EvaluationContext()) + if before != after: + self.on_config_change(before, after) # --- Private helpers --- @@ -574,7 +592,75 @@ async def _watch_loop() -> None: case WatchStrategy(): self._background_task = asyncio.create_task(_watch_loop()) case PollingStrategy(): - self._background_task = asyncio.create_task(_polling_loop()) + self._background_task = asyncio.create_task(self._polling_loop()) + case SseStrategy(): + import weakref + strategy: SseStrategy = self.refresh_strategy + options = strategy.superposition_options + debounce_s = strategy.debounce_ms / 1000 + reconnect_delay = max(strategy.reconnect_delay, 1) + sse_url = f"{options.endpoint.rstrip('/')}/{options.org_id}/{options.workspace_id}/stream" + sse_headers = {"Authorization": f"Bearer {options.token}"} + if options.org_id: + sse_headers["x-org-id"] = options.org_id + weak_self = weakref.ref(self) + connected_event = asyncio.Event() + self._sse_connected_event = connected_event + + async def _sse_loop(): + import aiohttp + logger.info(f"Starting SSE refresh (url={sse_url})") + try: + async with aiohttp.ClientSession() as session: + while True: + if weak_self() is None: + logger.info("Provider garbage collected, stopping SSE loop.") + return + try: + async with session.get( + sse_url, + headers=sse_headers, + timeout=aiohttp.ClientTimeout(total=None, sock_read=30), + ) as resp: + if resp.status != 200: + logger.warning(f"SSE endpoint returned {resp.status}, retrying in {reconnect_delay}s") + await asyncio.sleep(reconnect_delay) + continue + logger.info("SSE connection established") + connected_event.set() + self_ref = weak_self() + if self_ref is not None: + try: + await self_ref.refresh() + except Exception as e: + logger.warning(f"Reconnect refresh failed: {e}") + debounce_task = None + async for line_bytes in resp.content: + self_ref = weak_self() + if self_ref is None: + logger.info("Provider garbage collected, stopping SSE loop.") + return + line = line_bytes.decode("utf-8", errors="replace").strip() + logger.debug(f"SSE raw line: {line!r}") + if not line or line.startswith(":"): + continue + if line.startswith("event:") or line.startswith("data:"): + async def _do_refresh(ref=self_ref): + await asyncio.sleep(debounce_s) + try: + await ref.refresh() + except Exception as e: + logger.warning(f"SSE-triggered refresh failed: {e}") + if debounce_task and not debounce_task.done(): + debounce_task.cancel() + debounce_task = asyncio.create_task(_do_refresh()) + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + logger.warning(f"SSE connection error: {e}, retrying in {reconnect_delay}s") + await asyncio.sleep(reconnect_delay) + except asyncio.CancelledError: + logger.info("SSE loop cancelled") + + self._background_task = asyncio.create_task(_sse_loop()) case ManualStrategy(): logger.debug("MANUAL strategy: caller must invoke refresh()") case OnDemandStrategy(): diff --git a/clients/python/provider/superposition_provider/types.py b/clients/python/provider/superposition_provider/types.py index 6f89399dc..04f4f6546 100644 --- a/clients/python/provider/superposition_provider/types.py +++ b/clients/python/provider/superposition_provider/types.py @@ -71,6 +71,19 @@ class WatchStrategy: def default_watch_strategy(): return WatchStrategy(500) +@dataclass +class SseStrategy: + """SSE-based refresh strategy. + + Connects to the server's SSE endpoint and refreshes when a change event + is received. Reconnects automatically on connection failure. + + Requires SuperpositionOptions to build the SSE endpoint URL and authenticate. + """ + superposition_options: SuperpositionOptions + reconnect_delay: int = 5 # seconds between reconnect attempts + debounce_ms: int = 500 # debounce rapid successive events + @dataclass class ManualStrategy: """Manual refresh strategy. @@ -81,7 +94,7 @@ class ManualStrategy: # Union type for all refresh strategies -RefreshStrategy = Union[PollingStrategy, OnDemandStrategy, WatchStrategy, ManualStrategy] +RefreshStrategy = Union[PollingStrategy, OnDemandStrategy, WatchStrategy, SseStrategy, ManualStrategy] # ============================================================================ diff --git a/crates/context_aware_config/src/api/context/handlers.rs b/crates/context_aware_config/src/api/context/handlers.rs index 38060ef48..3760da472 100644 --- a/crates/context_aware_config/src/api/context/handlers.rs +++ b/crates/context_aware_config/src/api/context/handlers.rs @@ -15,7 +15,7 @@ use diesel::{ use serde_json::{Map, Value}; use service_utils::{ helpers::{ - WebhookData, execute_webhook_call, fetch_dimensions_info_map, parse_config_tags, + WebhookData, fetch_dimensions_info_map, notify_change, parse_config_tags, }, middlewares::auth_z::{Action as AuthZAction, AuthZ}, service::types::{ @@ -186,7 +186,7 @@ async fn create_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -322,7 +322,7 @@ async fn update_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -467,7 +467,7 @@ async fn move_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -754,7 +754,7 @@ async fn delete_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::NoContent() @@ -1129,7 +1129,7 @@ async fn bulk_operations_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut resp_builder = if webhook_status { HttpResponse::Ok() @@ -1245,7 +1245,7 @@ async fn weight_recompute_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() diff --git a/crates/context_aware_config/src/api/default_config/handlers.rs b/crates/context_aware_config/src/api/default_config/handlers.rs index 19a06965e..6141c83a7 100644 --- a/crates/context_aware_config/src/api/default_config/handlers.rs +++ b/crates/context_aware_config/src/api/default_config/handlers.rs @@ -12,7 +12,7 @@ use diesel::{ use jsonschema::ValidationError; use serde_json::Value; use service_utils::{ - helpers::{WebhookData, execute_webhook_call, parse_config_tags}, + helpers::{WebhookData, notify_change, parse_config_tags}, service::types::{ AppHeader, AppState, CustomHeaders, DbConnection, EncryptionKey, SchemaName, WorkspaceContext, @@ -188,7 +188,7 @@ async fn create_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -348,7 +348,7 @@ async fn update_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -586,7 +586,7 @@ async fn delete_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() diff --git a/crates/context_aware_config/src/api/dimension/handlers.rs b/crates/context_aware_config/src/api/dimension/handlers.rs index 8e064005b..29c55abc2 100644 --- a/crates/context_aware_config/src/api/dimension/handlers.rs +++ b/crates/context_aware_config/src/api/dimension/handlers.rs @@ -9,7 +9,7 @@ use diesel::{ }; use serde_json::Value; use service_utils::{ - helpers::{WebhookData, execute_webhook_call, parse_config_tags}, + helpers::{WebhookData, notify_change, parse_config_tags}, service::types::{ AppHeader, AppState, CustomHeaders, DbConnection, WorkspaceContext, }, @@ -261,7 +261,7 @@ async fn create_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Created() @@ -499,7 +499,7 @@ async fn update_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -698,7 +698,7 @@ async fn delete_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() diff --git a/crates/experimentation_platform/src/api/experiments/handlers.rs b/crates/experimentation_platform/src/api/experiments/handlers.rs index cd50283ab..417ba580f 100644 --- a/crates/experimentation_platform/src/api/experiments/handlers.rs +++ b/crates/experimentation_platform/src/api/experiments/handlers.rs @@ -25,8 +25,8 @@ use serde_json::{Map, Value}; use service_utils::{ db::run_query, helpers::{ - WebhookData, construct_request_headers, execute_webhook_call, - fetch_dimensions_info_map, generate_snowflake_id, is_not_modified, request, + WebhookData, construct_request_headers, fetch_dimensions_info_map, + generate_snowflake_id, is_not_modified, notify_change, request, }, middlewares::auth_z::{Action as AuthZAction, AuthZ}, redis::{ @@ -413,7 +413,7 @@ async fn create_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -490,7 +490,7 @@ async fn conclude_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -765,7 +765,7 @@ async fn discard_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -1499,7 +1499,7 @@ async fn ramp_handler( action: Action::Update, }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -1849,7 +1849,7 @@ async fn update_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -1909,7 +1909,7 @@ async fn pause_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() @@ -2005,7 +2005,7 @@ async fn resume_handler( }; let webhook_status = - execute_webhook_call(data, &workspace_context, &state, &mut conn).await; + notify_change(data, &workspace_context, &state, &mut conn).await; let mut http_resp = if webhook_status { HttpResponse::Ok() diff --git a/crates/service_utils/src/helpers.rs b/crates/service_utils/src/helpers.rs index c109830ee..486286596 100644 --- a/crates/service_utils/src/helpers.rs +++ b/crates/service_utils/src/helpers.rs @@ -460,6 +460,24 @@ where } } +/// Calls `execute_webhook_call` and also broadcasts an SSE "config changed" +/// signal so that connected SDK clients refresh immediately. +pub async fn notify_change( + data: WebhookData, + workspace_context: &WorkspaceContext, + state: &Data, + conn: &mut DBConnection, +) -> bool +where + T: Serialize, +{ + // Broadcast SSE signal (fire-and-forget; ok if no subscribers). + let sender = state.get_sse_sender(&workspace_context.schema_name); + let _ = sender.send(()); + + execute_webhook_call(data, workspace_context, state, conn).await +} + pub fn fetch_dimensions_info_map( conn: &mut DBConnection, schema_name: &SchemaName, diff --git a/crates/service_utils/src/service/types.rs b/crates/service_utils/src/service/types.rs index b777a2ae1..8fb9013ec 100644 --- a/crates/service_utils/src/service/types.rs +++ b/crates/service_utils/src/service/types.rs @@ -1,10 +1,12 @@ use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, future::{Ready, ready}, str::FromStr, sync::{Arc, Mutex}, }; +use tokio::sync::watch; + use actix_web::{Error, FromRequest, HttpMessage, HttpResponseBuilder, error, web::Data}; use chrono::{DateTime, Utc}; use derive_more::{Deref, DerefMut}; @@ -97,6 +99,20 @@ pub struct AppState { pub redis: Option, pub http_client: reqwest::Client, pub master_encryption_key: Option, + pub sse_broadcaster: Mutex>>, +} + +impl AppState { + pub fn get_sse_sender(&self, schema_name: &str) -> watch::Sender<()> { + let mut map = self.sse_broadcaster.lock().expect("sse_broadcaster lock poisoned"); + map.entry(schema_name.to_string()) + .or_insert_with(|| watch::channel(()).0) + .clone() + } + + pub fn subscribe_sse(&self, schema_name: &str) -> watch::Receiver<()> { + self.get_sse_sender(schema_name).subscribe() + } } impl FromStr for AppEnv { diff --git a/crates/superposition/Cargo.toml b/crates/superposition/Cargo.toml index 26fa4bada..67dd3973d 100644 --- a/crates/superposition/Cargo.toml +++ b/crates/superposition/Cargo.toml @@ -30,6 +30,8 @@ rs-snowflake = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } service_utils = { workspace = true } +futures-util = { workspace = true } +tokio = { workspace = true } superposition_derives = { workspace = true } superposition_macros = { workspace = true } superposition_types = { workspace = true, features = [ diff --git a/crates/superposition/src/app_state.rs b/crates/superposition/src/app_state.rs index 0083ac256..46de09975 100644 --- a/crates/superposition/src/app_state.rs +++ b/crates/superposition/src/app_state.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, sync::{Arc, Mutex}, time::Duration, }; @@ -110,5 +110,6 @@ pub async fn get( redis: redis_pool, http_client: reqwest::Client::new(), master_encryption_key, + sse_broadcaster: Mutex::new(HashMap::new()), } } diff --git a/crates/superposition/src/main.rs b/crates/superposition/src/main.rs index 44d32f407..69242783a 100644 --- a/crates/superposition/src/main.rs +++ b/crates/superposition/src/main.rs @@ -3,6 +3,7 @@ mod app_state; mod log_span; mod organisation; mod resolve; +mod stream; mod webhooks; mod workspace; @@ -298,6 +299,11 @@ impl ScopeExt for Scope { .wrap(OrgWorkspaceMiddlewareFactory::new(true, true)) .service(resolve::endpoints()), ) + .service( + scope("/stream") + .wrap(OrgWorkspaceMiddlewareFactory::new(true, true)) + .service(stream::endpoints()), + ) .service( scope("/authz/workspace") .wrap(OrgWorkspaceMiddlewareFactory::new(true, true)) diff --git a/crates/superposition/src/stream.rs b/crates/superposition/src/stream.rs new file mode 100644 index 000000000..e87b7d56c --- /dev/null +++ b/crates/superposition/src/stream.rs @@ -0,0 +1,2 @@ +mod handlers; +pub use handlers::endpoints; diff --git a/crates/superposition/src/stream/handlers.rs b/crates/superposition/src/stream/handlers.rs new file mode 100644 index 000000000..b1e1c07c1 --- /dev/null +++ b/crates/superposition/src/stream/handlers.rs @@ -0,0 +1,43 @@ +use actix_web::{HttpResponse, Scope, web::Data}; +use futures_util::stream::{self, StreamExt}; +use service_utils::service::types::{AppState, WorkspaceContext}; +use tokio::time::{Duration, interval}; + +pub fn endpoints() -> Scope { + Scope::new("").service(sse_stream) +} + +#[actix_web::get("")] +async fn sse_stream( + workspace_context: WorkspaceContext, + state: Data, +) -> HttpResponse { + let schema_name = workspace_context.schema_name.0.clone(); + let mut rx = state.subscribe_sse(&schema_name); + let mut keepalive = interval(Duration::from_secs(15)); + keepalive.tick().await; // skip the immediate first tick + + let event_stream = stream::unfold( + (rx, keepalive), + |(mut rx, mut keepalive)| async move { + tokio::select! { + result = rx.changed() => match result { + Ok(()) => { + let payload = "event: config_change\ndata: {}\n\n"; + Some((Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(payload)), (rx, keepalive))) + } + Err(_) => None, + }, + _ = keepalive.tick() => { + Some((Ok(actix_web::web::Bytes::from(": keepalive\n\n")), (rx, keepalive))) + } + } + }, + ); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .insert_header(("X-Accel-Buffering", "no")) + .streaming(event_stream) +} diff --git a/crates/superposition_provider/src/client.rs b/crates/superposition_provider/src/client.rs index 560de9e63..f01dc1e0c 100644 --- a/crates/superposition_provider/src/client.rs +++ b/crates/superposition_provider/src/client.rs @@ -98,6 +98,9 @@ impl CacConfig { RefreshStrategy::Watch(_) => { info!("Using Watch refresh strategy"); } + RefreshStrategy::Sse(_) => { + info!("Using SSE refresh strategy"); + } RefreshStrategy::Manual => { info!("Using Manual refresh strategy"); } @@ -360,6 +363,9 @@ impl ExperimentationConfig { RefreshStrategy::Watch(_) => { info!("Using Watch refresh strategy for experiments"); } + RefreshStrategy::Sse(_) => { + info!("Using SSE refresh strategy for experiments"); + } RefreshStrategy::Manual => { info!("Using Manual refresh strategy for experiments"); } diff --git a/crates/superposition_provider/src/local_provider.rs b/crates/superposition_provider/src/local_provider.rs index cbba4c0d1..3ab107825 100644 --- a/crates/superposition_provider/src/local_provider.rs +++ b/crates/superposition_provider/src/local_provider.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use derive_more::{Deref, DerefMut}; @@ -187,6 +188,12 @@ impl LocalResolutionProvider { } } } + RefreshStrategy::Sse(sse_strategy) => { + log::info!("LocalResolutionProvider: starting SSE strategy"); + let task = self.start_sse(sse_strategy.clone()).await; + let mut background_task = self.background_task.write().await; + *background_task = Some(task); + } RefreshStrategy::Manual => { log::info!("LocalResolutionProvider: using Manual refresh strategy"); } @@ -337,11 +344,14 @@ impl LocalResolutionProvider { } async fn start_polling(&self, interval: u64) -> JoinHandle<()> { - let provider = self.clone(); + let weak = Arc::downgrade(&self.0); tokio::spawn(async move { loop { sleep(Duration::from_secs(interval)).await; - let _ = provider.do_refresh().await; + match weak.upgrade() { + Some(p) => { let _ = LocalResolutionProvider(p).do_refresh().await; } + None => return, + } } }) } @@ -351,28 +361,147 @@ impl LocalResolutionProvider { mut watch_stream: crate::types::WatchStream, debounce_ms: u64, ) -> JoinHandle<()> { - let provider = self.clone(); + let weak = Arc::downgrade(&self.0); + let debounce = Duration::from_millis(debounce_ms); tokio::spawn(async move { + let mut debounce_task: Option> = None; + loop { match watch_stream.receiver.recv().await { Ok(()) => { - // Debounce: wait, then drain any queued events - sleep(Duration::from_millis(debounce_ms)).await; - while watch_stream.receiver.try_recv().is_ok() {} - let _ = provider.do_refresh().await; + if let Some(prev) = debounce_task.take() { + prev.abort(); + } + let weak_clone = weak.clone(); + debounce_task = Some(tokio::spawn(async move { + sleep(debounce).await; + match weak_clone.upgrade() { + Some(p) => { let _ = LocalResolutionProvider(p).do_refresh().await; } + None => {} + } + })); } Err(e) => { - log::error!( - "LocalResolutionProvider: watch channel error: {}", - e - ); + log::error!("LocalResolutionProvider: watch channel error: {}", e); + return; } } } }) } + async fn start_sse(&self, strategy: crate::types::SseStrategy) -> JoinHandle<()> { + let weak = Arc::downgrade(&self.0); + let opts = strategy.superposition_options; + let reconnect_delay = Duration::from_secs(strategy.reconnect_delay.unwrap_or(5)); + let debounce = Duration::from_millis(strategy.debounce_ms.unwrap_or(500)); + let sse_url = format!( + "{}/{}/{}/stream", + opts.endpoint.trim_end_matches('/'), + opts.org_id, + opts.workspace_id, + ); + + tokio::spawn(async move { + let client = reqwest::Client::new(); + log::info!("LocalResolutionProvider: starting SSE loop (url={})", sse_url); + + loop { + let provider = match weak.upgrade() { + Some(p) => LocalResolutionProvider(p), + None => { + log::info!("LocalResolutionProvider: provider dropped, stopping SSE loop"); + return; + } + }; + + let resp = client + .get(&sse_url) + .header("Authorization", format!("Bearer {}", opts.token)) + .header("x-org-id", &opts.org_id) + .send() + .await; + + let resp = match resp { + Ok(r) if r.status().is_success() => r, + Ok(r) => { + log::warn!("LocalResolutionProvider: SSE endpoint returned {}, retrying", r.status()); + sleep(reconnect_delay).await; + continue; + } + Err(e) => { + log::warn!("LocalResolutionProvider: SSE connect error: {}, retrying", e); + sleep(reconnect_delay).await; + continue; + } + }; + + log::info!("LocalResolutionProvider: SSE connection established"); + if let Err(e) = provider.do_refresh().await { + log::warn!("LocalResolutionProvider: reconnect refresh failed: {}", e); + } + + let mut resp = resp; + let mut debounce_task: Option> = None; + + loop { + // 30s read timeout — server sends keepalives every 15s so under + // normal conditions this never fires. On a silent dead connection + // (NAT drop, etc.) it detects the failure and triggers a reconnect. + let chunk = tokio::time::timeout( + std::time::Duration::from_secs(30), + resp.chunk(), + ) + .await; + + match chunk { + Err(_) => { + log::warn!("LocalResolutionProvider: SSE read timeout, reconnecting"); + break; + } + Ok(Err(e)) => { + log::warn!("LocalResolutionProvider: SSE read error: {}, reconnecting", e); + break; + } + Ok(Ok(None)) => { + log::warn!("LocalResolutionProvider: SSE stream ended, reconnecting"); + break; + } + Ok(Ok(Some(bytes))) => { + let text = String::from_utf8_lossy(&bytes); + for line in text.lines() { + let line = line.trim(); + log::debug!("LocalResolutionProvider: SSE raw line: {:?}", line); + if line.is_empty() || line.starts_with(':') { + continue; + } + if line.starts_with("event:") || line.starts_with("data:") { + if let Some(prev) = debounce_task.take() { + prev.abort(); + } + let p = provider.clone(); + debounce_task = Some(tokio::spawn(async move { + sleep(debounce).await; + if let Err(e) = p.do_refresh().await { + log::warn!("LocalResolutionProvider: SSE-triggered refresh failed: {}", e); + } + })); + } + } + } + } + } + + if let Some(task) = debounce_task.take() { + task.abort(); + } + + sleep(reconnect_delay).await; + } + }) + } + async fn ensure_fresh_data(&self) -> Result<()> { if let RefreshStrategy::OnDemand(on_demand) = &self.refresh_strategy { let ttl = on_demand.ttl; diff --git a/crates/superposition_provider/src/types.rs b/crates/superposition_provider/src/types.rs index 292e2b8a6..8c2967584 100644 --- a/crates/superposition_provider/src/types.rs +++ b/crates/superposition_provider/src/types.rs @@ -120,6 +120,26 @@ impl Default for WatchStrategy { } } +/// SSE-based refresh strategy. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SseStrategy { + pub superposition_options: SuperpositionOptions, + /// Seconds to wait before reconnecting after a connection failure (default: 5). + pub reconnect_delay: Option, + /// Debounce rapid successive events in milliseconds (default: 500). + pub debounce_ms: Option, +} + +impl SseStrategy { + pub fn new(superposition_options: SuperpositionOptions) -> Self { + Self { + superposition_options, + reconnect_delay: Some(5), + debounce_ms: Some(500), + } + } +} + /// A stream of change notifications from a data source. pub struct WatchStream { pub receiver: tokio::sync::broadcast::Receiver<()>, @@ -130,6 +150,7 @@ pub enum RefreshStrategy { Polling(PollingStrategy), OnDemand(OnDemandStrategy), Watch(WatchStrategy), + Sse(SseStrategy), Manual, }