diff --git a/src/unipoll_api/actions/workspace.py b/src/unipoll_api/actions/workspace.py index 5e5f31b..98eb07c 100644 --- a/src/unipoll_api/actions/workspace.py +++ b/src/unipoll_api/actions/workspace.py @@ -1,9 +1,10 @@ +import asyncio from bson import DBRef from beanie.odm.bulk import BulkWriter from unipoll_api import AccountManager from . import plugins, group as GroupActions, policy as PolicyActions, poll as PollActions, members as MembersActions from unipoll_api.documents import Workspace, Account, Policy, Member -from unipoll_api.utils import Permissions +from unipoll_api.utils import Permissions, events from unipoll_api.schemas import WorkspaceSchemas from unipoll_api.exceptions import WorkspaceExceptions # from unipoll_api.dependencies import get_member_by_account @@ -89,6 +90,12 @@ async def update_workspace(workspace: Workspace, # Save the updated workspace if save_changes: await Workspace.save(workspace) + + # Log the event + await workspace.log_event(data={"message": "Workspace updated"}) + asyncio.create_task(events.notify_members(workspace, {"message": "Workspace updated"})) + # BackgroundTasks.add_task(events.notify_members, workspace, {"message": "Workspace updated"}) + # Return the updated workspace return WorkspaceSchemas.Workspace(**workspace.model_dump(include={'id', 'name', 'description'})) diff --git a/src/unipoll_api/app.py b/src/unipoll_api/app.py index 125f8cb..f0a5d7f 100644 --- a/src/unipoll_api/app.py +++ b/src/unipoll_api/app.py @@ -1,4 +1,9 @@ -from fastapi import FastAPI +from datetime import datetime +import json +import uvicorn +import os +import argparse +from fastapi import FastAPI, APIRouter from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware from beanie import init_beanie @@ -8,6 +13,9 @@ from unipoll_api.config import get_settings +# Application start time +start_time = datetime.now() + # Apply setting from configuration file settings = get_settings() @@ -42,9 +50,9 @@ async def on_startup() -> None: # Simplify operation IDs so that generated API clients have simpler function names # Each route will have its operation ID set to the method name - for route in app.routes: - if isinstance(route, APIRoute): - route.operation_id = route.name + # for route in app.routes: + # if isinstance(route, APIRoute): + # route.operation_id = route.name await init_beanie( database=mainDB, # type: ignore diff --git a/src/unipoll_api/config.py b/src/unipoll_api/config.py index 60d8599..e3f16c5 100644 --- a/src/unipoll_api/config.py +++ b/src/unipoll_api/config.py @@ -28,6 +28,8 @@ class Settings(BaseSettings): # type: ignore port: int = 9000 reload: bool = True model_config = SettingsConfigDict(env_file=".env") + redis_host: str = Field(default="localhost", title="Redis Host", description="The host of the Redis database.") + redis_port: int = Field(default=6379, title="Redis Port", description="The port of the Redis database.") # plugins: list = ["timer"] plugins: list = ["test_plugin"] diff --git a/src/unipoll_api/dependencies.py b/src/unipoll_api/dependencies.py index 4a7cf94..885ed5e 100644 --- a/src/unipoll_api/dependencies.py +++ b/src/unipoll_api/dependencies.py @@ -59,7 +59,6 @@ async def websocket_auth(session: Annotated[str | None, Cookie()] = None, strategy=Depends(get_database_strategy) ) -> Account: user = None - if token: max_age = datetime.now(timezone.utc) - timedelta(seconds=strategy.lifetime_seconds) token_data = await token_db.get_by_token(token, max_age) diff --git a/src/unipoll_api/documents.py b/src/unipoll_api/documents.py index 72e8c68..6551ded 100644 --- a/src/unipoll_api/documents.py +++ b/src/unipoll_api/documents.py @@ -1,4 +1,5 @@ # from typing import ForwardRef, NewType, TypeAlias, Optional +from datetime import datetime from typing import Literal from bson import DBRef from beanie import Document as BeanieDocument @@ -9,7 +10,9 @@ Insert, Link, PydanticObjectId, -) # BackLink + TimeSeriesConfig, + Granularity, +) from fastapi_users_db_beanie import BeanieBaseUser from pydantic import Field from unipoll_api.utils import colored_dbg as Debug @@ -90,6 +93,10 @@ async def remove_policy_by_holder( if save: await self.save(link_rule=WriteRules.WRITE) # type: ignore + async def log_event(self, data: dict) -> "Event": + new_event = await Event(resource_id=str(self.id), data=data).create() # type: ignore + return new_event + class Account(BeanieBaseUser, Document): # type: ignore id: ResourceID = Field(default_factory=ResourceID, alias="_id") @@ -233,3 +240,21 @@ class Member(Document): workspace: BackLink[Workspace] = Field(original_field="members") # type: ignore groups: list[BackLink[Group]] = Field(original_field="members") # type: ignore policies: list[Link[Policy]] = [] + + +# https://docs.mongodb.com/manual/core/timeseries-collections +class Event(Document): + ts: datetime = Field(default_factory=datetime.now) + # resource: BackLink[Resource] = Field(original_field="event_log") + resource_id: str = Field(default_factory=str) + data: dict + + # @after_event(Insert) + # def + + class Settings: + timeseries = TimeSeriesConfig( + time_field="ts", + meta_field="resource_id", + # expire_after_seconds=60 * 60 * 24 # 24 hours + ) diff --git a/src/unipoll_api/mongo_db.py b/src/unipoll_api/mongo_db.py index 61b280c..201d100 100644 --- a/src/unipoll_api/mongo_db.py +++ b/src/unipoll_api/mongo_db.py @@ -20,5 +20,6 @@ Documents.Group, Documents.Workspace, Documents.Policy, - Documents.Poll + Documents.Poll, + Documents.Event ] diff --git a/src/unipoll_api/redis.py b/src/unipoll_api/redis.py new file mode 100644 index 0000000..6e96400 --- /dev/null +++ b/src/unipoll_api/redis.py @@ -0,0 +1,47 @@ +# import asyncio +import json +import redis.exceptions +import redis.asyncio +from redis.asyncio.client import Redis + +from unipoll_api.config import get_settings + +settings = get_settings() + + +PUSH_NOTIFICATIONS_CHANNEL = "PUSH_NOTIFICATIONS_CHANNEL" + + +connection: Redis = redis.asyncio.from_url( + f"redis://{settings.redis_host}:{settings.redis_port}", + encoding="utf8", + decode_responses=True, +) + + +async def publish_message(data: dict): + try: + await connection.publish(PUSH_NOTIFICATIONS_CHANNEL, json.dumps(data)) + except redis.exceptions.ConnectionError as e: + print("Connection error:", e) + except Exception as e: + print("An unexpected error occurred:", e) + + +async def listen_to_channel(user_id: str): + # Create message listener and subscribe on the event source channel + try: + async with connection.pubsub() as listener: + await listener.subscribe(PUSH_NOTIFICATIONS_CHANNEL) + # Create a message generator + while True: + message = await listener.get_message() + if message is None: + continue + if message.get("type") == "message": + message = json.loads(message["data"]) + # Checking, if the user is recipient of the message + if user_id == message.get("recipient_id"): + yield {"data": json.dumps(message)} + except redis.exceptions.ConnectionError as e: + print("Connection error:", e) diff --git a/src/unipoll_api/routes/streams.py b/src/unipoll_api/routes/streams.py index 1025ce9..73e4433 100644 --- a/src/unipoll_api/routes/streams.py +++ b/src/unipoll_api/routes/streams.py @@ -1,13 +1,79 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Depends, HTTPException +from fastapi import Body from sse_starlette.sse import EventSourceResponse -from unipoll_api.utils.streams import update_generator - +from datetime import datetime +from unipoll_api.redis import listen_to_channel, publish_message +from unipoll_api.documents import Account, ResourceID, Workspace # Event +from unipoll_api.utils.events import get_updates, get_event_stream +from unipoll_api.dependencies import get_current_active_user +from unipoll_api.exceptions import ResourceExceptions router = APIRouter() -@router.get("/updates", - response_class=EventSourceResponse) -async def event_log(): - updates = update_generator() - return EventSourceResponse(updates) +# For testing purposes only +# Endpoint to get all events for a resource +# Accepts a query parameter "since" to get all events after a certain time +@router.get("/updates/{resource_id}") +async def event_log(resource_id: ResourceID, + since: str = ""): + try: + if since == "": + from unipoll_api.app import start_time + time = start_time + else: + time = datetime.fromisoformat(since) + return await get_updates(resource_id, time) + except Exception as e: + print(e) + return HTTPException(status_code=404, detail="Resource not found") + + +# For testing purposes only +# Endpoint to log an event for a resource(workspace) +@router.post("/workspace/{workspace_id}/log") +async def generate_event(workspace_id: ResourceID, + event: dict = Body(...)): + try: + workspace = await Workspace.get(workspace_id) + if not workspace: + raise ResourceExceptions.ResourceNotFound("Workspace", workspace_id) + new_event = await workspace.log_event(data={"message": event}) + return new_event + except Exception as e: + print(e) + + +# For testing purposes only +# Endpoint to get new events, that occur after this request +@router.get("/resource/{resource_id}/subscribe") +async def mongodb_subscribe(resource_id: ResourceID): + try: + return EventSourceResponse(get_event_stream(resource_id)) + except Exception as e: + print(e) + + +# Endpoint to push notifications to a user +@router.post("/redis/push") +async def redis_push(user: Account = Depends(get_current_active_user), + message: dict = Body(...)): + try: + data = { + "recipient_id": str(user.id), + "timestamp": datetime.now().isoformat(), + "message": message + } + await publish_message(data) + except Exception as e: + print(e) + + +# Endpoint to user notifications +# @router.get("/redis/subscribe") +@router.get("/subscribe") +async def redis_subscribe(user: Account = Depends(get_current_active_user)): + try: + return EventSourceResponse(listen_to_channel(str(user.id))) + except Exception as e: + print(e) diff --git a/src/unipoll_api/utils/events.py b/src/unipoll_api/utils/events.py new file mode 100644 index 0000000..aebcf68 --- /dev/null +++ b/src/unipoll_api/utils/events.py @@ -0,0 +1,36 @@ +import asyncio +from datetime import datetime +from unipoll_api.documents import ResourceID, Event, Resource +from unipoll_api.redis import publish_message +from . import colored_dbg as Debug + + +async def get_updates(resource_id: ResourceID, since: datetime): + events = await Event.find(Event.resource_id == str(resource_id), Event.ts > since).to_list() + return events + + +async def get_event_stream(resource_id: ResourceID): + try: + time = datetime.now() + while True: + events = await get_updates(resource_id, time) + if events: + time = events[-1].ts + yield events + except asyncio.CancelledError as e: + Debug.info("Disconnected from client (via refresh/close)") + raise e + + +async def notify_members(resource: Resource, message: dict): + try: + timestamp = datetime.now() + for member in resource.members: + # print(member) + data = {"recipient_id": str(member.account.id), + "timestamp": str(timestamp), + "message": message} + await publish_message(data) + except Exception as e: + print(e)