diff --git a/src/robusta/core/sinks/robusta/dal/supabase_dal.py b/src/robusta/core/sinks/robusta/dal/supabase_dal.py index 3939f95b8..ff017aca1 100644 --- a/src/robusta/core/sinks/robusta/dal/supabase_dal.py +++ b/src/robusta/core/sinks/robusta/dal/supabase_dal.py @@ -2,8 +2,12 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Dict, List, Optional +import os +import threading +from typing import Any, Dict, List, Optional, Tuple +from uuid import uuid4 +from cachetools import TTLCache import requests from postgrest._sync.request_builder import SyncQueryRequestBuilder from postgrest.base_request_builder import BaseFilterRequestBuilder @@ -49,7 +53,7 @@ ACCOUNT_RESOURCE_TABLE = "AccountResource" ACCOUNT_RESOURCE_STATUS_TABLE = "AccountResourceStatus" OPENSHIFT_GROUPS_TABLE = "OpenshiftGroups" - +SESSION_TOKENS_TABLE = "AuthTokens" class SupabaseDal(AccountResourceFetcher): def __init__( @@ -77,11 +81,14 @@ def __init__( self.patch_postgrest_execute() self.email = email self.password = password - self.sign_in() + self.user_id = self.sign_in() self.client.auth.on_auth_state_change(self.__update_token_patch) self.sink_name = sink_name self.persist_events = persist_events self.signing_key = signing_key + ttl = int(os.environ.get("SAAS_SESSION_TOKEN_TTL_SEC", "82800")) # 23 hours + self.token_cache = TTLCache(maxsize=1, ttl=ttl) + self.lock = threading.Lock() def patch_postgrest_execute(self): # This is somewhat hacky. @@ -532,11 +539,12 @@ def publish_helm_releases(self, helm_releases: List[HelmRelease]): logging.error(f"Failed to persist helm_releases {helm_releases} error: {e}") raise - def sign_in(self): + def sign_in(self) -> str: logging.info("Supabase dal login") res = self.client.auth.sign_in_with_password({"email": self.email, "password": self.password}) self.client.auth.set_session(res.session.access_token, res.session.refresh_token) self.client.postgrest.auth(res.session.access_token) + return res.user.id def to_db_cluster_status(self, data: ClusterStatus) -> Dict[str, Any]: db_cluster_status = data.dict() @@ -753,3 +761,25 @@ def set_cluster_active(self, active: bool) -> None: ) except Exception as e: logging.error(f"Failed to set cluster status active=False error: {e}") + + def get_session_token(self) -> str: + with self.lock: + session_token = self.token_cache.get("session_token") + if not session_token: + session_token = self.create_session_token() + self.token_cache["session_token"] = session_token + + return session_token + + def create_session_token(self) -> str: + token = str(uuid4()) + self.client.table(SESSION_TOKENS_TABLE).insert( + { + "account_id": self.account_id, + "user_id": self.user_id, + "token": token, + "type": "RUNNER", + }, + returning=ReturnMethod.minimal, + ).execute() + return token diff --git a/src/robusta/integrations/receiver.py b/src/robusta/integrations/receiver.py index acb0ee78d..fdfb31bf6 100644 --- a/src/robusta/integrations/receiver.py +++ b/src/robusta/integrations/receiver.py @@ -76,13 +76,14 @@ class SlackActionsMessage(BaseModel): class ActionRequestReceiver: - def __init__(self, event_handler: PlaybooksEventHandler): + def __init__(self, event_handler: PlaybooksEventHandler, auth_token: str): self.event_handler = event_handler self.active = True self.account_id = self.event_handler.get_global_config().get("account_id") self.cluster_name = self.event_handler.get_global_config().get("cluster_name") self.auth_provider = AuthProvider() self.healthy = False + self.auth_token = auth_token self.ws = websocket.WebSocketApp( WEBSOCKET_RELAY_ADDRESS, @@ -291,6 +292,7 @@ def on_open(self, ws): "account_id": account_id, "cluster_name": cluster_name, "version": RUNNER_VERSION, + "token": self.auth_token, } logging.info(f"connecting to server as account_id={account_id}; cluster_name={cluster_name}") ws.send(json.dumps(open_payload)) diff --git a/src/robusta/model/config.py b/src/robusta/model/config.py index 0dbcf8497..ebcf6a654 100644 --- a/src/robusta/model/config.py +++ b/src/robusta/model/config.py @@ -9,6 +9,7 @@ from robusta.core.pubsub.event_emitter import EventEmitter from robusta.core.pubsub.event_subscriber import EventHandler from robusta.core.pubsub.events_pubsub import EventsPubSub +from robusta.core.sinks.robusta.robusta_sink import RobustaSink from robusta.core.sinks.robusta.robusta_sink_params import RobustaSinkConfigWrapper, RobustaSinkParams from robusta.core.sinks.sink_base import SinkBase from robusta.core.sinks.sink_config import SinkConfigBase @@ -35,6 +36,9 @@ def get_sink_by_name(self, sink_name: str) -> Optional[SinkBase]: def get_all(self) -> Dict[str, SinkBase]: return self.sinks + + def get_robusta_sinks(self) -> List[RobustaSink]: + return [sink for sink in self.sinks.values() if isinstance(sink, RobustaSink)] @classmethod def construct_new_sinks( @@ -159,7 +163,7 @@ class Registry: _playbooks: PlaybooksRegistry = PlaybooksRegistry() _sinks: SinksRegistry = None _scheduler = None - _receiver: ActionRequestReceiver = None + _receiver: Optional[ActionRequestReceiver] = None _global_config = dict() _alert_relabel_config: List[AlertRelabel] = [] _telemetry: Telemetry = Telemetry( @@ -201,7 +205,7 @@ def get_scheduler(self) -> PlaybooksSchedulerManager: def set_receiver(self, receiver: ActionRequestReceiver): self._receiver = receiver - def get_receiver(self) -> ActionRequestReceiver: + def get_receiver(self) -> Optional[ActionRequestReceiver]: return self._receiver def get_telemetry(self) -> Telemetry: diff --git a/src/robusta/runner/config_loader.py b/src/robusta/runner/config_loader.py index ff66ba09c..7c575b7f5 100644 --- a/src/robusta/runner/config_loader.py +++ b/src/robusta/runner/config_loader.py @@ -86,7 +86,7 @@ def __reload_scheduler(self, playbooks_registry: PlaybooksRegistry): def __reload_receiver(self): receiver = self.registry.get_receiver() if not receiver: # no existing receiver, just start one - self.registry.set_receiver(ActionRequestReceiver(self.event_handler)) + self.__create_receiver() return current_account_id = self.event_handler.get_global_config().get("account_id") @@ -95,8 +95,21 @@ def __reload_receiver(self): if current_account_id != receiver.account_id or current_cluster_name != receiver.cluster_name: # need to re-create the receiver receiver.stop() - self.registry.set_receiver(ActionRequestReceiver(self.event_handler)) + self.__create_receiver() + def __create_receiver(self): + robusta_sinks = self.registry.get_sinks().get_robusta_sinks() + if not robusta_sinks: + logging.info("No robusta sinks found, skipping receiver creation") + return + + robusta_sink = robusta_sinks[0] + token = robusta_sink.dal.get_session_token() + + receiver = ActionRequestReceiver(self.event_handler, token) + self.registry.set_receiver(receiver) + return receiver + @staticmethod def __get_package_name_from_pyproject(local_path: str) -> str: with open(os.path.join(local_path, "pyproject.toml"), "r") as pyproj_toml: @@ -235,8 +248,6 @@ def __reload_playbook_packages(self, change_name): # This needs to be set before the robusta sink is created since a cluster status is sent on creation self.registry.set_light_actions(runner_config.light_actions if runner_config.light_actions else []) - self.__reload_receiver() - (sinks_registry, playbooks_registry) = self.__prepare_runtime_config( runner_config, self.registry.get_sinks(), @@ -250,6 +261,8 @@ def __reload_playbook_packages(self, change_name): self.registry.set_actions(action_registry) self.registry.set_playbooks(playbooks_registry) self.registry.set_sinks(sinks_registry) + self.__reload_receiver() + telemetry = self.registry.get_telemetry() telemetry.playbooks_count = len(runner_config.active_playbooks) if runner_config.active_playbooks else 0