diff --git a/changelog.d/19794.feature b/changelog.d/19794.feature new file mode 100644 index 00000000000..d6504d3882c --- /dev/null +++ b/changelog.d/19794.feature @@ -0,0 +1 @@ +[MSC4140: Cancellable delayed events](https://github.com/matrix-org/matrix-spec-proposals/pull/4140): Allow authentication on delayed event management endpoints (such as `/restart`) to bypass ratelimits based on the client IP address. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index d028d65fe33..4338aa9d82f 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1971,7 +1971,7 @@ rc_presence: *(object)* Ratelimiting settings for delayed event management. -This is a ratelimiting option that ratelimits attempts to restart, cancel, or view delayed events based on the sending client's account and device ID. +This is a ratelimiting option that ratelimits attempts to restart, cancel, or view delayed events based on the sending client's account, or its source IP when unauthenticated. Attempts to create or send delayed events are ratelimited not by this setting, but by `rc_message`. diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index 8b8d57b9bf1..9e88e449842 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -2244,8 +2244,8 @@ properties: This is a ratelimiting option that ratelimits attempts to restart, cancel, - or view delayed events based on the sending client's account and device - ID. + or view delayed events based on the sending client's account, + or its source IP when unauthenticated. Attempts to create or send delayed events are ratelimited not by this diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index 4a9f646d4db..9880601c9ff 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -61,13 +61,13 @@ def __init__(self, hs: "HomeServer"): self._storage_controllers = hs.get_storage_controllers() self._config = hs.config self._clock = hs.get_clock() + self._auth = hs.get_auth() self._event_creation_handler = hs.get_event_creation_handler() self._room_member_handler = hs.get_room_member_handler() self._request_ratelimiter = hs.get_request_ratelimiter() - # Ratelimiter for management of existing delayed events, - # keyed by the sending user ID & device ID. + # Ratelimiters for management of existing delayed events self._delayed_event_mgmt_ratelimiter = Ratelimiter( store=self._store, clock=self._clock, @@ -413,9 +413,7 @@ async def cancel(self, request: SynapseRequest, delay_id: str) -> None: NotFoundError: if no matching delayed event could be found. """ assert self._is_master - await self._delayed_event_mgmt_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._mgmt_ratelimit(request) await make_deferred_yieldable(self._initialized_from_db) next_send_ts = await self._store.cancel_delayed_event(delay_id) @@ -430,9 +428,7 @@ async def restart(self, request: SynapseRequest, delay_id: str) -> None: Raises: NotFoundError: if no matching delayed event could be found. """ - await self._delayed_event_mgmt_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._mgmt_ratelimit(request) # Note: We don't need to wait on `self._initialized_from_db` here as the # events that deals with are already marked as processed. @@ -456,9 +452,7 @@ async def send(self, request: SynapseRequest, delay_id: str) -> None: NotFoundError: if no matching delayed event could be found. """ assert self._is_master - await self._delayed_event_mgmt_ratelimiter.ratelimit( - None, request.getClientAddress().host - ) + await self._mgmt_ratelimit(request) await make_deferred_yieldable(self._initialized_from_db) event, next_send_ts = await self._store.process_target_delayed_event(delay_id) @@ -468,6 +462,15 @@ async def send(self, request: SynapseRequest, delay_id: str) -> None: await self._send_event(event) + async def _mgmt_ratelimit(self, request: SynapseRequest) -> None: + if self._auth.has_access_token(request): + requester = await self._auth.get_user_by_req(request) + key = None + else: + requester = None + key = request.getClientAddress().host + await self._delayed_event_mgmt_ratelimiter.ratelimit(requester, key) + async def _send_on_timeout(self) -> None: self._next_delayed_event_call = None @@ -527,10 +530,7 @@ def _schedule_next_at(self, next_send_ts: Timestamp) -> None: async def get_all_for_user(self, requester: Requester) -> list[JsonDict]: """Return all pending delayed events requested by the given user.""" - await self._delayed_event_mgmt_ratelimiter.ratelimit( - requester, - (requester.user.to_string(), requester.device_id), - ) + await self._delayed_event_mgmt_ratelimiter.ratelimit(requester) return await self._store.get_all_delayed_events_for_user( requester.user.localpart ) diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index da904ce1f51..c3bfdf7c8d4 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -317,7 +317,7 @@ def test_cancel_delayed_state_event(self, action_in_path: bool) -> None: ) def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None: delay_ids = [] - for _ in range(2): + for _ in range(3): channel = self.make_request( "POST", _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), @@ -329,12 +329,39 @@ def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None: assert delay_id is not None delay_ids.append(delay_id) - channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path) + delay_id = delay_ids.pop(0) + channel = self._update_delayed_event(delay_id, "cancel", action_in_path) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path) + delay_id = delay_ids.pop(0) + channel = self._update_delayed_event(delay_id, "cancel", action_in_path) + self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + + # Using auth should bypass ratelimit applied against source IP + channel = self._update_delayed_event( + delay_id, "cancel", action_in_path, self.user1_access_token + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + delay_id = delay_ids.pop(0) + channel = self._update_delayed_event( + delay_id, "cancel", action_in_path, self.user1_access_token + ) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastores().main.set_ratelimit_for_user( + self.user1_user_id, 0, 0 + ) + ) + + # Test that the request isn't ratelimited anymore. + channel = self._update_delayed_event( + delay_id, "cancel", action_in_path, self.user1_access_token + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + @parameterized.expand( ( (content_property_value, action_in_path) @@ -475,7 +502,7 @@ def test_restart_delayed_state_event(self, action_in_path: bool) -> None: ) def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None: delay_ids = [] - for _ in range(2): + for _ in range(3): channel = self.make_request( "POST", _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), @@ -487,16 +514,39 @@ def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None: assert delay_id is not None delay_ids.append(delay_id) + delay_id = delay_ids.pop(0) + channel = self._update_delayed_event(delay_id, "restart", action_in_path) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + delay_id = delay_ids.pop(0) + channel = self._update_delayed_event(delay_id, "restart", action_in_path) + self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + + # Using auth should bypass ratelimit applied against source IP channel = self._update_delayed_event( - delay_ids.pop(0), "restart", action_in_path + delay_id, "restart", action_in_path, self.user1_access_token ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + delay_id = delay_ids.pop(0) channel = self._update_delayed_event( - delay_ids.pop(0), "restart", action_in_path + delay_id, "restart", action_in_path, self.user1_access_token ) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastores().main.set_ratelimit_for_user( + self.user1_user_id, 0, 0 + ) + ) + + # Test that the request isn't ratelimited anymore. + channel = self._update_delayed_event( + delay_id, "restart", action_in_path, self.user1_access_token + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + def test_delayed_state_is_not_cancelled_by_new_state_from_same_user( self, ) -> None: @@ -612,7 +662,11 @@ def _get_delayed_event_content(self, event: JsonDict) -> JsonDict: return content def _update_delayed_event( - self, delay_id: str, action: str, action_in_path: bool + self, + delay_id: str, + action: str, + action_in_path: bool, + access_token: str | None = None, ) -> FakeChannel: path = f"{PATH_PREFIX}/{delay_id}" body = {} @@ -620,7 +674,7 @@ def _update_delayed_event( path += f"/{action}" else: body["action"] = action - return self.make_request("POST", path, body) + return self.make_request("POST", path, body, access_token) def _find_sent_delayed_event( self, access_token: str, delay_id: str, should_find: bool