diff --git a/docker-compose.development.yml b/docker-compose.development.yml index 56a45646..428c477f 100644 --- a/docker-compose.development.yml +++ b/docker-compose.development.yml @@ -19,6 +19,7 @@ services: - PORT=10001 - VITE_API_BASE=http://localhost:10000 - VITE_WS_BASE=ws://localhost:10000 + - VITE_OTP_ENABLED=false ports: - "10001:10001" api: diff --git a/docker-compose.yml b/docker-compose.yml index e774aac9..ce4b0541 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ services: - PORT=10001 - VITE_API_BASE=http://localhost:10000 - VITE_WS_BASE=ws://localhost:10000 + - VITE_OTP_ENABLED=false ports: - "10001:10001" api: diff --git a/frontend/src/components/login/login-form.tsx b/frontend/src/components/login/login-form.tsx index 1da6dd25..f707fe06 100644 --- a/frontend/src/components/login/login-form.tsx +++ b/frontend/src/components/login/login-form.tsx @@ -1,53 +1,39 @@ -import { cn } from "@/lib/utils" +import { cn } from "@/lib/utils"; import { Input } from "@/components/ui/input.tsx"; import { Button } from "@/components/ui/button.tsx"; import { HTMLAttributes, SyntheticEvent, useState } from "react"; import { Spinner } from "@/components/util/spinner.tsx"; import { authenticationProviderInstance } from "@/lib/authentication-provider.ts"; -import { useNavigate } from "react-router-dom"; import { ApiClient } from "@/lib/api.ts"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert.tsx"; const api = new ApiClient(authenticationProviderInstance); -export function LoginForm({ className, ...props }: HTMLAttributes) { +interface LoginFormProps extends HTMLAttributes { + onSuccess: (token: string, api: ApiClient) => Promise; +} + +export function LoginForm({ className, onSuccess, ...props }: LoginFormProps) { const [isLoading, setIsLoading] = useState(false); const [isError, setIsError] = useState(false); - const navigate = useNavigate(); - async function onSubmit(event: SyntheticEvent) { + const onSubmit = async (event: SyntheticEvent) => { event.preventDefault(); setIsLoading(true); - const email = (event.target as HTMLFormElement).email.value; - const password = (event.target as HTMLFormElement).password.value; + const form = event.target as HTMLFormElement; + const email = (form.email as HTMLInputElement).value; + const password = (form.password as HTMLInputElement).value; try { const { token } = await api.login(email, password); - authenticationProviderInstance.login(token); - - const dataSets = await api.dataSets().getDataSets(); - if (dataSets.length === 0) { - const account = await api.getAccount(); - if (account.isStaff) { - navigate('/onboarding'); - } else { - navigate('/no-data-sets'); - } - } else { - const dataSetAgents = await api.agents().getDatasetAvailableAgents(dataSets[0].id!) - const agentId = dataSetAgents?.[0]?.id; - const page = agentId - ? `/data-sets/${dataSets[0].id}/chat/new/${agentId}` - : `/data-sets/${dataSets[0].id}/chat/new`; - navigate(page); - } + await onSuccess(token, api); } catch { setIsError(true); } finally { setIsLoading(false); } - } + }; return (
@@ -60,6 +46,7 @@ export function LoginForm({ className, ...props }: HTMLAttributes}
- -
-
- -
-
- - Or - -
-
- - ) + ); } diff --git a/frontend/src/components/login/otp-login-form.tsx b/frontend/src/components/login/otp-login-form.tsx new file mode 100644 index 00000000..bb67aa4a --- /dev/null +++ b/frontend/src/components/login/otp-login-form.tsx @@ -0,0 +1,51 @@ +import { Button } from "@/components/ui/button.tsx"; +import { useEffect, useState } from "react"; +import { authenticationProviderInstance } from "@/lib/authentication-provider.ts"; +import { ApiClient } from "@/lib/api.ts"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert.tsx"; +import { useNavigate } from "react-router-dom"; + +const api = new ApiClient(authenticationProviderInstance); + +interface OtpLoginFormProps { + onSuccess: (token: string, api: ApiClient) => Promise; +} + +export function OtpLoginForm({ onSuccess }: OtpLoginFormProps) { + const navigate = useNavigate(); + const [errorMessage, setErrorMessage] = useState(null); + const onClick = () => { + window.location.href = `${import.meta.env.VITE_API_BASE}/api/auth/otp/start` + } + useEffect(() => { + const params = new URLSearchParams(window.location.search); + const token = params.get("token"); + const error = params.get("error") + + if (token) { + onSuccess(token, api); + window.history.replaceState({}, document.title, window.location.pathname); + } + if (error) { + setErrorMessage(error); + window.history.replaceState({}, document.title, window.location.pathname); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [navigate]); + return ( +
+ {errorMessage && ( + + Login failed + {errorMessage} + + )} + +
+ ); +} diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 02a5b5b2..7b45bfb3 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -76,7 +76,7 @@ export class ApiClient { async getAllDocumentSourcePlugins(): Promise { const response = await fetch(`${this.apiBase}/api/plugins/document_source_plugins?page_size=1000`, this._requestConfiguration()); return (await response.json()).choices as SourcePlugin[]; - } + } catalog(): CatalogApiClient { return new CatalogApiClient(this.apiBase, this.authenticationProvider); @@ -110,4 +110,4 @@ export class ApiClient { } } } -} +} \ No newline at end of file diff --git a/frontend/src/pages/login.tsx b/frontend/src/pages/login.tsx index 1f5949df..ebac0691 100644 --- a/frontend/src/pages/login.tsx +++ b/frontend/src/pages/login.tsx @@ -1,8 +1,34 @@ import { LoginForm } from "@/components/login/login-form.tsx"; import logoUrl from '@/assets/logo.png'; import logoSvgUrl from '@/assets/logo.svg'; +import {authenticationProviderInstance} from "@/lib/authentication-provider.ts"; +import {ApiClient} from "@/lib/api.ts"; +import {useNavigate} from "react-router-dom"; +import {OtpLoginForm} from "@/components/login/otp-login-form.tsx"; export function LoginPage() { + const navigate = useNavigate(); + + const onSuccess = async (token:string, api: ApiClient) => { + authenticationProviderInstance.login(token); + + const dataSets = await api.dataSets().getDataSets(); + if (dataSets.length === 0) { + const account = await api.getAccount(); + if (account.isStaff) { + navigate('/onboarding'); + } else { + navigate('/no-data-sets'); + } + } else { + const dataSetAgents = await api.agents().getDatasetAvailableAgents(dataSets[0].id!) + const agentId = dataSetAgents?.[0]?.id; + const page = agentId + ? `/data-sets/${dataSets[0].id}/chat/new/${agentId}` + : `/data-sets/${dataSets[0].id}/chat/new`; + navigate(page); + } + } return ( <>
@@ -25,7 +51,22 @@ export function LoginPage() { Enter your email and password to get started

- + + {import.meta.env.VITE_OTP_ENABLED === "true" && ( + <> +
+
+ +
+
+ + Or + +
+
+ + + )} diff --git a/server/account/services.py b/server/account/services.py index c99bfc0e..693e7020 100644 --- a/server/account/services.py +++ b/server/account/services.py @@ -1,4 +1,7 @@ +from abc import ABC, abstractmethod + from django.conf import settings +from rest_framework.authtoken.models import Token from .models import User @@ -16,3 +19,31 @@ def is_service_account_name_available(self, name: str) -> bool: """ email = self.generate_service_account_email(name) return not User.objects.filter(email=email).exists() + + +class OTPLoginService(ABC): + NOT_CONFIGURED_ERROR_MESSAGE = "OTP service not configured." + + def create_user(self, email: str) -> User: + user, _ = User.objects.get_or_create(email=email) + return user + + def create_token(self, user: User) -> Token: + token, _ = Token.objects.get_or_create(user=user) + return token + + @abstractmethod + def get_email_from_token(self, token: str) -> str: + pass + + @abstractmethod + def get_redirect_url(self) -> str: + pass + + @abstractmethod + def get_token(self, code: str) -> str: + pass + + @abstractmethod + def is_enabled(self) -> bool: + pass diff --git a/server/account/urls.py b/server/account/urls.py index 4fb3bbca..ea5df081 100644 --- a/server/account/urls.py +++ b/server/account/urls.py @@ -7,6 +7,8 @@ urlpatterns = [ path("api/auth/login", account.views.login.LoginView.as_view(), name="login"), + path("api/auth/otp/callback", account.views.login.OTPCallbackView.as_view(), name="otp_callback"), + path("api/auth/otp/start", account.views.login.OTPStartView.as_view(), name="otp_start"), path("api/account", account.views.accounts.AccountView.as_view(), name="account"), path("api/users", account.views.users.UserListView.as_view(), name="user_list"), path("api/users/", account.views.users.UserView.as_view(), name="user_details"), diff --git a/server/account/views/login.py b/server/account/views/login.py index 8887638e..94567e4e 100644 --- a/server/account/views/login.py +++ b/server/account/views/login.py @@ -1,12 +1,20 @@ +import logging +import urllib.parse + +from django.conf import settings from django.contrib.auth import authenticate +from django.shortcuts import redirect from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework.authtoken.models import Token from rest_framework.response import Response from rest_framework.views import APIView +from utils.functions import import_from_string from account.serializers import TokenResponseSerializer +logger = logging.getLogger(__name__) + class LoginView(APIView): @swagger_auto_schema( @@ -30,3 +38,73 @@ def post(self, request): serializer = TokenResponseSerializer({"token": token.key}) return Response(serializer.data) return Response({}, status=403) + + +class OTPCallbackView(APIView): + DEFAULT_ERROR_MESSAGE = "Could not sign in. Please try again." + + @swagger_auto_schema( + operation_description="Callback endpoint for authentication provider hosted OTP login. Exchanges code for access token and redirects.", + manual_parameters=[ + openapi.Parameter( + "code", + openapi.IN_QUERY, + description="Authorization code from auth provider", + type=openapi.TYPE_STRING, + required=True, + ), + ], + responses={ + 200: openapi.Response(description="User authenticated successfully"), + 400: "Bad Request", + 401: "Unauthorized", + }, + ) + def get(self, request): + code = request.query_params.get("code") + redirect_uri = f"{settings.FRONTEND_BASE_URL}/login" + try: + otp_service_class = import_from_string(settings.OTP_AUTH_SERVICE) + otp_service = otp_service_class() + if not otp_service.is_enabled(): + return redirect( + f"{settings.FRONTEND_BASE_URL}/login?{urllib.parse.urlencode({'error': otp_service.NOT_CONFIGURED_ERROR_MESSAGE})}" + ) + token = otp_service.get_token(code) + user_email = otp_service.get_email_from_token(token) + user = otp_service.create_user(email=user_email) + token = otp_service.create_token(user=user) + return redirect(f"{redirect_uri}?{urllib.parse.urlencode({'token': token.key})}") + except Exception as e: + logger.error(e, exc_info=True) + return redirect(f"{redirect_uri}?{urllib.parse.urlencode({'error': self.DEFAULT_ERROR_MESSAGE})}") + + +class OTPStartView(APIView): + DEFAULT_ERROR_MESSAGE = "Service unavailable, please try again later." + + @swagger_auto_schema( + operation_description="Start the OTP login flow. Redirects user to the authentication provider hosted login page.", + responses={ + 302: openapi.Response(description="Redirect to authentication provider or frontend login with error"), + 400: "Bad Request", + 503: "Service Unavailable", + }, + ) + def get(self, request): + try: + otp_service_class = import_from_string(settings.OTP_AUTH_SERVICE) + otp_service = otp_service_class() + if not otp_service.is_enabled(): + return redirect( + f"{settings.FRONTEND_BASE_URL}/login?{urllib.parse.urlencode({'error': otp_service.NOT_CONFIGURED_ERROR_MESSAGE})}" + ) + auth_url = otp_service.get_redirect_url() + + return redirect(auth_url) + + except Exception as e: + logger.error(e, exc_info=True) + return redirect( + f"{settings.FRONTEND_BASE_URL}/login?{urllib.parse.urlencode({'error': self.DEFAULT_ERROR_MESSAGE})}" + ) diff --git a/server/pecl/settings.py b/server/pecl/settings.py index d0a9c817..33702e03 100644 --- a/server/pecl/settings.py +++ b/server/pecl/settings.py @@ -278,4 +278,9 @@ SERVICE_ACCOUNT_DOMAIN = env.str("SERVICE_ACCOUNT_DOMAIN", "enthusiast.internal") +OTP_AUTH_SERVICE = "" + +BASE_URL = env.str("BASE_URL", default="http://localhost:10000") +FRONTEND_BASE_URL = env.str("FRONTEND_BASE_URL", default="http://localhost:10001") + from .settings_override import * # noqa diff --git a/server/sample.env b/server/sample.env index 9c65aba9..8e0e6d03 100644 --- a/server/sample.env +++ b/server/sample.env @@ -23,3 +23,8 @@ ECL_ADMIN_PASSWORD=changeme # === API keys === OPENAI_API_KEY= GOOGLE_API_KEY= + + +FRONTEND_BASE_URL=http://localhost:10001 +BASE_URL=http://localhost:10000 +