diff --git a/frontend/src/components/login/login-form.tsx b/frontend/src/components/login/login-form.tsx index 71a6df98..13e6561f 100644 --- a/frontend/src/components/login/login-form.tsx +++ b/frontend/src/components/login/login-form.tsx @@ -1,53 +1,39 @@ -import { cn } from "@/lib/utils" +import { cn } from "@/lib/utils"; import { Input } from "@/components/ui/input.tsx"; import { Button } from "@/components/ui/button.tsx"; import { HTMLAttributes, SyntheticEvent, useState } from "react"; import { Spinner } from "@/components/util/spinner.tsx"; import { authenticationProviderInstance } from "@/lib/authentication-provider.ts"; -import { useNavigate } from "react-router-dom"; import { ApiClient } from "@/lib/api.ts"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert.tsx"; const api = new ApiClient(authenticationProviderInstance); -export function LoginForm({ className, ...props }: HTMLAttributes) { +interface LoginFormProps extends HTMLAttributes { + onSuccess: (api: ApiClient) => Promise; +} + +export function LoginForm({ className, onSuccess, ...props }: LoginFormProps) { const [isLoading, setIsLoading] = useState(false); const [isError, setIsError] = useState(false); - const navigate = useNavigate(); - async function onSubmit(event: SyntheticEvent) { + const onSubmit = async (event: SyntheticEvent) => { event.preventDefault(); setIsLoading(true); - const email = (event.target as HTMLFormElement).email.value; - const password = (event.target as HTMLFormElement).password.value; + const form = event.target as HTMLFormElement; + const email = (form.email as HTMLInputElement).value; + const password = (form.password as HTMLInputElement).value; try { await api.login(email, password); - authenticationProviderInstance.login(); - - const dataSets = await api.dataSets().getDataSets(); - if (dataSets.length === 0) { - const account = await api.getAccount(); - if (account.isStaff) { - navigate('/onboarding'); - } else { - navigate('/no-data-sets'); - } - } else { - const dataSetAgents = await api.agents().getDatasetAvailableAgents(dataSets[0].id!) - const agentId = dataSetAgents?.[0]?.id; - const page = agentId - ? `/data-sets/${dataSets[0].id}/chat/new/${agentId}` - : `/data-sets/${dataSets[0].id}/chat/new`; - navigate(page); - } + await onSuccess(api); } catch { setIsError(true); } finally { setIsLoading(false); } - } + }; return (
@@ -60,6 +46,7 @@ export function LoginForm({ className, ...props }: HTMLAttributes}
- -
-
- -
-
- - Or - -
-
- - ) -} + ); +} \ No newline at end of file diff --git a/frontend/src/components/login/sso-login-form.tsx b/frontend/src/components/login/sso-login-form.tsx new file mode 100644 index 00000000..e6be46be --- /dev/null +++ b/frontend/src/components/login/sso-login-form.tsx @@ -0,0 +1,37 @@ +import { Button } from "@/components/ui/button.tsx"; +import { useEffect, useState } from "react"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert.tsx"; + +export function SSOLoginForm() { + const [errorMessage, setErrorMessage] = useState(null); + + useEffect(() => { + const params = new URLSearchParams(window.location.search); + const error = params.get("error"); + if (error) { + setErrorMessage(error); + window.history.replaceState({}, document.title, window.location.pathname); + } + }, []); + + const onClick = () => { + window.location.href = `${import.meta.env.VITE_API_BASE}/login/sso/`; + }; + + return ( +
+ {errorMessage && ( + + Login failed + {errorMessage} + + )} + +
+ ); +} \ No newline at end of file diff --git a/frontend/src/pages/login.tsx b/frontend/src/pages/login.tsx index 1f5949df..436e47db 100644 --- a/frontend/src/pages/login.tsx +++ b/frontend/src/pages/login.tsx @@ -1,8 +1,48 @@ +import { useEffect } from "react"; import { LoginForm } from "@/components/login/login-form.tsx"; -import logoUrl from '@/assets/logo.png'; -import logoSvgUrl from '@/assets/logo.svg'; +import logoUrl from "@/assets/logo.png"; +import logoSvgUrl from "@/assets/logo.svg"; +import { authenticationProviderInstance } from "@/lib/authentication-provider.ts"; +import { ApiClient } from "@/lib/api.ts"; +import { useNavigate } from "react-router-dom"; +import { SSOLoginForm } from "@/components/login/sso-login-form.tsx"; + +const api = new ApiClient(authenticationProviderInstance); + +async function continueAfterLogin(navigate: (path: string) => void) { + authenticationProviderInstance.login(); + const dataSets = await api.dataSets().getDataSets(); + if (dataSets.length === 0) { + const account = await api.getAccount(); + if (account.isStaff) { + navigate("/onboarding"); + } else { + navigate("/no-data-sets"); + } + } else { + const dataSetAgents = await api.agents().getDatasetAvailableAgents(dataSets[0].id!); + const agentId = dataSetAgents?.[0]?.id; + const page = agentId + ? `/data-sets/${dataSets[0].id}/chat/new/${agentId}` + : `/data-sets/${dataSets[0].id}/chat/new`; + navigate(page); + } +} export function LoginPage() { + const navigate = useNavigate(); + + useEffect(() => { + api + .getAccount() + .then(() => continueAfterLogin(navigate)) + .catch(() => {}); + }, [navigate]); + + const onSuccess = async () => { + await continueAfterLogin(navigate); + }; + return ( <>
@@ -25,10 +65,23 @@ export function LoginPage() { Enter your email and password to get started

- + + <> +
+
+ +
+
+ + Or + +
+
+ + ) -} +} \ No newline at end of file diff --git a/server/account/pipelines.py b/server/account/pipelines.py new file mode 100644 index 00000000..cd955482 --- /dev/null +++ b/server/account/pipelines.py @@ -0,0 +1,22 @@ +import logging + +from django.conf import settings +from social_core.exceptions import AuthForbidden +from utils.functions import import_from_string + +logger = logging.getLogger(__name__) + + +def update_user(strategy, backend, user=None, *args, **kwargs): + if not user: + return + if kwargs.get("is_new") is False: + return + try: + sso_provider_class = import_from_string(settings.SSO_PROVIDER_SERVICE) + sso_provider_class.update_user(user) + except NotImplementedError: + return + except Exception as e: + logger.error(e, exc_info=True) + raise AuthForbidden(backend) diff --git a/server/account/services/__init__.py b/server/account/services/__init__.py new file mode 100644 index 00000000..275703f0 --- /dev/null +++ b/server/account/services/__init__.py @@ -0,0 +1,4 @@ +from .service_accounts import ServiceAccountNameService +from .sso_provider import SSOProviderService + +__all__ = ["ServiceAccountNameService", "SSOProviderService"] diff --git a/server/account/services.py b/server/account/services/service_accounts.py similarity index 94% rename from server/account/services.py rename to server/account/services/service_accounts.py index c99bfc0e..efd3d005 100644 --- a/server/account/services.py +++ b/server/account/services/service_accounts.py @@ -1,6 +1,6 @@ from django.conf import settings -from .models import User +from account.models import User class ServiceAccountNameService: diff --git a/server/account/services/sso_provider.py b/server/account/services/sso_provider.py new file mode 100644 index 00000000..e8bdb639 --- /dev/null +++ b/server/account/services/sso_provider.py @@ -0,0 +1,23 @@ +from django.conf import settings +from django.contrib.auth import get_user_model +from django.shortcuts import redirect +from utils.functions import import_from_string + +User = get_user_model() + + +class SSOProviderService: + NOT_CONFIGURED_ERROR_MESSAGE = "SSO provider service not configured." + + @staticmethod + def login(): + backend = import_from_string(settings.DEFAULT_SSO_PROVIDER_BACKEND) + return redirect(f"/login/{backend.name}/") + + @staticmethod + def is_enabled() -> bool: + return bool(getattr(settings, "DEFAULT_SSO_PROVIDER_BACKEND", "")) + + @staticmethod + def update_user(user: User): + raise NotImplementedError diff --git a/server/account/tests/__init__.py b/server/account/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/account/tests/test_pipelines.py b/server/account/tests/test_pipelines.py new file mode 100644 index 00000000..88b34125 --- /dev/null +++ b/server/account/tests/test_pipelines.py @@ -0,0 +1,52 @@ +from unittest.mock import MagicMock, patch + +import pytest +from social_core.exceptions import AuthForbidden + +from account.models import User +from account.pipelines import update_user + + +@pytest.mark.django_db +class TestUpdateUserPipeline: + def test_returns_early_when_no_user(self): + strategy = MagicMock() + backend = MagicMock() + result = update_user(strategy, backend, user=None) + assert result is None + + def test_returns_early_when_user_not_new(self): + from model_bakery import baker + + user = baker.make(User, email="existing@example.com") + strategy = MagicMock() + backend = MagicMock() + result = update_user(strategy, backend, user=user, is_new=False) + assert result is None + + def test_returns_early_on_not_implemented_error(self): + from model_bakery import baker + + user = baker.make(User, email="new@example.com") + strategy = MagicMock() + backend = MagicMock() + mock_provider_class = MagicMock() + mock_provider_class.update_user.side_effect = NotImplementedError + with patch("account.pipelines.import_from_string", return_value=mock_provider_class): + result = update_user(strategy, backend, user=user, is_new=True) + assert result is None + mock_provider_class.update_user.assert_called_once_with(user) + + def test_raises_auth_forbidden_on_other_exception(self): + from model_bakery import baker + + user = baker.make(User, email="new@example.com") + strategy = MagicMock() + backend = MagicMock() + mock_provider_class = MagicMock() + mock_provider_class.update_user.side_effect = ValueError("Update failed") + with patch("account.pipelines.import_from_string", return_value=mock_provider_class): + with pytest.raises(AuthForbidden) as exc_info: + update_user(strategy, backend, user=user, is_new=True) + assert exc_info.value.backend is backend + mock_provider_class.update_user.assert_called_once_with(user) diff --git a/server/account/tests/test_services/__init__.py b/server/account/tests/test_services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/account/tests/test_services/test_sso_provider.py b/server/account/tests/test_services/test_sso_provider.py new file mode 100644 index 00000000..68563ad0 --- /dev/null +++ b/server/account/tests/test_services/test_sso_provider.py @@ -0,0 +1,38 @@ +import pytest + +from account.models import User +from account.services.sso_provider import SSOProviderService + + +class TestSSOProviderServiceIsEnabled: + def test_returns_false_when_backend_not_configured(self, settings): + settings.DEFAULT_SSO_PROVIDER_BACKEND = "" + assert SSOProviderService.is_enabled() is False + + def test_returns_false_when_backend_setting_missing(self): + assert SSOProviderService.is_enabled() is False + + def test_returns_true_when_backend_configured(self, settings): + settings.DEFAULT_SSO_PROVIDER_BACKEND = "social_core.backends.google.GoogleOAuth2" + assert SSOProviderService.is_enabled() is True + + +class TestSSOProviderServiceLogin: + def test_redirects_to_backend_login_url(self, settings): + settings.DEFAULT_SSO_PROVIDER_BACKEND = "account.tests.test_services.test_sso_provider.FakeBackend" + response = SSOProviderService.login() + assert response.url == "/login/fake-backend/" + + +class TestSSOProviderServiceUpdateUser: + @pytest.mark.django_db + def test_raises_not_implemented_error(self): + from model_bakery import baker + + user = baker.make(User, email="test@example.com") + with pytest.raises(NotImplementedError): + SSOProviderService.update_user(user) + + +class FakeBackend: + name = "fake-backend" diff --git a/server/account/tests/test_views/__init__.py b/server/account/tests/test_views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/account/tests/test_views/test_auth.py b/server/account/tests/test_views/test_auth.py new file mode 100644 index 00000000..2f98f4da --- /dev/null +++ b/server/account/tests/test_views/test_auth.py @@ -0,0 +1,36 @@ +from unittest.mock import patch + +import pytest +from django.test import Client +from django.urls import reverse + + +@pytest.mark.django_db +class TestSSOProviderLoginView: + def test_redirects_to_frontend_with_error_when_sso_disabled(self, settings): + settings.DEFAULT_SSO_PROVIDER_BACKEND = "" + settings.FRONTEND_BASE_URL = "https://app.example.com" + client = Client() + response = client.get(reverse("sso-login")) + assert response.status_code == 302 + assert response["Location"].startswith("https://app.example.com/login?error=") + + def test_redirects_to_backend_login_when_sso_enabled(self, settings): + settings.DEFAULT_SSO_PROVIDER_BACKEND = "account.tests.test_services.test_sso_provider.FakeBackend" + settings.FRONTEND_BASE_URL = "https://app.example.com" + client = Client() + response = client.get(reverse("sso-login")) + assert response.status_code == 302 + assert response["Location"] == "/login/fake-backend/" + + def test_redirects_to_frontend_with_generic_error_on_exception(self, settings): + settings.FRONTEND_BASE_URL = "https://app.example.com" + client = Client() + with patch( + "account.views.auth.import_from_string", + side_effect=Exception("Service unavailable"), + ): + response = client.get(reverse("sso-login")) + assert response.status_code == 302 + assert response["Location"].startswith("https://app.example.com/login?error=") + assert "error=Service+unavailable" in response["Location"] diff --git a/server/account/urls.py b/server/account/urls.py index cfef675f..0a0cc583 100644 --- a/server/account/urls.py +++ b/server/account/urls.py @@ -1,14 +1,14 @@ -from django.urls import path +from django.urls import include, path import account.views.accounts -import account.views.login +import account.views.auth import account.views.service_accounts import account.views.users urlpatterns = [ - path("api/auth/login", account.views.login.LoginView.as_view(), name="login"), - path("api/auth/logout", account.views.login.SessionLogoutView.as_view(), name="logout"), - path("api/auth/csrf", account.views.login.CSRFView.as_view(), name="csrf"), + path("api/auth/login", account.views.auth.LoginView.as_view(), name="login"), + path("api/auth/logout", account.views.auth.LogoutView.as_view(), name="logout"), + path("api/auth/csrf", account.views.auth.CSRFView.as_view(), name="csrf"), path("api/account", account.views.accounts.AccountView.as_view(), name="account"), path("api/users", account.views.users.UserListView.as_view(), name="user_list"), path("api/users/", account.views.users.UserView.as_view(), name="user_details"), @@ -29,4 +29,6 @@ account.views.service_accounts.CheckServiceNameView.as_view(), name="check_service_name", ), + path("login/sso/", account.views.auth.SSOProviderLoginView.as_view(), name="sso-login"), + path("", include("social_django.urls", namespace="social")), ] diff --git a/server/account/views/login.py b/server/account/views/auth.py similarity index 55% rename from server/account/views/login.py rename to server/account/views/auth.py index d2ded526..837bee8a 100644 --- a/server/account/views/login.py +++ b/server/account/views/auth.py @@ -1,11 +1,22 @@ +import logging +import urllib +from typing import Type + +from django.conf import settings from django.contrib.auth import authenticate, login, logout -from django.http import JsonResponse +from django.http import HttpRequest, HttpResponse, JsonResponse from django.middleware.csrf import get_token +from django.shortcuts import redirect from django.views import View from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework.response import Response from rest_framework.views import APIView +from utils.functions import import_from_string + +from account.services.sso_provider import SSOProviderService + +logger = logging.getLogger(__name__) class LoginView(APIView): @@ -32,7 +43,7 @@ def post(self, request): return Response({}, status=403) -class SessionLogoutView(APIView): +class LogoutView(APIView): @swagger_auto_schema(operation_description="Log out (clear session)", responses={200: "OK"}) def post(self, request): logout(request) @@ -43,3 +54,21 @@ class CSRFView(View): def get(self, request): get_token(request) return JsonResponse({"csrfToken": get_token(request)}) + + +class SSOProviderLoginView(View): + DEFAULT_ERROR_MESSAGE = "Service unavailable, please try again later." + + def get(self, request: HttpRequest) -> HttpResponse: + try: + sso_provider_service: Type[SSOProviderService] = import_from_string(settings.SSO_PROVIDER_SERVICE) + if not sso_provider_service.is_enabled(): + return redirect( + f"{settings.FRONTEND_BASE_URL}/login?{urllib.parse.urlencode({'error': sso_provider_service.NOT_CONFIGURED_ERROR_MESSAGE})}" + ) + return SSOProviderService.login() + except Exception as e: + logger.error(e, exc_info=True) + return redirect( + f"{settings.FRONTEND_BASE_URL}/login?{urllib.parse.urlencode({'error': self.DEFAULT_ERROR_MESSAGE})}" + ) diff --git a/server/account/views/service_accounts.py b/server/account/views/service_accounts.py index 1f081f6d..c3914801 100644 --- a/server/account/views/service_accounts.py +++ b/server/account/views/service_accounts.py @@ -14,7 +14,7 @@ ServiceAccountSerializer, TokenResponseSerializer, ) -from account.services import ServiceAccountNameService +from account.services.service_accounts import ServiceAccountNameService class CheckServiceNameView(APIView): diff --git a/server/pecl/settings.py b/server/pecl/settings.py index 0335ef41..a0758100 100644 --- a/server/pecl/settings.py +++ b/server/pecl/settings.py @@ -13,10 +13,12 @@ import json import os import sys +import urllib from datetime import timedelta from pathlib import Path from environ import Env +from social_core.pipeline import DEFAULT_AUTH_PIPELINE # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent @@ -80,6 +82,7 @@ "sync", "drf_yasg", "django_filters", + "social_django", ] MIDDLEWARE = [ @@ -92,6 +95,7 @@ "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", + "social_django.middleware.SocialAuthExceptionMiddleware", ] ROOT_URLCONF = "pecl.urls" @@ -274,4 +278,23 @@ SERVICE_ACCOUNT_DOMAIN = env.str("SERVICE_ACCOUNT_DOMAIN", "enthusiast.internal") +AUTHENTICATION_BACKENDS = ("django.contrib.auth.backends.ModelBackend",) + +# Django social auth settings +SOCIAL_AUTH_LOGIN_ERROR_URL = ( + f"http://localhost:10001/login?{urllib.parse.urlencode({'error': 'Could not sign in. Please try again.'})}" +) +SOCIAL_AUTH_RAISE_EXCEPTIONS = False +SOCIAL_AUTH_JSONFIELD_ENABLED = True +SOCIAL_AUTH_PIPELINE = DEFAULT_AUTH_PIPELINE + ("account.pipelines.update_user",) +LOGIN_URL = "/login/" +LOGIN_REDIRECT_URL = "http://localhost:10001/" + +# SSO providers +SSO_PROVIDER_SERVICE = "account.services.sso_provider.SSOProviderService" +DEFAULT_SSO_PROVIDER_BACKEND = "" + +FRONTEND_BASE_URL = env.str("FRONTEND_BASE_URL", "http://localhost:10001") + + from .settings_override import * # noqa diff --git a/server/pyproject.dev.toml b/server/pyproject.dev.toml index 1d44126c..344b1734 100644 --- a/server/pyproject.dev.toml +++ b/server/pyproject.dev.toml @@ -27,6 +27,7 @@ dependencies = [ "pypdf (>=6.0.0, <7.0.0)", "pillow (>=11.3.0,<12.0.0)", "sqlglot (>=28.5.0,<29.0.0)", + "social-auth-app-django (==5.7.0)" ]