Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 18 additions & 45 deletions frontend/src/components/login/login-form.tsx
Original file line number Diff line number Diff line change
@@ -1,53 +1,39 @@
import { cn } from "@/lib/utils"
import { cn } from "@/lib/utils";
import { Input } from "@/components/ui/input.tsx";
import { Button } from "@/components/ui/button.tsx";
import { HTMLAttributes, SyntheticEvent, useState } from "react";
import { Spinner } from "@/components/util/spinner.tsx";
import { authenticationProviderInstance } from "@/lib/authentication-provider.ts";
import { useNavigate } from "react-router-dom";
import { ApiClient } from "@/lib/api.ts";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert.tsx";

const api = new ApiClient(authenticationProviderInstance);

export function LoginForm({ className, ...props }: HTMLAttributes<HTMLDivElement>) {
interface LoginFormProps extends HTMLAttributes<HTMLDivElement> {
onSuccess: (api: ApiClient) => Promise<void>;
}

export function LoginForm({ className, onSuccess, ...props }: LoginFormProps) {
const [isLoading, setIsLoading] = useState(false);
const [isError, setIsError] = useState(false);
const navigate = useNavigate();

async function onSubmit(event: SyntheticEvent) {
const onSubmit = async (event: SyntheticEvent) => {
event.preventDefault();
setIsLoading(true);

const email = (event.target as HTMLFormElement).email.value;
const password = (event.target as HTMLFormElement).password.value;
const form = event.target as HTMLFormElement;
const email = (form.email as HTMLInputElement).value;
const password = (form.password as HTMLInputElement).value;

try {
await api.login(email, password);
authenticationProviderInstance.login();

const dataSets = await api.dataSets().getDataSets();
if (dataSets.length === 0) {
const account = await api.getAccount();
if (account.isStaff) {
navigate('/onboarding');
} else {
navigate('/no-data-sets');
}
} else {
const dataSetAgents = await api.agents().getDatasetAvailableAgents(dataSets[0].id!)
const agentId = dataSetAgents?.[0]?.id;
const page = agentId
? `/data-sets/${dataSets[0].id}/chat/new/${agentId}`
: `/data-sets/${dataSets[0].id}/chat/new`;
navigate(page);
}
await onSuccess(api);
} catch {
setIsError(true);
} finally {
setIsLoading(false);
}
}
};

return (
<div className={cn("grid gap-6", className)} {...props}>
Expand All @@ -60,6 +46,7 @@ export function LoginForm({ className, ...props }: HTMLAttributes<HTMLDivElement
</Alert>}
<Input
id="email"
name="email"
placeholder="user@example.com"
type="email"
autoCapitalize="none"
Expand All @@ -69,6 +56,7 @@ export function LoginForm({ className, ...props }: HTMLAttributes<HTMLDivElement
/>
<Input
id="password"
name="password"
placeholder="password"
type="password"
autoCapitalize="none"
Expand All @@ -77,27 +65,12 @@ export function LoginForm({ className, ...props }: HTMLAttributes<HTMLDivElement
disabled={isLoading}
/>
</div>
<Button disabled={isLoading}>
{isLoading && (
<Spinner />
)}
<Button type="submit" disabled={isLoading}>
{isLoading && <Spinner />}
Continue
</Button>
</div>
</form>
<div className="relative">
<div className="absolute inset-0 flex items-center">
<span className="w-full border-t"/>
</div>
<div className="relative flex justify-center text-xs uppercase">
<span className="bg-background px-2 text-muted-foreground">
Or
</span>
</div>
</div>
<Button disabled variant="outline">
Log in with SSO
</Button>
</div>
)
}
);
}
37 changes: 37 additions & 0 deletions frontend/src/components/login/sso-login-form.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { Button } from "@/components/ui/button.tsx";
import { useEffect, useState } from "react";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert.tsx";

export function SSOLoginForm() {
const [errorMessage, setErrorMessage] = useState<string | null>(null);

useEffect(() => {
const params = new URLSearchParams(window.location.search);
const error = params.get("error");
if (error) {
setErrorMessage(error);
window.history.replaceState({}, document.title, window.location.pathname);
}
}, []);

const onClick = () => {
window.location.href = `${import.meta.env.VITE_API_BASE}/login/sso/`;
};

return (
<div className="grid gap-1">
{errorMessage && (
<Alert variant="destructive">
<AlertTitle>Login failed</AlertTitle>
<AlertDescription>{errorMessage}</AlertDescription>
</Alert>
)}
<Button
variant="outline"
onClick={onClick}
>
Log in with SSO
</Button>
</div>
);
}
61 changes: 57 additions & 4 deletions frontend/src/pages/login.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,48 @@
import { useEffect } from "react";
import { LoginForm } from "@/components/login/login-form.tsx";
import logoUrl from '@/assets/logo.png';
import logoSvgUrl from '@/assets/logo.svg';
import logoUrl from "@/assets/logo.png";
import logoSvgUrl from "@/assets/logo.svg";
import { authenticationProviderInstance } from "@/lib/authentication-provider.ts";
import { ApiClient } from "@/lib/api.ts";
import { useNavigate } from "react-router-dom";
import { SSOLoginForm } from "@/components/login/sso-login-form.tsx";

const api = new ApiClient(authenticationProviderInstance);

async function continueAfterLogin(navigate: (path: string) => void) {
authenticationProviderInstance.login();
const dataSets = await api.dataSets().getDataSets();
if (dataSets.length === 0) {
const account = await api.getAccount();
if (account.isStaff) {
navigate("/onboarding");
} else {
navigate("/no-data-sets");
}
} else {
const dataSetAgents = await api.agents().getDatasetAvailableAgents(dataSets[0].id!);
const agentId = dataSetAgents?.[0]?.id;
const page = agentId
? `/data-sets/${dataSets[0].id}/chat/new/${agentId}`
: `/data-sets/${dataSets[0].id}/chat/new`;
navigate(page);
}
}

export function LoginPage() {
const navigate = useNavigate();

useEffect(() => {
api
.getAccount()
.then(() => continueAfterLogin(navigate))
.catch(() => {});
}, [navigate]);

const onSuccess = async () => {
await continueAfterLogin(navigate);
};

return (
<>
<div className="container relative hidden h-full flex-col items-center justify-center md:grid lg:max-w-none lg:grid-cols-2 lg:px-0">
Expand All @@ -25,10 +65,23 @@ export function LoginPage() {
Enter your email and password to get started
</p>
</div>
<LoginForm />
<LoginForm onSuccess={onSuccess}/>
<>
<div className="relative">
<div className="absolute inset-0 flex items-center">
<span className="w-full border-t" />
</div>
<div className="relative flex justify-center text-xs uppercase">
<span className="bg-background px-2 text-muted-foreground">
Or
</span>
</div>
</div>
<SSOLoginForm />
</>
</div>
</div>
</div>
</>
)
}
}
22 changes: 22 additions & 0 deletions server/account/pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import logging

from django.conf import settings
from social_core.exceptions import AuthForbidden
from utils.functions import import_from_string

logger = logging.getLogger(__name__)


def update_user(strategy, backend, user=None, *args, **kwargs):
if not user:
return
if kwargs.get("is_new") is False:
return
try:
sso_provider_class = import_from_string(settings.SSO_PROVIDER_SERVICE)
sso_provider_class.update_user(user)
except NotImplementedError:
return
except Exception as e:
logger.error(e, exc_info=True)
raise AuthForbidden(backend)
4 changes: 4 additions & 0 deletions server/account/services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .service_accounts import ServiceAccountNameService
from .sso_provider import SSOProviderService

__all__ = ["ServiceAccountNameService", "SSOProviderService"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.conf import settings

from .models import User
from account.models import User


class ServiceAccountNameService:
Expand Down
23 changes: 23 additions & 0 deletions server/account/services/sso_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from django.conf import settings
from django.contrib.auth import get_user_model
from django.shortcuts import redirect
from utils.functions import import_from_string

User = get_user_model()


class SSOProviderService:
NOT_CONFIGURED_ERROR_MESSAGE = "SSO provider service not configured."

@staticmethod
def login():
backend = import_from_string(settings.DEFAULT_SSO_PROVIDER_BACKEND)
return redirect(f"/login/{backend.name}/")

@staticmethod
def is_enabled() -> bool:
return bool(getattr(settings, "DEFAULT_SSO_PROVIDER_BACKEND", ""))

@staticmethod
def update_user(user: User):
raise NotImplementedError
Empty file.
52 changes: 52 additions & 0 deletions server/account/tests/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from unittest.mock import MagicMock, patch

import pytest
from social_core.exceptions import AuthForbidden

from account.models import User
from account.pipelines import update_user


@pytest.mark.django_db
class TestUpdateUserPipeline:
def test_returns_early_when_no_user(self):
strategy = MagicMock()
backend = MagicMock()
result = update_user(strategy, backend, user=None)
assert result is None

def test_returns_early_when_user_not_new(self):
from model_bakery import baker

user = baker.make(User, email="existing@example.com")
strategy = MagicMock()
backend = MagicMock()
result = update_user(strategy, backend, user=user, is_new=False)
assert result is None

def test_returns_early_on_not_implemented_error(self):
from model_bakery import baker

user = baker.make(User, email="new@example.com")
strategy = MagicMock()
backend = MagicMock()
mock_provider_class = MagicMock()
mock_provider_class.update_user.side_effect = NotImplementedError
with patch("account.pipelines.import_from_string", return_value=mock_provider_class):
result = update_user(strategy, backend, user=user, is_new=True)
assert result is None
mock_provider_class.update_user.assert_called_once_with(user)

def test_raises_auth_forbidden_on_other_exception(self):
from model_bakery import baker

user = baker.make(User, email="new@example.com")
strategy = MagicMock()
backend = MagicMock()
mock_provider_class = MagicMock()
mock_provider_class.update_user.side_effect = ValueError("Update failed")
with patch("account.pipelines.import_from_string", return_value=mock_provider_class):
with pytest.raises(AuthForbidden) as exc_info:
update_user(strategy, backend, user=user, is_new=True)
assert exc_info.value.backend is backend
mock_provider_class.update_user.assert_called_once_with(user)
Empty file.
38 changes: 38 additions & 0 deletions server/account/tests/test_services/test_sso_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from account.models import User
from account.services.sso_provider import SSOProviderService


class TestSSOProviderServiceIsEnabled:
def test_returns_false_when_backend_not_configured(self, settings):
settings.DEFAULT_SSO_PROVIDER_BACKEND = ""
assert SSOProviderService.is_enabled() is False

def test_returns_false_when_backend_setting_missing(self):
assert SSOProviderService.is_enabled() is False

def test_returns_true_when_backend_configured(self, settings):
settings.DEFAULT_SSO_PROVIDER_BACKEND = "social_core.backends.google.GoogleOAuth2"
assert SSOProviderService.is_enabled() is True


class TestSSOProviderServiceLogin:
def test_redirects_to_backend_login_url(self, settings):
settings.DEFAULT_SSO_PROVIDER_BACKEND = "account.tests.test_services.test_sso_provider.FakeBackend"
response = SSOProviderService.login()
assert response.url == "/login/fake-backend/"


class TestSSOProviderServiceUpdateUser:
@pytest.mark.django_db
def test_raises_not_implemented_error(self):
from model_bakery import baker

user = baker.make(User, email="test@example.com")
with pytest.raises(NotImplementedError):
SSOProviderService.update_user(user)


class FakeBackend:
name = "fake-backend"
Empty file.
Loading