From 9a9b1066011e39f73cbfe963aede90feb9bf61ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damian=20Sowi=C5=84ski?= Date: Wed, 4 Feb 2026 14:40:00 +0100 Subject: [PATCH] feat: Add django-guardian --- .../management/commands/ensuresuperuser.py | 3 +- server/account/models.py | 42 ++- server/account/pipelines.py | 1 + server/account/serializers.py | 4 +- server/account/services/__init__.py | 3 +- server/account/services/user.py | 109 ++++++++ .../tests/test_views/test_service_accounts.py | 245 ++++++++++++++++++ server/account/tests/test_views/test_users.py | 205 +++++++++++++++ server/account/views/service_accounts.py | 24 +- server/account/views/users.py | 37 ++- .../agent/conversation/tests/test_manager.py | 3 +- server/agent/tests/test_views.py | 5 +- server/agent/views.py | 16 +- server/catalog/models/data_set.py | 3 + .../test_views/test_data_set_detail_view.py | 94 +++++++ ...ta_set_document_source_view_permissions.py | 136 ++++++++++ .../test_data_set_list_view_permissions.py | 104 ++++++++ ...ata_set_product_source_view_permissions.py | 133 ++++++++++ .../test_views/test_data_set_user_view.py | 143 ++++++++++ server/catalog/utils.py | 160 +++++++++++- server/catalog/views.py | 130 ++++++---- server/conftest.py | 106 +++++++- server/pecl/settings.py | 6 +- server/pyproject.dev.toml | 3 +- server/pyproject.toml | 2 + server/sample.env | 5 + 26 files changed, 1619 insertions(+), 103 deletions(-) create mode 100644 server/account/services/user.py create mode 100644 server/account/tests/test_views/test_service_accounts.py create mode 100644 server/account/tests/test_views/test_users.py create mode 100644 server/catalog/tests/test_views/test_data_set_detail_view.py create mode 100644 server/catalog/tests/test_views/test_data_set_document_source_view_permissions.py create mode 100644 server/catalog/tests/test_views/test_data_set_list_view_permissions.py create mode 100644 server/catalog/tests/test_views/test_data_set_product_source_view_permissions.py create mode 100644 server/catalog/tests/test_views/test_data_set_user_view.py diff --git a/server/account/management/commands/ensuresuperuser.py b/server/account/management/commands/ensuresuperuser.py index d3d1e586..08ec989f 100644 --- a/server/account/management/commands/ensuresuperuser.py +++ b/server/account/management/commands/ensuresuperuser.py @@ -1,6 +1,7 @@ from django.core.management.base import BaseCommand from account.models import User +from account.services import UserService class Command(BaseCommand): @@ -14,4 +15,4 @@ def add_arguments(self, parser): def handle(self, *args, **options): if not User.objects.exists(): - User.objects.create_superuser(email=options["email"], password=options["password"]) + UserService.create_superuser(email=options["email"], password=options["password"]) diff --git a/server/account/models.py b/server/account/models.py index e6c69235..bcc152b6 100644 --- a/server/account/models.py +++ b/server/account/models.py @@ -4,32 +4,22 @@ class UserManager(BaseUserManager): - def _create_user(self, email, password=None, **extra_fields): - if not email: - raise ValueError("The email field must be set") - - email = self.normalize_email(email) - user = self.model(email=email, **extra_fields) - if extra_fields.get("is_service_account", False): - user.set_unusable_password() - else: - user.set_password(password) - user.save(using=self._db) - return user - - def create_user(self, username=None, email=None, password=None, **extra_fields): - extra_fields.setdefault("is_staff", False) - extra_fields.setdefault("is_superuser", False) - return self._create_user(email, password, **extra_fields) - - def create_superuser(self, username=None, email=None, password=None, **extra_fields): - extra_fields.setdefault("is_staff", True) - extra_fields.setdefault("is_superuser", True) - return self._create_user(email, password, **extra_fields) - - def create_service_account(self, email, **extra_fields): - extra_fields.setdefault("is_service_account", True) - return self._create_user(email, password=None, **extra_fields) + def create_user(self, username=None, email=None, password=None, role=None, **extra_fields): + from account.services.user import UserService + + return UserService.create_user(username=username, email=email, password=password, role=role, **extra_fields) + + def create_superuser(self, username=None, email=None, password=None, role=None, **extra_fields): + from account.services.user import UserService + + return UserService.create_superuser( + username=username, email=email, password=password, role=role, **extra_fields + ) + + def create_service_account(self, email, role=None, **extra_fields): + from account.services.user import UserService + + return UserService.create_service_account(email=email, role=role, **extra_fields) class User(AbstractUser): diff --git a/server/account/pipelines.py b/server/account/pipelines.py index cd955482..253dc280 100644 --- a/server/account/pipelines.py +++ b/server/account/pipelines.py @@ -2,6 +2,7 @@ from django.conf import settings from social_core.exceptions import AuthForbidden + from utils.functions import import_from_string logger = logging.getLogger(__name__) diff --git a/server/account/serializers.py b/server/account/serializers.py index 4d0f0d37..42366824 100644 --- a/server/account/serializers.py +++ b/server/account/serializers.py @@ -2,7 +2,7 @@ from account.models import User -from .services import ServiceAccountNameService +from .services import ServiceAccountNameService, UserService class AccountSerializer(serializers.Serializer): @@ -40,7 +40,7 @@ def create(self, validated_data): is_active = validated_data.get("is_active") is_staff = validated_data.get("is_staff") dataset_ids = validated_data.get("data_set_ids", []) - service_account = User.objects.create_service_account(email=email, is_active=is_active, is_staff=is_staff) + service_account = UserService.create_service_account(email=email, is_active=is_active, is_staff=is_staff) service_account.data_sets.add(*dataset_ids) return service_account diff --git a/server/account/services/__init__.py b/server/account/services/__init__.py index 275703f0..20113312 100644 --- a/server/account/services/__init__.py +++ b/server/account/services/__init__.py @@ -1,4 +1,5 @@ from .service_accounts import ServiceAccountNameService from .sso_provider import SSOProviderService +from .user import UserService -__all__ = ["ServiceAccountNameService", "SSOProviderService"] +__all__ = ["ServiceAccountNameService", "SSOProviderService", "UserService"] diff --git a/server/account/services/user.py b/server/account/services/user.py new file mode 100644 index 00000000..3ba3a190 --- /dev/null +++ b/server/account/services/user.py @@ -0,0 +1,109 @@ +import logging +from typing import Optional, Type + +from django.contrib.auth import get_user_model +from django.db import models +from guardian.shortcuts import assign_perm, remove_perm + +from catalog.utils import AdminRole, BaseRole, UserRole + +User = get_user_model() +logger = logging.getLogger(__name__) + + +class UserService: + @staticmethod + def _create_user_base(email, password=None, **extra_fields): + if not email: + raise ValueError("The email field must be set") + + from django.contrib.auth.base_user import BaseUserManager + + manager = BaseUserManager() + email = manager.normalize_email(email) + user = User(email=email, **extra_fields) + + if extra_fields.get("is_service_account", False): + user.set_unusable_password() + else: + if password: + user.set_password(password) + + user.save() + return user + + @staticmethod + def create_user( + username=None, + email=None, + password=None, + role: Type[BaseRole] = UserRole, + is_staff: bool = False, + is_superuser: bool = False, + **extra_fields, + ) -> User: + extra_fields.setdefault("is_staff", is_staff) + extra_fields.setdefault("is_superuser", is_superuser) + + user = UserService._create_user_base(email, password, **extra_fields) + UserService.assign_role(user, role) + + return user + + @staticmethod + def create_superuser( + username=None, email=None, password=None, role: Type[BaseRole] = AdminRole, **extra_fields + ) -> User: + extra_fields.setdefault("is_staff", True) + extra_fields.setdefault("is_superuser", True) + + user = UserService._create_user_base(email, password, **extra_fields) + + UserService.assign_role(user, role) + + return user + + @staticmethod + def create_service_account(email, role: Optional[Type[BaseRole]] = None, **extra_fields) -> User: + extra_fields.setdefault("is_service_account", True) + + user = UserService._create_user_base(email, password=None, **extra_fields) + + if role is None: + if not extra_fields.get("is_staff", False): + role = AdminRole + else: + role = UserRole + + if role is not None: + UserService.assign_role(user, role) + + return user + + @staticmethod + def assign_role(user: User, role: Type[BaseRole], obj: Optional[models.Model] = None) -> None: + try: + role.assign(user, obj) + except Exception as e: + logger.error(f"Failed to assign {role.__name__} role to user {user.id}: {str(e)}", exc_info=True) + raise + + @staticmethod + def assign_permission(user: User, permission: str, obj: Optional[models.Model] = None) -> None: + try: + assign_perm(permission, user, obj) + except Exception as e: + logger.error(f"Failed to assign permission '{permission}' to user {user.id}: {str(e)}", exc_info=True) + raise + + @staticmethod + def remove_permission(user: User, permission: str, obj: Optional[models.Model] = None) -> None: + try: + remove_perm(permission, user, obj) + except Exception as e: + logger.error(f"Failed to remove permission '{permission}' from user {user.id}: {str(e)}", exc_info=True) + raise + + @staticmethod + def assign_role_to_object(user: User, role: Type[BaseRole], obj: models.Model) -> None: + UserService.assign_role(user, role, obj) diff --git a/server/account/tests/test_views/test_service_accounts.py b/server/account/tests/test_views/test_service_accounts.py new file mode 100644 index 00000000..24294f67 --- /dev/null +++ b/server/account/tests/test_views/test_service_accounts.py @@ -0,0 +1,245 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status +from rest_framework.authtoken.models import Token + +from account.models import User + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(): + return reverse("service_account_list") + + +@pytest.fixture +def detail_url(service_account): + return f"/api/service_accounts/{service_account.id}" + + +@pytest.fixture +def reset_token_url(service_account): + return reverse("reset_token", kwargs={"id": service_account.id}) + + +@pytest.fixture +def check_name_url(): + return reverse("check_service_name") + + +@pytest.fixture +def service_account(): + return baker.make(User, is_service_account=True) + + +@pytest.fixture +def payload(): + return { + "name": "test-service", + "is_active": True, + "is_staff": False, + } + + +@pytest.fixture +def update_payload(): + return { + "name": "updated-service", + "is_active": False, + "is_staff": True, + } + + +@pytest.fixture +def check_name_payload(): + return {"name": "test-service-name"} + + +class TestServiceAccountListViewGet: + def test_user_with_view_permission_can_list_service_accounts(self, api_client_with_user_view_permission, url): + baker.make(User, is_service_account=True, _quantity=2) + + response = api_client_with_user_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) >= 2 + + def test_user_without_permission_cannot_list_service_accounts(self, api_client, url): + baker.make(User, is_service_account=True, _quantity=2) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_only_service_accounts_are_listed(self, api_client_with_user_view_permission, url): + baker.make(User, is_service_account=True, _quantity=2) + baker.make(User, is_service_account=False, _quantity=2) + + response = api_client_with_user_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert all(user.get("is_service_account", True) for user in response.data["results"]) + + +class TestServiceAccountListViewPost: + def test_user_with_add_permission_can_create_service_account( + self, api_client_with_user_add_permission, url, payload + ): + response = api_client_with_user_add_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert "token" in response.data + assert User.objects.filter(email__contains="test-service").exists() + + def test_user_without_permission_cannot_create_service_account(self, api_client, url, payload): + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not User.objects.filter(email__contains="test-service").exists() + + def test_user_with_only_view_permission_cannot_create_service_account( + self, api_client_with_user_view_permission, url, payload + ): + response = api_client_with_user_view_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not User.objects.filter(email__contains="test-service").exists() + + def test_admin_with_global_permission_can_create_service_account(self, admin_api_client, url, payload): + response = admin_api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert "token" in response.data + assert User.objects.filter(email__contains="test-service").exists() + + +class TestServiceAccountViewPatch: + def test_user_with_change_permission_can_update_service_account( + self, api_client_with_user_change_permission, detail_url, update_payload, service_account + ): + response = api_client_with_user_change_permission.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + service_account.refresh_from_db() + assert service_account.email == f"updated-service@{service_account.email.split('@')[1]}" + assert service_account.is_active is False + assert service_account.is_staff is True + + def test_user_without_permission_cannot_update_service_account( + self, api_client, detail_url, update_payload, service_account + ): + original_email = service_account.email + + response = api_client.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + service_account.refresh_from_db() + assert service_account.email == original_email + + def test_user_with_only_view_permission_cannot_update_service_account( + self, api_client_with_user_view_permission, detail_url, update_payload, service_account + ): + original_email = service_account.email + + response = api_client_with_user_view_permission.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + service_account.refresh_from_db() + assert service_account.email == original_email + + def test_user_cannot_update_regular_user_as_service_account(self, api_client_with_user_change_permission): + regular_user = baker.make(User, is_service_account=False) + url = f"/api/service_accounts/{regular_user.id}" + payload = {"name": "test-service"} + + response = api_client_with_user_change_permission.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_admin_with_global_permission_can_update_service_account( + self, admin_api_client, detail_url, update_payload, service_account + ): + response = admin_api_client.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + + +class TestResetTokenViewPost: + def test_user_with_change_permission_can_reset_token( + self, api_client_with_user_change_permission, reset_token_url, service_account + ): + old_token, _ = Token.objects.get_or_create(user=service_account) + old_token_key = old_token.key + + response = api_client_with_user_change_permission.post(reset_token_url) + + assert response.status_code == status.HTTP_200_OK + assert "token" in response.data + assert response.data["token"] != old_token_key + new_token = Token.objects.get(user=service_account) + assert new_token.key == response.data["token"] + + def test_user_without_permission_cannot_reset_token(self, api_client, reset_token_url, service_account): + old_token, _ = Token.objects.get_or_create(user=service_account) + old_token_key = old_token.key + + response = api_client.post(reset_token_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + token = Token.objects.get(user=service_account) + assert token.key == old_token_key + + def test_user_with_only_view_permission_cannot_reset_token( + self, api_client_with_user_view_permission, reset_token_url, service_account + ): + old_token, _ = Token.objects.get_or_create(user=service_account) + old_token_key = old_token.key + + response = api_client_with_user_view_permission.post(reset_token_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + token = Token.objects.get(user=service_account) + assert token.key == old_token_key + + def test_user_cannot_reset_token_for_regular_user(self, api_client_with_user_change_permission): + regular_user = baker.make(User, is_service_account=False) + url = reverse("reset_token", kwargs={"id": regular_user.id}) + + response = api_client_with_user_change_permission.post(url) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_admin_with_global_permission_can_reset_token(self, admin_api_client, reset_token_url, service_account): + old_token, _ = Token.objects.get_or_create(user=service_account) + old_token_key = old_token.key + + response = admin_api_client.post(reset_token_url) + + assert response.status_code == status.HTTP_200_OK + assert "token" in response.data + assert response.data["token"] != old_token_key + + +class TestCheckServiceNameViewPost: + def test_user_with_view_permission_can_check_service_name( + self, api_client_with_user_view_permission, check_name_url, check_name_payload + ): + response = api_client_with_user_view_permission.post(check_name_url, check_name_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + assert "is_available" in response.data + + def test_user_without_permission_cannot_check_service_name(self, api_client, check_name_url, check_name_payload): + response = api_client.post(check_name_url, check_name_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_admin_with_global_permission_can_check_service_name( + self, admin_api_client, check_name_url, check_name_payload + ): + response = admin_api_client.post(check_name_url, check_name_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + assert "is_available" in response.data diff --git a/server/account/tests/test_views/test_users.py b/server/account/tests/test_views/test_users.py new file mode 100644 index 00000000..98bc6f16 --- /dev/null +++ b/server/account/tests/test_views/test_users.py @@ -0,0 +1,205 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status + +from account.models import User + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(): + return reverse("user_list") + + +@pytest.fixture +def detail_url(user): + return reverse("user_details", kwargs={"id": user.id}) + + +@pytest.fixture +def password_url(user): + return reverse("user_password", kwargs={"id": user.id}) + + +@pytest.fixture +def payload(): + return { + "email": "newuser@example.com", + "password": "testpass123", + "is_active": True, + "is_staff": False, + } + + +@pytest.fixture +def update_payload(): + return { + "email": "updated@example.com", + "is_active": False, + "is_staff": True, + } + + +@pytest.fixture +def password_payload(): + return {"password": "newpassword123"} + + +class TestUserListViewGet: + def test_user_with_view_permission_can_list_users(self, api_client_with_user_view_permission, url): + baker.make(User, is_service_account=False, _quantity=2) + + response = api_client_with_user_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) >= 2 + + def test_user_without_permission_cannot_list_users(self, api_client, url): + baker.make(User, is_service_account=False, _quantity=2) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_service_accounts_are_excluded_from_list(self, api_client_with_user_view_permission, url): + baker.make(User, is_service_account=False, _quantity=2) + baker.make(User, is_service_account=True, _quantity=2) + + response = api_client_with_user_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert all(not user.get("is_service_account", False) for user in response.data["results"]) + + +class TestUserListViewPost: + def test_user_with_add_permission_can_create_user(self, api_client_with_user_add_permission, url, payload): + response = api_client_with_user_add_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert User.objects.filter(email="newuser@example.com").exists() + + def test_user_without_permission_cannot_create_user(self, api_client, url, payload): + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not User.objects.filter(email="newuser@example.com").exists() + + def test_user_with_only_view_permission_cannot_create_user( + self, api_client_with_user_view_permission, url, payload + ): + response = api_client_with_user_view_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not User.objects.filter(email="newuser@example.com").exists() + + def test_admin_with_global_permission_can_create_user(self, admin_api_client, url, payload): + response = admin_api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert User.objects.filter(email="newuser@example.com").exists() + + +class TestUserViewPatch: + def test_user_with_change_permission_can_update_user( + self, api_client_with_user_change_permission, detail_url, update_payload, user + ): + response = api_client_with_user_change_permission.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + user.refresh_from_db() + assert user.email == "updated@example.com" + assert user.is_active is False + assert user.is_staff is True + + def test_user_without_permission_cannot_update_user(self, api_client, detail_url, update_payload, user): + original_email = user.email + + response = api_client.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + user.refresh_from_db() + assert user.email == original_email + + def test_user_with_only_view_permission_cannot_update_user( + self, api_client_with_user_view_permission, detail_url, update_payload, user + ): + original_email = user.email + + response = api_client_with_user_view_permission.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + user.refresh_from_db() + assert user.email == original_email + + def test_user_cannot_update_service_account(self, api_client_with_user_change_permission): + service_account = baker.make(User, is_service_account=True) + url = reverse("user_details", kwargs={"id": service_account.id}) + payload = {"email": "updated@example.com"} + + response = api_client_with_user_change_permission.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_admin_with_global_permission_can_update_user(self, admin_api_client, detail_url, update_payload, user): + response = admin_api_client.patch(detail_url, update_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + user.refresh_from_db() + assert user.email == "updated@example.com" + + +class TestUserPasswordViewPatch: + def test_user_with_change_permission_can_update_password( + self, api_client_with_user_change_permission, password_url, password_payload, user + ): + old_password_hash = user.password + + response = api_client_with_user_change_permission.patch(password_url, password_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + user.refresh_from_db() + assert user.password != old_password_hash + assert user.check_password("newpassword123") + + def test_user_without_permission_cannot_update_password(self, api_client, password_url, password_payload, user): + old_password_hash = user.password + + response = api_client.patch(password_url, password_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + user.refresh_from_db() + assert user.password == old_password_hash + + def test_user_with_only_view_permission_cannot_update_password( + self, api_client_with_user_view_permission, password_url, password_payload, user + ): + old_password_hash = user.password + + response = api_client_with_user_view_permission.patch(password_url, password_payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + user.refresh_from_db() + assert user.password == old_password_hash + + def test_user_cannot_update_service_account_password(self, api_client_with_user_change_permission): + service_account = baker.make(User, is_service_account=True) + url = reverse("user_password", kwargs={"id": service_account.id}) + payload = {"password": "newpassword123"} + + response = api_client_with_user_change_permission.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_admin_with_global_permission_can_update_password( + self, admin_api_client, password_url, password_payload, user + ): + old_password_hash = user.password + + response = admin_api_client.patch(password_url, password_payload, format="json") + + assert response.status_code == status.HTTP_200_OK + user.refresh_from_db() + assert user.password != old_password_hash + assert user.check_password("newpassword123") diff --git a/server/account/views/service_accounts.py b/server/account/views/service_accounts.py index c3914801..2658dd0a 100644 --- a/server/account/views/service_accounts.py +++ b/server/account/views/service_accounts.py @@ -1,9 +1,11 @@ +from django.shortcuts import get_object_or_404 +from django.utils.decorators import method_decorator from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework import status from rest_framework.authtoken.models import Token from rest_framework.generics import ListAPIView -from rest_framework.permissions import IsAdminUser +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView @@ -15,11 +17,15 @@ TokenResponseSerializer, ) from account.services.service_accounts import ServiceAccountNameService +from catalog.utils import ModelPermissions, permission_required_with_global_perms + +_user_perms = ModelPermissions(User) class CheckServiceNameView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] + @method_decorator(permission_required_with_global_perms(_user_perms.view)) @swagger_auto_schema( operation_description="Check if a service account name is available", request_body=openapi.Schema( @@ -37,8 +43,9 @@ def post(self, request): class ResetTokenView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] + @method_decorator(permission_required_with_global_perms(_user_perms.change, (User, "pk", "id"))) @swagger_auto_schema( operation_description="Revoke and regenerate a token for a service account", manual_parameters=[ @@ -47,7 +54,7 @@ class ResetTokenView(APIView): responses={200: TokenResponseSerializer}, ) def post(self, request, id): - service_account = User.objects.get(id=id, is_service_account=True) + service_account = get_object_or_404(User, id=id, is_service_account=True) Token.objects.filter(user=service_account).delete() token, created = Token.objects.get_or_create(user=service_account) serializer = TokenResponseSerializer({"token": token.key}) @@ -55,16 +62,18 @@ def post(self, request, id): class ServiceAccountListView(ListAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = ServiceAccountSerializer queryset = User.objects.filter(is_service_account=True).order_by("-date_joined") + @method_decorator(permission_required_with_global_perms(_user_perms.view)) @swagger_auto_schema( operation_description="Get list of service accounts", ) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) + @method_decorator(permission_required_with_global_perms(_user_perms.add)) @swagger_auto_schema( operation_description="Create a new service account", request_body=CreateUpdateServiceAccountSerializer ) @@ -80,8 +89,9 @@ def post(self, request): class ServiceAccountView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] + @method_decorator(permission_required_with_global_perms(_user_perms.change, (User, "pk", "id"))) @swagger_auto_schema( operation_description="Update a service account", request_body=CreateUpdateServiceAccountSerializer, @@ -90,7 +100,7 @@ class ServiceAccountView(APIView): ], ) def patch(self, request, id): - service_account = User.objects.get(id=id, is_service_account=True) + service_account = get_object_or_404(User, id=id, is_service_account=True) serializer = CreateUpdateServiceAccountSerializer(service_account, data=request.data, partial=True) serializer.is_valid(raise_exception=True) serializer.save() diff --git a/server/account/views/users.py b/server/account/views/users.py index 6805ec52..0ff14067 100644 --- a/server/account/views/users.py +++ b/server/account/views/users.py @@ -1,23 +1,36 @@ +from django.utils.decorators import method_decorator from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema -from rest_framework.generics import ListCreateAPIView, UpdateAPIView -from rest_framework.permissions import IsAdminUser +from rest_framework.generics import ListCreateAPIView, UpdateAPIView, get_object_or_404 +from rest_framework.permissions import IsAuthenticated from account.models import User from account.serializers import UserSerializer, UserUpdatePasswordSerializer, UserUpdateSerializer +from account.services import UserService +from catalog.utils import ModelPermissions, permission_required_with_global_perms + +_user_perms = ModelPermissions(User) class UserListView(ListCreateAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = UserSerializer + @method_decorator(permission_required_with_global_perms(_user_perms.view)) @swagger_auto_schema(operation_description="List all users", manual_parameters=[]) + def get(self, request, *args, **kwargs): + return super().get(request, *args, **kwargs) + def get_queryset(self): return User.objects.filter(is_service_account=False) + @method_decorator(permission_required_with_global_perms(_user_perms.add)) @swagger_auto_schema(operation_description="Create a new user", request_body=UserSerializer) + def post(self, request, *args, **kwargs): + return super().post(request, *args, **kwargs) + def perform_create(self, serializer): - User.objects.create_user( + UserService.create_user( email=serializer.validated_data["email"], password=serializer.validated_data["password"], is_active=serializer.validated_data["is_active"], @@ -26,7 +39,7 @@ def perform_create(self, serializer): class UserView(UpdateAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = UserUpdateSerializer @swagger_auto_schema( @@ -37,7 +50,11 @@ class UserView(UpdateAPIView): ], ) def get_object(self): - return User.objects.get(id=self.kwargs["id"], is_service_account=False) + return get_object_or_404(User, id=self.kwargs["id"], is_service_account=False) + + @method_decorator(permission_required_with_global_perms(_user_perms.change, (User, "pk", "id"))) + def patch(self, request, *args, **kwargs): + return super().patch(request, *args, **kwargs) def perform_update(self, serializer): user = self.get_object() @@ -48,7 +65,7 @@ def perform_update(self, serializer): class UserPasswordView(UpdateAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = UserUpdatePasswordSerializer @swagger_auto_schema( @@ -59,7 +76,11 @@ class UserPasswordView(UpdateAPIView): ], ) def get_object(self): - return User.objects.get(id=self.kwargs["id"], is_service_account=False) + return get_object_or_404(User, id=self.kwargs["id"], is_service_account=False) + + @method_decorator(permission_required_with_global_perms(_user_perms.change, (User, "pk", "id"))) + def patch(self, request, *args, **kwargs): + return super().patch(request, *args, **kwargs) def perform_update(self, serializer): user = self.get_object() diff --git a/server/agent/conversation/tests/test_manager.py b/server/agent/conversation/tests/test_manager.py index bb43c8a9..cfcdf906 100644 --- a/server/agent/conversation/tests/test_manager.py +++ b/server/agent/conversation/tests/test_manager.py @@ -6,6 +6,7 @@ from django.core.exceptions import ObjectDoesNotExist from model_bakery import baker +from account.services import UserService from agent.conversation.manager import ConversationManager from agent.models import Agent, Conversation, Message from catalog.models import DataSet @@ -132,7 +133,7 @@ def test_get_conversation_with_invalid_conversation_id(self): def test_get_conversation_with_wrong_user_ownership(self): """Test conversation retrieval with conversation owned by different user.""" # Given - other_user = get_user_model().objects.create_user(email="other@example.com", password="testpass123") + other_user = UserService.create_user(email="other@example.com", password="testpass123") other_data_set = DataSet.objects.create(name="Other DataSet") other_user.data_sets.add(other_data_set) diff --git a/server/agent/tests/test_views.py b/server/agent/tests/test_views.py index 512a849f..398c84a8 100644 --- a/server/agent/tests/test_views.py +++ b/server/agent/tests/test_views.py @@ -12,9 +12,11 @@ from rest_framework.test import APIClient from account.models import User +from account.services import UserService from agent.models import Conversation from agent.models.agent import Agent from catalog.models import DataSet +from catalog.utils import AdminRole pytestmark = pytest.mark.django_db @@ -181,8 +183,7 @@ def test_get_returns_ordered_by_created_at(self, api_client, url, dataset_instan assert response.data[1]["id"] == newer.id def test_get_returns_corrupted_agents_to_admin(self, user, api_client, url, dataset_instance): - user.is_staff = True - user.save() + UserService.assign_role(user, AdminRole, dataset_instance) agent_1 = baker.make(Agent, dataset=dataset_instance) agent_2 = baker.make(Agent, dataset=dataset_instance, corrupted=True) diff --git a/server/agent/views.py b/server/agent/views.py index 245a256b..b33d8210 100644 --- a/server/agent/views.py +++ b/server/agent/views.py @@ -41,6 +41,9 @@ ) from agent.tasks import process_file_upload_task, respond_to_user_message_task from catalog.models import DataSet +from catalog.utils import ModelPermissions + +_ds_perms = ModelPermissions(DataSet) class GetTaskStatus(APIView): @@ -292,10 +295,17 @@ class AgentView(APIView): responses={200: AgentListSerializer(many=True)}, ) def get(self, request): - if request.user.is_staff: - queryset = Agent.objects.all().order_by("created_at") - else: + hide_corrupted_agents = not request.user.has_perm(_ds_perms.change) + if hide_corrupted_agents and request.GET.get("dataset"): + try: + dataset = DataSet.objects.get(pk=request.GET["dataset"]) + hide_corrupted_agents = not request.user.has_perm(_ds_perms.change, dataset) + except DataSet.DoesNotExist: + pass + if hide_corrupted_agents: queryset = Agent.objects.filter(corrupted=False).order_by("created_at") + else: + queryset = Agent.objects.all().order_by("created_at") filterset = AgentFilter(request.GET, queryset=queryset) if not filterset.is_valid(): return Response(filterset.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/server/catalog/models/data_set.py b/server/catalog/models/data_set.py index 1e777f84..d33e1c72 100644 --- a/server/catalog/models/data_set.py +++ b/server/catalog/models/data_set.py @@ -19,3 +19,6 @@ class Meta: "List of various data sets. One data set may be the whole company's content such as blog " "posts, or some part of it: a data set may be represent a brand or department." ) + permissions = [ + ("manage_dataset_users", "Can manage users of a dataset"), + ] diff --git a/server/catalog/tests/test_views/test_data_set_detail_view.py b/server/catalog/tests/test_views/test_data_set_detail_view.py new file mode 100644 index 00000000..1d37afac --- /dev/null +++ b/server/catalog/tests/test_views/test_data_set_detail_view.py @@ -0,0 +1,94 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status + +from catalog.models import DataSet + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(data_set): + return reverse("data_set_detail", kwargs={"data_set_id": data_set.id}) + + +class TestDataSetDetailViewGet: + def test_user_with_view_permission_can_access_dataset(self, api_client_with_view_permission, url, data_set): + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert response.data["id"] == data_set.id + + def test_user_without_permission_cannot_access_dataset(self, api_client, url, data_set): + response = api_client.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_user_with_global_view_permission_can_access_any_dataset(self, admin_api_client, data_set): + other_data_set = baker.make(DataSet) + url = reverse("data_set_detail", kwargs={"data_set_id": other_data_set.id}) + + response = admin_api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + assert response.data["id"] == other_data_set.id + + def test_user_with_object_level_permission_cannot_access_other_datasets( + self, api_client_with_view_permission, data_set + ): + other_data_set = baker.make(DataSet) + url = reverse("data_set_detail", kwargs={"data_set_id": other_data_set.id}) + + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetDetailViewPatch: + @pytest.fixture + def payload(self): + return {"name": "Updated DataSet Name"} + + def test_user_with_change_permission_can_update_dataset( + self, api_client_with_change_permission, url, payload, data_set + ): + response = api_client_with_change_permission.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_200_OK + assert response.data["name"] == "Updated DataSet Name" + data_set.refresh_from_db() + assert data_set.name == "Updated DataSet Name" + + def test_user_with_view_only_permission_cannot_update_dataset( + self, api_client_with_view_permission, url, payload, data_set + ): + response = api_client_with_view_permission.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + data_set.refresh_from_db() + assert data_set.name != "Updated DataSet Name" + + def test_user_without_permission_cannot_update_dataset(self, api_client, url, payload, data_set): + response = api_client.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_user_with_object_level_change_permission_cannot_update_other_datasets( + self, api_client_with_change_permission, payload, data_set + ): + other_data_set = baker.make(DataSet) + url = reverse("data_set_detail", kwargs={"data_set_id": other_data_set.id}) + + response = api_client_with_change_permission.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_admin_with_global_permission_can_update_any_dataset(self, admin_api_client, payload, data_set): + other_data_set = baker.make(DataSet) + url = reverse("data_set_detail", kwargs={"data_set_id": other_data_set.id}) + + response = admin_api_client.patch(url, payload, format="json") + + assert response.status_code == status.HTTP_200_OK + assert response.data["name"] == "Updated DataSet Name" diff --git a/server/catalog/tests/test_views/test_data_set_document_source_view_permissions.py b/server/catalog/tests/test_views/test_data_set_document_source_view_permissions.py new file mode 100644 index 00000000..cefb60c8 --- /dev/null +++ b/server/catalog/tests/test_views/test_data_set_document_source_view_permissions.py @@ -0,0 +1,136 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status + +from catalog.models import DataSet, DocumentSource + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(data_set): + return reverse("data_set_document_source_list", kwargs={"data_set_id": data_set.id}) + + +@pytest.fixture +def detail_url(data_set, document_source): + return reverse( + "data_set_document_source_details", + kwargs={"data_set_id": data_set.id, "document_source_id": document_source.id}, + ) + + +@pytest.fixture +def payload(): + return {"plugin_name": "Test Source", "config": {}} + + +class TestDataSetDocumentSourceListViewGet: + def test_user_with_view_permission_can_list_sources(self, api_client_with_view_permission, url, data_set): + baker.make(DocumentSource, data_set=data_set, _quantity=2) + + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 2 + + def test_user_without_permission_cannot_list_sources(self, api_client, url, data_set): + baker.make(DocumentSource, data_set=data_set, _quantity=2) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_user_with_view_permission_only_sees_sources_from_their_dataset( + self, api_client_with_view_permission, data_set + ): + other_dataset = baker.make(DataSet) + baker.make(DocumentSource, data_set=data_set, _quantity=2) + baker.make(DocumentSource, data_set=other_dataset, _quantity=3) + url = reverse("data_set_document_source_list", kwargs={"data_set_id": data_set.id}) + + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 2 + + +class TestDataSetDocumentSourceListViewPost: + def test_user_with_change_permission_can_create_source( + self, api_client_with_change_permission, url, payload, data_set + ): + response = api_client_with_change_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert DocumentSource.objects.filter(data_set=data_set, plugin_name="Test Source").exists() + + def test_user_without_change_permission_cannot_create_source(self, api_client, url, payload, data_set): + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not DocumentSource.objects.filter(data_set=data_set, plugin_name="Test Source").exists() + + def test_user_with_view_only_permission_cannot_create_source( + self, api_client_with_view_permission, url, payload, data_set + ): + response = api_client_with_view_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetDocumentSourceViewGet: + def test_user_with_view_permission_can_get_source( + self, api_client_with_view_permission, detail_url, document_source + ): + response = api_client_with_view_permission.get(detail_url) + + assert response.status_code == status.HTTP_200_OK + assert response.data["id"] == document_source.id + + def test_user_without_permission_cannot_get_source(self, api_client, detail_url, document_source): + response = api_client.get(detail_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetDocumentSourceViewPatch: + @pytest.fixture + def payload(self): + return {"plugin_name": "Updated Source"} + + def test_user_with_change_permission_can_update_source( + self, api_client_with_change_permission, detail_url, payload, document_source + ): + response = api_client_with_change_permission.patch(detail_url, payload, format="json") + + assert response.status_code == status.HTTP_200_OK + document_source.refresh_from_db() + assert document_source.plugin_name == "Updated Source" + + def test_user_without_change_permission_cannot_update_source( + self, api_client, detail_url, payload, document_source + ): + response = api_client.patch(detail_url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetDocumentSourceViewDelete: + def test_user_with_delete_permission_can_delete_source( + self, api_client_with_change_permission, detail_url, document_source + ): + source_id = document_source.id + + response = api_client_with_change_permission.delete(detail_url) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert not DocumentSource.objects.filter(id=source_id).exists() + + def test_user_without_delete_permission_cannot_delete_source(self, api_client, detail_url, document_source): + source_id = document_source.id + + response = api_client.delete(detail_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert DocumentSource.objects.filter(id=source_id).exists() diff --git a/server/catalog/tests/test_views/test_data_set_list_view_permissions.py b/server/catalog/tests/test_views/test_data_set_list_view_permissions.py new file mode 100644 index 00000000..db56b11c --- /dev/null +++ b/server/catalog/tests/test_views/test_data_set_list_view_permissions.py @@ -0,0 +1,104 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status + +from account.services import UserService +from catalog.models import DataSet +from catalog.utils import AdminRole, UserRole + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(): + return reverse("data_set_list") + + +@pytest.fixture +def payload(): + return {"name": "New DataSet"} + + +class TestDataSetListViewGet: + def test_user_with_view_permission_sees_only_accessible_datasets(self, api_client, user, url): + accessible_dataset = baker.make(DataSet) + inaccessible_dataset = baker.make(DataSet) + UserService.assign_role(user, UserRole, accessible_dataset) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + dataset_ids = [ds["id"] for ds in response.data["results"]] + assert accessible_dataset.id in dataset_ids + assert inaccessible_dataset.id not in dataset_ids + + def test_user_without_permissions_sees_no_datasets(self, api_client, url): + baker.make(DataSet, _quantity=3) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 0 + + def test_admin_with_global_permission_sees_all_datasets(self, admin_api_client, url): + baker.make(DataSet, _quantity=3) + + response = admin_api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 3 + + def test_user_with_multiple_dataset_permissions_sees_all(self, api_client, user, url): + dataset1 = baker.make(DataSet) + dataset2 = baker.make(DataSet) + dataset3 = baker.make(DataSet) + UserService.assign_role(user, UserRole, dataset1) + UserService.assign_role(user, UserRole, dataset2) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + dataset_ids = [ds["id"] for ds in response.data["results"]] + assert dataset1.id in dataset_ids + assert dataset2.id in dataset_ids + assert dataset3.id not in dataset_ids + + +class TestDataSetListViewPost: + def test_user_with_add_permission_can_create_dataset(self, api_client, user, url, payload): + UserService.assign_role(user, AdminRole) + + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert DataSet.objects.filter(name="New DataSet").exists() + + def test_user_without_add_permission_cannot_create_dataset(self, api_client, url, payload): + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not DataSet.objects.filter(name="New DataSet").exists() + + def test_user_with_only_view_permission_cannot_create_dataset(self, api_client, user, url, payload): + UserService.assign_role(user, UserRole) + + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not DataSet.objects.filter(name="New DataSet").exists() + + def test_admin_with_global_add_permission_can_create_dataset(self, admin_api_client, url, payload): + response = admin_api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert DataSet.objects.filter(name="New DataSet").exists() + + def test_created_dataset_includes_creator_in_users(self, api_client, user, url, payload): + UserService.assign_role(user, AdminRole) + + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + dataset = DataSet.objects.get(name="New DataSet") + assert user in dataset.users.all() diff --git a/server/catalog/tests/test_views/test_data_set_product_source_view_permissions.py b/server/catalog/tests/test_views/test_data_set_product_source_view_permissions.py new file mode 100644 index 00000000..541033fe --- /dev/null +++ b/server/catalog/tests/test_views/test_data_set_product_source_view_permissions.py @@ -0,0 +1,133 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status + +from catalog.models import DataSet, ProductSource + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(data_set): + return reverse("data_set_product_source_list", kwargs={"data_set_id": data_set.id}) + + +@pytest.fixture +def detail_url(data_set, product_source): + return reverse( + "data_set_product_source_details", kwargs={"data_set_id": data_set.id, "product_source_id": product_source.id} + ) + + +@pytest.fixture +def payload(): + return {"plugin_name": "Test Source", "config": {}} + + +class TestDataSetProductSourceListViewGet: + def test_user_with_view_permission_can_list_sources(self, api_client_with_view_permission, url, data_set): + baker.make(ProductSource, data_set=data_set, _quantity=2) + + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 2 + + def test_user_without_permission_cannot_list_sources(self, api_client, url, data_set): + baker.make(ProductSource, data_set=data_set, _quantity=2) + + response = api_client.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_user_with_view_permission_only_sees_sources_from_their_dataset( + self, api_client_with_view_permission, data_set + ): + other_dataset = baker.make(DataSet) + baker.make(ProductSource, data_set=data_set, _quantity=2) + baker.make(ProductSource, data_set=other_dataset, _quantity=3) + url = reverse("data_set_product_source_list", kwargs={"data_set_id": data_set.id}) + + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 2 + + +class TestDataSetProductSourceListViewPost: + def test_user_with_change_permission_can_create_source( + self, api_client_with_change_permission, url, payload, data_set + ): + response = api_client_with_change_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert ProductSource.objects.filter(data_set=data_set, plugin_name="Test Source").exists() + + def test_user_without_change_permission_cannot_create_source(self, api_client, url, payload, data_set): + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not ProductSource.objects.filter(data_set=data_set, plugin_name="Test Source").exists() + + def test_user_with_view_only_permission_cannot_create_source( + self, api_client_with_view_permission, url, payload, data_set + ): + response = api_client_with_view_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetProductSourceViewGet: + def test_user_with_view_permission_can_get_source( + self, api_client_with_view_permission, detail_url, product_source + ): + response = api_client_with_view_permission.get(detail_url) + + assert response.status_code == status.HTTP_200_OK + assert response.data["id"] == product_source.id + + def test_user_without_permission_cannot_get_source(self, api_client, detail_url, product_source): + response = api_client.get(detail_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetProductSourceViewPatch: + @pytest.fixture + def payload(self): + return {"plugin_name": "Updated Source"} + + def test_user_with_change_permission_can_update_source( + self, api_client_with_change_permission, detail_url, payload, product_source + ): + response = api_client_with_change_permission.patch(detail_url, payload, format="json") + + assert response.status_code == status.HTTP_200_OK + product_source.refresh_from_db() + assert product_source.plugin_name == "Updated Source" + + def test_user_without_change_permission_cannot_update_source(self, api_client, detail_url, payload, product_source): + response = api_client.patch(detail_url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + + +class TestDataSetProductSourceViewDelete: + def test_user_with_delete_permission_can_delete_source( + self, api_client_with_change_permission, detail_url, product_source + ): + source_id = product_source.id + + response = api_client_with_change_permission.delete(detail_url) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert not ProductSource.objects.filter(id=source_id).exists() + + def test_user_without_delete_permission_cannot_delete_source(self, api_client, detail_url, product_source): + source_id = product_source.id + + response = api_client.delete(detail_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert ProductSource.objects.filter(id=source_id).exists() diff --git a/server/catalog/tests/test_views/test_data_set_user_view.py b/server/catalog/tests/test_views/test_data_set_user_view.py new file mode 100644 index 00000000..f449bf40 --- /dev/null +++ b/server/catalog/tests/test_views/test_data_set_user_view.py @@ -0,0 +1,143 @@ +import pytest +from django.urls import reverse +from model_bakery import baker +from rest_framework import status + +from account.models import User +from account.services import UserService +from catalog.utils import AdminRole + +pytestmark = pytest.mark.django_db + + +@pytest.fixture +def url(data_set): + return reverse("data_set_user_list", kwargs={"data_set_id": data_set.id}) + + +@pytest.fixture +def delete_url(data_set, user): + return reverse("data_set_user_details", kwargs={"data_set_id": data_set.id, "user_id": user.id}) + + +class TestDataSetUserListViewGet: + def test_user_with_manage_permission_can_list_users(self, api_client_with_manage_permission, url, data_set): + user1 = baker.make(User) + user2 = baker.make(User) + data_set.users.add(user1, user2) + + response = api_client_with_manage_permission.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) == 2 + + def test_user_without_manage_permission_cannot_list_users(self, api_client, url, data_set): + response = api_client.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_user_with_view_permission_cannot_list_users(self, api_client_with_view_permission, url, data_set): + response = api_client_with_view_permission.get(url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_admin_with_global_permission_can_list_users(self, admin_api_client, admin_user, url, data_set): + UserService.assign_role(admin_user, AdminRole, data_set) + user1 = baker.make(User) + data_set.users.add(user1) + + response = admin_api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + assert len(response.data["results"]) >= 1 + + +class TestDataSetUserListViewPost: + @pytest.fixture + def new_user(self): + return baker.make(User) + + @pytest.fixture + def payload(self, new_user): + return {"user_id": new_user.id} + + def test_user_with_manage_permission_can_add_user( + self, api_client_with_manage_permission, url, payload, data_set, new_user + ): + response = api_client_with_manage_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert data_set.users.filter(id=new_user.id).exists() + + def test_user_without_manage_permission_cannot_add_user(self, api_client, url, payload, data_set, new_user): + response = api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert not data_set.users.filter(id=new_user.id).exists() + + def test_staff_user_gets_admin_role_on_dataset(self, api_client_with_manage_permission, url, data_set): + staff_user = baker.make(User, is_staff=True) + payload = {"user_id": staff_user.id} + + response = api_client_with_manage_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + from guardian.shortcuts import get_user_perms + + perms = [str(p) for p in get_user_perms(staff_user, data_set)] + assert "view_dataset" in perms + assert "change_dataset" in perms + + def test_regular_user_gets_user_role_on_dataset(self, api_client_with_manage_permission, url, data_set): + regular_user = baker.make(User, is_staff=False) + payload = {"user_id": regular_user.id} + + response = api_client_with_manage_permission.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + from guardian.shortcuts import get_user_perms + + perms = [str(p) for p in get_user_perms(regular_user, data_set)] + assert "view_dataset" in perms + assert "change_dataset" not in perms + + def test_admin_with_global_permission_can_add_user( + self, admin_api_client, admin_user, url, payload, data_set, new_user + ): + UserService.assign_role(admin_user, AdminRole, data_set) + + response = admin_api_client.post(url, payload, format="json") + + assert response.status_code == status.HTTP_201_CREATED + assert data_set.users.filter(id=new_user.id).exists() + + +class TestDataSetUserViewDelete: + def test_user_with_manage_permission_can_remove_user( + self, api_client_with_manage_permission, delete_url, data_set, user + ): + data_set.users.add(user) + + response = api_client_with_manage_permission.delete(delete_url) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert not data_set.users.filter(id=user.id).exists() + + def test_user_without_manage_permission_cannot_remove_user(self, api_client, delete_url, data_set, user): + data_set.users.add(user) + + response = api_client.delete(delete_url) + + assert response.status_code == status.HTTP_403_FORBIDDEN + assert data_set.users.filter(id=user.id).exists() + + def test_admin_with_global_permission_can_remove_user( + self, admin_api_client, admin_user, delete_url, data_set, user + ): + UserService.assign_role(admin_user, AdminRole, data_set) + data_set.users.add(user) + + response = admin_api_client.delete(delete_url) + + assert response.status_code == status.HTTP_204_NO_CONTENT + assert not data_set.users.filter(id=user.id).exists() diff --git a/server/catalog/utils.py b/server/catalog/utils.py index a83208b1..82753f3b 100644 --- a/server/catalog/utils.py +++ b/server/catalog/utils.py @@ -1,6 +1,10 @@ -from typing import Type +from typing import Literal, Optional, Type +from django.contrib.auth import get_user_model +from django.db import models from drf_yasg import openapi +from guardian.decorators import permission_required_or_403 +from guardian.shortcuts import assign_perm from pydantic import BaseModel from pydantic import ValidationError as PydanticValidationError from rest_framework import serializers @@ -9,6 +13,160 @@ from sync.base import SourcePluginRegistry +User = get_user_model() + + +def get_model_permission(model: Type[models.Model], action: Literal["view", "change", "delete", "add"]) -> str: + app_label = model._meta.app_label + model_name = model._meta.model_name + return f"{app_label}.{action}_{model_name}" + + +class ModelPermissions: + """ + Helper class to access all permissions (built-in and custom) for a model. + + Usage: + perms = ModelPermissions(DataSet) + # Built-in permissions: + # perms.view -> "catalog.view_dataset" + # perms.add -> "catalog.add_dataset" + # perms.change -> "catalog.change_dataset" + # perms.delete -> "catalog.delete_dataset" + # Custom permissions (from Meta.permissions): + # perms.manage_dataset_users -> "catalog.manage_dataset_users" + """ + + def __init__(self, model: Type[models.Model]): + self.model = model + app_label = model._meta.app_label + + # Built-in permissions + self.view = get_model_permission(model, "view") + self.add = get_model_permission(model, "add") + self.change = get_model_permission(model, "change") + self.delete = get_model_permission(model, "delete") + + # Custom permissions from Meta.permissions + custom_permissions = getattr(model._meta, "permissions", []) + for codename, _ in custom_permissions: + setattr(self, codename, f"{app_label}.{codename}") + + def get_custom_permission(self, codename: str) -> str: + """Get a custom permission string for the model (for dynamic access).""" + app_label = self.model._meta.app_label + return f"{app_label}.{codename}" + + +def is_builtin_permission_action(action: str) -> bool: + return action in ["view", "change", "delete", "add"] + + +def build_permission_string(model: Type[models.Model], action: str) -> str: + app_label = model._meta.app_label + + if is_builtin_permission_action(action): + model_name = model._meta.model_name + return f"{app_label}.{action}_{model_name}" + else: + return f"{app_label}.{action}" + + +def permission_required_with_global_perms(perm, *args, **kwargs): + kwargs.setdefault("accept_global_perms", True) + return permission_required_or_403(perm, *args, **kwargs) + + +class BaseRole: + PERMISSIONS: dict[str, list[str]] = {} + + @classmethod + def _get_model_by_string(cls, model_string: str) -> Type[models.Model]: + from django.apps import apps + + app_label, model_name = model_string.split(".") + return apps.get_model(app_label, model_name) + + @classmethod + def get_permissions_for_model(cls, model: Type[models.Model]) -> list[str]: + model_string = f"{model._meta.app_label}.{model.__name__}" + + role_permissions = cls.PERMISSIONS.get(model_string, []) + + if not role_permissions: + return [] + + permissions = [] + for action in role_permissions: + permissions.append(build_permission_string(model, action)) + + return permissions + + @classmethod + def assign(cls, user: User, obj: Optional[models.Model] = None) -> None: + from django.contrib.auth.models import Permission + from django.contrib.contenttypes.models import ContentType + + if obj is None: + for model_string, _ in cls.PERMISSIONS.items(): + model = cls._get_model_by_string(model_string) + permissions = cls.get_permissions_for_model(model) + + for perm in permissions: + try: + assign_perm(perm, user) + except Permission.DoesNotExist: + codename = perm.split(".")[-1] + content_type = ContentType.objects.get_for_model(model) + Permission.objects.get_or_create( + codename=codename, + content_type=content_type, + defaults={"name": f"Can {codename.replace('_', ' ')}"}, + ) + assign_perm(perm, user) + else: + model = type(obj) + permissions = cls.get_permissions_for_model(model) + + for perm in permissions: + try: + assign_perm(perm, user, obj) + except Permission.DoesNotExist: + codename = perm.split(".")[-1] + content_type = ContentType.objects.get_for_model(model) + Permission.objects.get_or_create( + codename=codename, + content_type=content_type, + defaults={"name": f"Can {codename.replace('_', ' ')}"}, + ) + assign_perm(perm, user, obj) + + +class AdminRole(BaseRole): + PERMISSIONS = { + "catalog.DataSet": [ + "view", + "change", + "delete", + "add", + "manage_dataset_users", + ], + "account.User": [ + "view", + "change", + "delete", + "add", + ], + } + + +class UserRole(BaseRole): + PERMISSIONS = { + "catalog.DataSet": [ + "view", + ], + } + class PydanticModelField(BasePydanticModelField): def __init__(self, *, config_field_name: str, plugin_registry_class: Type[SourcePluginRegistry], **kwargs): diff --git a/server/catalog/views.py b/server/catalog/views.py index 4589663d..6347c4c0 100644 --- a/server/catalog/views.py +++ b/server/catalog/views.py @@ -1,10 +1,12 @@ from django.conf import settings from django.db import transaction from django.db.models import Count +from django.utils.decorators import method_decorator from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema +from guardian.shortcuts import get_objects_for_user from rest_framework import status -from rest_framework.generics import GenericAPIView, ListAPIView, ListCreateAPIView, RetrieveAPIView +from rest_framework.generics import GenericAPIView, ListAPIView, ListCreateAPIView, RetrieveAPIView, get_object_or_404 from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import IsAdminUser, IsAuthenticated from rest_framework.response import Response @@ -12,6 +14,7 @@ from account.models import User from account.serializers import UserSerializer +from account.services.user import UserService from agent.core.registries.embeddings import EmbeddingProviderRegistry from agent.core.registries.language_models import LanguageModelRegistry from agent.services import AgentService @@ -39,6 +42,9 @@ ProductSourceSerializer, SyncResponseSerializer, ) +from .utils import AdminRole, ModelPermissions, UserRole, permission_required_with_global_perms + +_ds_perms = ModelPermissions(DataSet) class SyncAllSourcesView(APIView): @@ -67,19 +73,14 @@ def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) def get_queryset(self): - if self.request.user and self.request.user.is_staff: - return DataSet.objects.all() - - return DataSet.objects.filter(users=self.request.user) + return get_objects_for_user(self.request.user, _ds_perms.view, DataSet) @swagger_auto_schema(operation_description="Create a new data set", request_body=DataSetCreateSerializer) + @method_decorator(permission_required_with_global_perms(_ds_perms.add)) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) def perform_create(self, serializer): - if not self.request.user.is_staff: - self.permission_denied(self.request) - with transaction.atomic(): preconfigure_agents = serializer.validated_data.pop("preconfigure_agents") data_set = serializer.save() @@ -102,6 +103,7 @@ class DataSetDetailView(RetrieveAPIView): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) @@ -114,6 +116,7 @@ def get(self, request, *args, **kwargs): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def patch(self, request, *args, **kwargs): instance = self.get_object() @@ -133,7 +136,7 @@ def patch(self, request, *args, **kwargs): class DataSetUserListView(ListCreateAPIView): serializer_class = UserSerializer - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="List users in a data set", @@ -143,6 +146,9 @@ class DataSetUserListView(ListCreateAPIView): ) ], ) + @method_decorator( + permission_required_with_global_perms(_ds_perms.manage_dataset_users, (DataSet, "pk", "data_set_id")) + ) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) @@ -156,17 +162,27 @@ def get_queryset(self): properties={"user_id": openapi.Schema(type=openapi.TYPE_INTEGER, description="ID of the user")}, ), ) + @method_decorator( + permission_required_with_global_perms(_ds_perms.manage_dataset_users, (DataSet, "pk", "data_set_id")) + ) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) def create(self, *args, **kwargs): user = User.objects.get(id=self.request.data["user_id"]) - DataSet.objects.get(id=self.kwargs["data_set_id"]).users.add(user) + data_set = DataSet.objects.get(id=self.kwargs["data_set_id"]) + data_set.users.add(user) + + if user.is_staff: + UserService.assign_role(user, AdminRole, data_set) + else: + UserService.assign_role(user, UserRole, data_set) + return Response({}, status=status.HTTP_201_CREATED) class DataSetUserView(GenericAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Remove a user from a data set", @@ -177,18 +193,22 @@ class DataSetUserView(GenericAPIView): openapi.Parameter("user_id", openapi.IN_PATH, description="ID of the user", type=openapi.TYPE_INTEGER), ], ) + @method_decorator( + permission_required_with_global_perms(_ds_perms.manage_dataset_users, (DataSet, "pk", "data_set_id")) + ) def delete(self, *args, **kwargs): DataSet.objects.get(id=self.kwargs["data_set_id"]).users.remove(self.kwargs["user_id"]) return Response({}, status=status.HTTP_204_NO_CONTENT) class SyncDataSetAllSourcesView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Sync all sources in a data set", responses={200: SyncResponseSerializer}, ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def post(self, request, *args, **kwargs): task = sync_data_set_all_sources.apply_async(args=[kwargs["data_set_id"]]) serializer = SyncResponseSerializer({"task_id": task.id}) @@ -197,7 +217,7 @@ def post(self, request, *args, **kwargs): class DataSetProductSourceListView(ListCreateAPIView): serializer_class = ProductSourceSerializer - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="List product sources in a data set", @@ -207,6 +227,7 @@ class DataSetProductSourceListView(ListCreateAPIView): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) @@ -216,6 +237,7 @@ def get_queryset(self): @swagger_auto_schema( operation_description="Create a new product source in a data set", request_body=ProductSourceSerializer ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) @@ -237,7 +259,7 @@ def perform_create(self, serializer): class DataSetProductSourceView(GenericAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = ProductSourceSerializer @swagger_auto_schema( @@ -251,14 +273,16 @@ class DataSetProductSourceView(GenericAPIView): ), ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, data_set_id, product_source_id): - product_source = ProductSource.objects.get(id=product_source_id) + product_source = ProductSource.objects.get(id=product_source_id, data_set_id=data_set_id) serializer = self.serializer_class(product_source) return Response(serializer.data) @swagger_auto_schema(operation_description="Update a product source", request_body=ProductSourceSerializer) - def patch(self, request, *args, **kwargs): - product_source = ProductSource.objects.get(id=kwargs.get("product_source_id")) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) + def patch(self, request, data_set_id, product_source_id): + product_source = ProductSource.objects.get(id=product_source_id, data_set_id=data_set_id) serializer = self.serializer_class(product_source, data=request.data, partial=True) serializer.is_valid(raise_exception=True) serializer.save(corrupted=False) @@ -276,8 +300,9 @@ def patch(self, request, *args, **kwargs): ), ], ) - def delete(self, *args, **kwargs): - ProductSource.objects.filter(id=kwargs["product_source_id"]).delete() + @method_decorator(permission_required_with_global_perms(_ds_perms.delete, (DataSet, "pk", "data_set_id"))) + def delete(self, request, data_set_id, product_source_id): + ProductSource.objects.filter(id=product_source_id, data_set_id=data_set_id).delete() return Response({}, status=status.HTTP_204_NO_CONTENT) @@ -295,27 +320,29 @@ def post(self, request, *args, **kwargs): class SyncDataSetProductSourcesView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Sync all product sources in a data set", responses={200: SyncResponseSerializer}, ) - def post(self, request, *args, **kwargs): - task = sync_data_set_product_sources.apply_async(args=[kwargs["data_set_id"]]) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) + def post(self, request, data_set_id): + task = sync_data_set_product_sources.apply_async(args=[data_set_id]) serializer = SyncResponseSerializer({"task_id": task.id}) return Response(serializer.data, status=status.HTTP_202_ACCEPTED) class SyncDataSetProductSourceView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Sync a product source", responses={200: SyncResponseSerializer}, ) - def post(self, request, *args, **kwargs): - task = sync_product_source.apply_async(args=[kwargs["product_source_id"]]) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) + def post(self, request, data_set_id, product_source_id): + task = sync_product_source.apply_async(args=[product_source_id]) serializer = SyncResponseSerializer({"task_id": task.id}) return Response(serializer.data, status=status.HTTP_202_ACCEPTED) @@ -333,14 +360,12 @@ class ProductListView(ListAPIView): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) def get_queryset(self): - if self.request.user.is_staff: - data_set = DataSet.objects.get(id=self.kwargs["data_set_id"]) - else: - data_set = DataSet.objects.get(id=self.kwargs["data_set_id"], users=self.request.user) + data_set = get_object_or_404(DataSet, pk=self.kwargs["data_set_id"]) return data_set.products.all() @@ -357,20 +382,18 @@ class DocumentListView(ListAPIView): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) def get_queryset(self): - if self.request.user.is_staff: - data_set = DataSet.objects.get(id=self.kwargs["data_set_id"]) - else: - data_set = DataSet.objects.get(id=self.kwargs["data_set_id"], users=self.request.user) + data_set = get_object_or_404(DataSet, pk=self.kwargs["data_set_id"]) return data_set.documents.annotate(chunks_count=Count("chunks")).all() class DataSetDocumentSourceListView(ListCreateAPIView): serializer_class = DocumentSourceSerializer - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="List document sources in a data set", @@ -380,6 +403,7 @@ class DataSetDocumentSourceListView(ListCreateAPIView): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, *args, **kwargs): return super().get(request, *args, **kwargs) @@ -389,6 +413,7 @@ def get_queryset(self): @swagger_auto_schema( operation_description="Create a new document source in a data set", request_body=DocumentSourceSerializer ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) @@ -410,7 +435,7 @@ def create(self, request, *args, **kwargs): class DataSetDocumentSourceView(GenericAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = DocumentSourceSerializer @swagger_auto_schema( @@ -427,14 +452,16 @@ class DataSetDocumentSourceView(GenericAPIView): ), ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, data_set_id, document_source_id): - document_source = DocumentSource.objects.get(id=document_source_id) + document_source = DocumentSource.objects.get(id=document_source_id, data_set_id=data_set_id) serializer = self.serializer_class(document_source) return Response(serializer.data) @swagger_auto_schema(operation_description="Update a document source", request_body=DocumentSourceSerializer) - def patch(self, request, *args, **kwargs): - document_source = DocumentSource.objects.get(id=kwargs.get("document_source_id")) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) + def patch(self, request, data_set_id, document_source_id): + document_source = DocumentSource.objects.get(id=document_source_id, data_set_id=data_set_id) serializer = self.serializer_class(document_source, data=request.data, partial=True) serializer.is_valid(raise_exception=True) serializer.save(corrupted=False) @@ -454,8 +481,9 @@ def patch(self, request, *args, **kwargs): ), ], ) - def delete(self, *args, **kwargs): - DocumentSource.objects.filter(id=kwargs["document_source_id"]).delete() + @method_decorator(permission_required_with_global_perms(_ds_perms.delete, (DataSet, "pk", "data_set_id"))) + def delete(self, request, data_set_id, document_source_id): + DocumentSource.objects.filter(id=document_source_id, data_set_id=data_set_id).delete() return Response({}, status=status.HTTP_204_NO_CONTENT) @@ -479,27 +507,29 @@ class SyncDataSetDocumentSourcesView(APIView): operation_description="Sync all document sources in a data set", responses={200: SyncResponseSerializer}, ) - def post(self, request, *args, **kwargs): - task = sync_data_set_document_sources.apply_async(args=[kwargs["data_set_id"]]) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) + def post(self, request, data_set_id): + task = sync_data_set_document_sources.apply_async(args=[data_set_id]) serializer = SyncResponseSerializer({"task_id": task.id}) return Response(serializer.data, status=status.HTTP_202_ACCEPTED) class SyncDataSetDocumentSourceView(APIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Sync a document source", responses={200: SyncResponseSerializer}, ) - def post(self, request, *args, **kwargs): - task = sync_document_source.apply_async(args=[kwargs["document_source_id"]]) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) + def post(self, request, data_set_id, document_source_id): + task = sync_document_source.apply_async(args=[document_source_id]) serializer = SyncResponseSerializer({"task_id": task.id}) return Response(serializer.data, status=status.HTTP_202_ACCEPTED) class DataSetECommerceIntegrationView(GenericAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] serializer_class = ECommerceIntegrationSerializer @swagger_auto_schema( @@ -510,6 +540,7 @@ class DataSetECommerceIntegrationView(GenericAPIView): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.view, (DataSet, "pk", "data_set_id"))) def get(self, request, *args, **kwargs): data_set_id = kwargs["data_set_id"] try: @@ -528,6 +559,7 @@ def get(self, request, *args, **kwargs): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def post(self, request, *args, **kwargs): data_set_id = kwargs["data_set_id"] @@ -553,6 +585,7 @@ def post(self, request, *args, **kwargs): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def patch(self, request, *args, **kwargs): data_set_id = kwargs["data_set_id"] try: @@ -572,6 +605,7 @@ def patch(self, request, *args, **kwargs): ) ], ) + @method_decorator(permission_required_with_global_perms(_ds_perms.delete, (DataSet, "pk", "data_set_id"))) def delete(self, *args, **kwargs): data_set_id = kwargs["data_set_id"] ECommerceIntegration.objects.filter(data_set_id=data_set_id).delete() @@ -579,12 +613,13 @@ def delete(self, *args, **kwargs): class DataSetECommerceIntegrationSyncView(GenericAPIView): - permission_classes = [IsAdminUser] + permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Sync a document source", responses={200: SyncResponseSerializer}, ) + @method_decorator(permission_required_with_global_perms(_ds_perms.change, (DataSet, "pk", "data_set_id"))) def post(self, request, *args, **kwargs): data_set_id = kwargs["data_set_id"] try: @@ -597,6 +632,7 @@ def post(self, request, *args, **kwargs): return Response(serializer.data, status=status.HTTP_202_ACCEPTED) +# TODO Permissions class ConfigView(GenericAPIView): permission_classes = [IsAdminUser] diff --git a/server/conftest.py b/server/conftest.py index 2d984f2a..84835d78 100644 --- a/server/conftest.py +++ b/server/conftest.py @@ -6,8 +6,10 @@ from rest_framework.test import APIClient from account.models import User +from account.services import UserService from agent.models import Conversation -from catalog.models import DataSet +from catalog.models import DataSet, DocumentSource, ProductSource +from catalog.utils import AdminRole, UserRole @pytest.fixture @@ -17,7 +19,10 @@ def user(): @pytest.fixture def admin_user(): - return baker.make(User, is_staff=True) + """Create admin user with AdminRole permissions.""" + user = baker.make(User, is_staff=True) + UserService.assign_role(user, AdminRole) + return user @pytest.fixture @@ -34,11 +39,108 @@ def admin_api_client(admin_user): return client +@pytest.fixture +def user_with_view_permission(data_set): + user = baker.make(User, is_staff=True) + UserService.assign_role(user, UserRole, data_set) + return user + + +@pytest.fixture +def api_client_with_view_permission(user_with_view_permission): + client = APIClient() + client.force_authenticate(user=user_with_view_permission) + return client + + +@pytest.fixture +def user_with_change_permission(data_set): + user = baker.make(User, is_staff=True) + UserService.assign_role(user, AdminRole, data_set) + return user + + +@pytest.fixture +def api_client_with_change_permission(user_with_change_permission): + client = APIClient() + client.force_authenticate(user=user_with_change_permission) + return client + + +@pytest.fixture +def user_with_manage_permission(data_set): + user = baker.make(User, is_staff=True) + UserService.assign_role(user, AdminRole, data_set) + return user + + +@pytest.fixture +def api_client_with_manage_permission(user_with_manage_permission): + client = APIClient() + client.force_authenticate(user=user_with_manage_permission) + return client + + +@pytest.fixture +def user_with_user_view_permission(): + user = baker.make(User, is_staff=True) + from account.models import User as UserModel + from catalog.utils import get_model_permission + + UserService.assign_permission(user, get_model_permission(UserModel, "view")) + return user + + +@pytest.fixture +def api_client_with_user_view_permission(user_with_user_view_permission): + client = APIClient() + client.force_authenticate(user=user_with_user_view_permission) + return client + + +@pytest.fixture +def user_with_user_add_permission(): + user = baker.make(User, is_staff=True) + UserService.assign_role(user, AdminRole) + return user + + +@pytest.fixture +def api_client_with_user_add_permission(user_with_user_add_permission): + client = APIClient() + client.force_authenticate(user=user_with_user_add_permission) + return client + + +@pytest.fixture +def user_with_user_change_permission(): + user = baker.make(User, is_staff=True) + UserService.assign_role(user, AdminRole) + return user + + +@pytest.fixture +def api_client_with_user_change_permission(user_with_user_change_permission): + client = APIClient() + client.force_authenticate(user=user_with_user_change_permission) + return client + + @pytest.fixture def data_set(): return baker.make(DataSet) +@pytest.fixture +def document_source(data_set): + return baker.make(DocumentSource, data_set=data_set) + + +@pytest.fixture +def product_source(data_set): + return baker.make(ProductSource, data_set=data_set) + + @pytest.fixture def conversation(user, data_set): return baker.make(Conversation, user=user, data_set=data_set) diff --git a/server/pecl/settings.py b/server/pecl/settings.py index a0758100..3734558b 100644 --- a/server/pecl/settings.py +++ b/server/pecl/settings.py @@ -76,6 +76,7 @@ "rest_framework", "rest_framework.authtoken", "corsheaders", + "guardian", "catalog", "agent", "account", @@ -278,7 +279,10 @@ SERVICE_ACCOUNT_DOMAIN = env.str("SERVICE_ACCOUNT_DOMAIN", "enthusiast.internal") -AUTHENTICATION_BACKENDS = ("django.contrib.auth.backends.ModelBackend",) +AUTHENTICATION_BACKENDS = ( + "django.contrib.auth.backends.ModelBackend", + "guardian.backends.ObjectPermissionBackend", +) # Django social auth settings SOCIAL_AUTH_LOGIN_ERROR_URL = ( diff --git a/server/pyproject.dev.toml b/server/pyproject.dev.toml index 344b1734..c4fcd805 100644 --- a/server/pyproject.dev.toml +++ b/server/pyproject.dev.toml @@ -27,7 +27,8 @@ dependencies = [ "pypdf (>=6.0.0, <7.0.0)", "pillow (>=11.3.0,<12.0.0)", "sqlglot (>=28.5.0,<29.0.0)", - "social-auth-app-django (==5.7.0)" + "social-auth-app-django (==5.7.0)", + "django-guardian (>=2.4.0,<3.0.0)" ] diff --git a/server/pyproject.toml b/server/pyproject.toml index 024b4a5c..cc3625bd 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "pypdf (>=6.0.0, <7.0.0)", "pillow (>=11.3.0,<12.0.0)", "sqlglot (>=28.5.0,<29.0.0)", + "social-auth-app-django (==5.7.0)", + "django-guardian (>=2.4.0,<3.0.0)" ] [tool.poetry] diff --git a/server/sample.env b/server/sample.env index 9c65aba9..8e0e6d03 100644 --- a/server/sample.env +++ b/server/sample.env @@ -23,3 +23,8 @@ ECL_ADMIN_PASSWORD=changeme # === API keys === OPENAI_API_KEY= GOOGLE_API_KEY= + + +FRONTEND_BASE_URL=http://localhost:10001 +BASE_URL=http://localhost:10000 +