diff --git a/queueless_backend/queue_tracker/tests.py b/queueless_backend/queue_tracker/tests.py index 9043c39..0aa64cd 100644 --- a/queueless_backend/queue_tracker/tests.py +++ b/queueless_backend/queue_tracker/tests.py @@ -1,7 +1,8 @@ from datetime import timedelta +from unittest.mock import patch from django.core.cache import cache -from django.test import SimpleTestCase, TestCase +from django.test import SimpleTestCase, TestCase, override_settings from django.utils import timezone from rest_framework import status from rest_framework.test import APIClient @@ -380,3 +381,66 @@ def test_rejoin_after_cancel(self): ) self.assertEqual(response_join_success.status_code, status.HTTP_201_CREATED) self.assertEqual(response_join_success.data["queue_number"], 20) + + +@override_settings( + CACHES={ + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "throttling-test-cache", + } + }, + REST_FRAMEWORK={ + "DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.ScopedRateThrottle"], + "DEFAULT_THROTTLE_RATES": {"join": "1/minute", "burst": "1/minute"}, + }, +) +class QueueThrottlingTests(TestCase): + def setUp(self): + self.client = APIClient() + self.institution = Institution.objects.create( + name="Throttled Office", + institution_type=Institution.InstitutionType.GOVERNMENT, + status=Institution.Status.OPEN, + is_active=True, + ) + cache.clear() + + def test_join_queue_throttling(self): + # We patch ScopedRateThrottle.THROTTLE_RATES because DRF loads it once + with patch( + "rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES", + {"join": "1/minute", "burst": "1/minute"}, + ): + # Join limit is pinned to 1/minute + response = self.client.post( + "/api/queue/join/", + {"institution_id": self.institution.id, "queue_number": 101}, + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + # 2nd request should be throttled + response = self.client.post( + "/api/queue/join/", + {"institution_id": self.institution.id, "queue_number": 102}, + ) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_status_polling_throttling(self): + with patch( + "rest_framework.throttling.ScopedRateThrottle.THROTTLE_RATES", + {"join": "1/minute", "burst": "1/minute"}, + ): + entry = QueueEntry.objects.create( + institution=self.institution, + queue_number=10, + current_serving_number=5, + status=QueueEntryStatus.WAITING, + ) + # Burst limit is pinned to 1/minute + response = self.client.get(f"/api/queue/entries/{entry.session_id}/status/") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # 2nd request should be throttled + response = self.client.get(f"/api/queue/entries/{entry.session_id}/status/") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/queueless_backend/queue_tracker/views.py b/queueless_backend/queue_tracker/views.py index 7bf640b..841086a 100644 --- a/queueless_backend/queue_tracker/views.py +++ b/queueless_backend/queue_tracker/views.py @@ -26,6 +26,7 @@ class QueueJoinView(APIView): permission_classes = [permissions.AllowAny] + throttle_scope = "join" def post(self, request): serializer = QueueJoinSerializer(data=request.data) @@ -129,6 +130,7 @@ def post(self, request): class QueueEntryStatusView(APIView): permission_classes = [permissions.AllowAny] + throttle_scope = "burst" def get(self, request, session_id): try: @@ -162,6 +164,7 @@ def get(self, request, session_id): class QueueEntryCheckInView(APIView): permission_classes = [permissions.AllowAny] + throttle_scope = "burst" def patch(self, request, session_id): entry, error = check_in_serving_entry(session_id) @@ -182,6 +185,7 @@ def patch(self, request, session_id): class QueueEntryCancelView(APIView): permission_classes = [permissions.AllowAny] + throttle_scope = "burst" def post(self, request, session_id): entry, error = cancel_queue_entry(session_id) diff --git a/queueless_backend/queueless_backend/settings.py b/queueless_backend/queueless_backend/settings.py index 13f41fa..c7c32d3 100644 --- a/queueless_backend/queueless_backend/settings.py +++ b/queueless_backend/queueless_backend/settings.py @@ -179,6 +179,28 @@ def env_bool(name: str, default: bool = False) -> bool: }, } DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" +try: + _drf_num_proxies = os.getenv("DRF_NUM_PROXIES") + DRF_NUM_PROXIES = int(_drf_num_proxies) if _drf_num_proxies is not None else None + if DRF_NUM_PROXIES is not None and DRF_NUM_PROXIES < 0: + raise ImproperlyConfigured("DRF_NUM_PROXIES must be a non-negative integer.") +except ValueError as exc: + raise ImproperlyConfigured("DRF_NUM_PROXIES must be an integer.") from exc + +REST_FRAMEWORK = { + "DEFAULT_THROTTLE_CLASSES": [ + "rest_framework.throttling.AnonRateThrottle", + "rest_framework.throttling.UserRateThrottle", + "rest_framework.throttling.ScopedRateThrottle", + ], + "DEFAULT_THROTTLE_RATES": { + "anon": os.getenv("DRF_THROTTLE_RATE_ANON", "20000/day"), + "user": os.getenv("DRF_THROTTLE_RATE_USER", "100000/day"), + "burst": os.getenv("DRF_THROTTLE_RATE_BURST", "60/minute"), + "join": os.getenv("DRF_THROTTLE_RATE_JOIN", "5/minute"), + }, + "NUM_PROXIES": DRF_NUM_PROXIES, +} # CORS configuration from environment variables. CORS_ALLOW_ALL_ORIGINS = env_bool("CORS_ALLOW_ALL_ORIGINS", default=False)