diff --git a/migrations/1776346034220_invitation_is_trial.ts b/migrations/1776346034220_invitation_is_trial.ts new file mode 100644 index 000000000..906d993b7 --- /dev/null +++ b/migrations/1776346034220_invitation_is_trial.ts @@ -0,0 +1,15 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { type Kysely, sql } from "kysely"; + +export async function up(db: Kysely): Promise { + await db.schema + .alterTable("invitation") + .addColumn("isTrial", "boolean", (col) => + col.notNull().defaultTo(sql`false`), + ) + .execute(); +} + +export async function down(db: Kysely): Promise { + await db.schema.alterTable("invitation").dropColumn("isTrial").execute(); +} diff --git a/migrations/1776346034221_user_trial_ends_at.ts b/migrations/1776346034221_user_trial_ends_at.ts new file mode 100644 index 000000000..dcd8696aa --- /dev/null +++ b/migrations/1776346034221_user_trial_ends_at.ts @@ -0,0 +1,10 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import type { Kysely } from "kysely"; + +export async function up(db: Kysely): Promise { + await db.schema.alterTable("user").addColumn("trialEndsAt", "text").execute(); +} + +export async function down(db: Kysely): Promise { + await db.schema.alterTable("user").dropColumn("trialEndsAt").execute(); +} diff --git a/src/app/(private)/layout.tsx b/src/app/(private)/layout.tsx index 2dc777f4a..351e4fbbc 100644 --- a/src/app/(private)/layout.tsx +++ b/src/app/(private)/layout.tsx @@ -1,6 +1,7 @@ import { getServerSession } from "@/auth"; import { redirectToLogin } from "@/auth/redirectToLogin"; import SentryFeedbackWidget from "@/components/SentryFeedbackWidget"; +import TrialBanner from "@/components/TrialBanner"; import type { Metadata } from "next"; export const metadata: Metadata = { @@ -16,8 +17,11 @@ export default async function PrivateLayout({ if (!serverSession.currentUser) { await redirectToLogin(); } + + const { trialEndsAt } = serverSession.currentUser ?? {}; return ( <> + {trialEndsAt && } {children} diff --git a/src/app/HTMLBody.tsx b/src/app/HTMLBody.tsx new file mode 100644 index 000000000..56daf78fd --- /dev/null +++ b/src/app/HTMLBody.tsx @@ -0,0 +1,29 @@ +import { IBM_Plex_Mono, IBM_Plex_Sans } from "next/font/google"; +import "./global.css"; + +const ibmPlexSans = IBM_Plex_Sans({ + subsets: ["latin"], + weight: ["300", "400", "500", "600", "700"], + variable: "--font-ibm-plex-sans", + display: "swap", + preload: true, +}); + +const ibmPlexMono = IBM_Plex_Mono({ + subsets: ["latin"], + weight: ["400", "500"], + variable: "--font-ibm-plex-mono", + display: "swap", + preload: true, +}); + +export default function HTMLBody({ children }: { children: React.ReactNode }) { + return ( + + {children} + + ); +} diff --git a/src/app/global-error.tsx b/src/app/global-error.tsx index 9388e06e0..1ca45e1dd 100644 --- a/src/app/global-error.tsx +++ b/src/app/global-error.tsx @@ -2,7 +2,51 @@ import * as Sentry from "@sentry/nextjs"; import NextError from "next/error"; +import Link from "next/link"; import { useEffect } from "react"; +import { logout } from "@/auth/logout"; +import { TRIAL_EXPIRED_MESSAGE } from "@/constants"; +import { Button } from "@/shadcn/ui/button"; +import { Card, CardContent, CardTitle } from "@/shadcn/ui/card"; +import HTMLBody from "./HTMLBody"; + +function isTrialExpiredError(error: Error) { + return error.name === "TRPCError" && error.message === TRIAL_EXPIRED_MESSAGE; +} + +function TrialExpired() { + return ( +
+
+
+ + {/* eslint-disable-next-line @next/next/no-img-element */} + Mapped + +
+
+
+ + + Trial Expired +

+ Your trial period has ended. Please contact us to continue using + Mapped. +

+
+ + +
+
+
+
+
+ ); +} export default function GlobalError({ error, @@ -10,18 +54,26 @@ export default function GlobalError({ error: Error & { digest?: string }; }) { useEffect(() => { - Sentry.captureException(error); + if (!isTrialExpiredError(error)) { + Sentry.captureException(error); + } }, [error]); + if (isTrialExpiredError(error)) { + return ( + + + + ); + } + return ( - - - {/* `NextError` is the default Next.js error page component. Its type - definition requires a `statusCode` prop. However, since the App Router - does not expose status codes for errors, we simply pass 0 to render a - generic error message. */} - - - + + {/* `NextError` is the default Next.js error page component. Its type + definition requires a `statusCode` prop. However, since the App Router + does not expose status codes for errors, we simply pass 0 to render a + generic error message. */} + + ); } diff --git a/src/app/layout.tsx b/src/app/layout.tsx index 04bd83801..796258767 100644 --- a/src/app/layout.tsx +++ b/src/app/layout.tsx @@ -1,5 +1,4 @@ import { HydrationBoundary, dehydrate } from "@tanstack/react-query"; -import { IBM_Plex_Mono, IBM_Plex_Sans } from "next/font/google"; import { cookies } from "next/headers"; import { getServerSession } from "@/auth"; import { ORGANISATION_COOKIE_NAME } from "@/constants"; @@ -11,26 +10,10 @@ import { TRPCReactProvider } from "@/services/trpc/react"; import { createCaller, getQueryClient, trpc } from "@/services/trpc/server"; import { Toaster } from "@/shadcn/ui/sonner"; import { getAbsoluteUrl } from "@/utils/appUrl"; +import HTMLBody from "./HTMLBody"; import type { Organisation } from "@/models/Organisation"; import type { Metadata, Viewport } from "next"; import "nprogress/nprogress.css"; -import "./global.css"; - -const ibmPlexSans = IBM_Plex_Sans({ - subsets: ["latin"], - weight: ["300", "400", "500", "600", "700"], - variable: "--font-ibm-plex-sans", - display: "swap", - preload: true, -}); - -const ibmPlexMono = IBM_Plex_Mono({ - subsets: ["latin"], - weight: ["400", "500"], - variable: "--font-ibm-plex-mono", - display: "swap", - preload: true, -}); export const metadata: Metadata = { metadataBase: new URL(getAbsoluteUrl()), @@ -65,31 +48,24 @@ export default async function RootLayout({ const storedOrgId = cookieStore.get(ORGANISATION_COOKIE_NAME)?.value ?? null; return ( - - - - - - - - -
- {children} -
- -
-
-
-
-
-
- - + + + + + + + +
{children}
+ +
+
+
+
+
+
+
); } diff --git a/src/auth/index.ts b/src/auth/index.ts index 15dd0cf5d..22f6810a3 100644 --- a/src/auth/index.ts +++ b/src/auth/index.ts @@ -19,6 +19,7 @@ export const getServerSession = cache(async (): Promise => { name: user.name, avatarUrl: user.avatarUrl, role: user.role, + trialEndsAt: user.trialEndsAt, }, }; } diff --git a/src/auth/logout.ts b/src/auth/logout.ts new file mode 100644 index 000000000..81cc0f885 --- /dev/null +++ b/src/auth/logout.ts @@ -0,0 +1,11 @@ +import { JWT_LIFETIME_SECONDS } from "@/constants"; + +export async function logout() { + try { + await fetch("/api/logout", { method: "POST" }); + } catch { + // Server unavailable so JWT cookie may not be removed - set client side LoggedOut cookie + document.cookie = `LoggedOut=1; path=/; SameSite=lax; max-age=${JWT_LIFETIME_SECONDS}`; + } + window.location.href = "/"; +} diff --git a/src/authTypes.ts b/src/authTypes.ts index 1a021ee21..d8eba6cdb 100644 --- a/src/authTypes.ts +++ b/src/authTypes.ts @@ -6,6 +6,7 @@ export interface CurrentUser { name: string; avatarUrl?: string | null; role?: UserRole | null; + trialEndsAt?: Date | null; } export interface ServerSession { diff --git a/src/components/SidebarUserMenu.tsx b/src/components/SidebarUserMenu.tsx index 688ce1dc1..a7fe86993 100644 --- a/src/components/SidebarUserMenu.tsx +++ b/src/components/SidebarUserMenu.tsx @@ -7,7 +7,7 @@ import { LogOutIcon, SettingsIcon, } from "lucide-react"; -import { JWT_LIFETIME_SECONDS } from "@/constants"; +import { logout } from "@/auth/logout"; import { useCurrentUser } from "@/hooks"; import { useOrganisations } from "@/hooks/useOrganisations"; import { Avatar, AvatarFallback, AvatarImage } from "@/shadcn/ui/avatar"; @@ -40,13 +40,7 @@ export default function SidebarUserMenu() { const onSubmitLogout = async (e: SyntheticEvent) => { e.preventDefault(); - try { - await fetch("/api/logout", { method: "POST" }); - } catch { - // Server unavailable so JWT cookie may not be removed - set client side LoggedOut cookie - document.cookie = `LoggedOut=1; path=/; SameSite=lax; max-age=${JWT_LIFETIME_SECONDS}`; - } - location.href = "/"; + await logout(); }; return ( diff --git a/src/components/TrialBanner.tsx b/src/components/TrialBanner.tsx new file mode 100644 index 000000000..b1357f161 --- /dev/null +++ b/src/components/TrialBanner.tsx @@ -0,0 +1,49 @@ +"use client"; + +import { useState } from "react"; +import { Alert, AlertDescription } from "@/shadcn/ui/alert"; + +const DISMISSED_KEY = "mapped-trial-banner-dismissed"; + +function getDaysRemaining(trialEndsAt: Date) { + const ms = new Date(trialEndsAt).getTime() - Date.now(); + if (ms <= 0) return null; + return Math.ceil(ms / (1000 * 60 * 60 * 24)); +} + +export default function TrialBanner({ trialEndsAt }: { trialEndsAt: Date }) { + const [dismissed, setDismissed] = useState(() => { + if (typeof window === "undefined") return false; + return localStorage.getItem(DISMISSED_KEY) === "true"; + }); + // useState (not useMemo) because Date.now() triggers the react-hooks/purity lint rule + const [daysRemaining] = useState(() => getDaysRemaining(trialEndsAt)); + + if (daysRemaining === null || dismissed) { + return null; + } + + function handleDismiss() { + localStorage.setItem(DISMISSED_KEY, "true"); + setDismissed(true); + } + + return ( + + + + You're on a trial period.{" "} + {daysRemaining === 1 + ? "1 day remaining." + : `${daysRemaining} days remaining.`} + + + + + ); +} diff --git a/src/constants/index.ts b/src/constants/index.ts index c32b67408..902d563f4 100644 --- a/src/constants/index.ts +++ b/src/constants/index.ts @@ -13,6 +13,8 @@ export const DATA_RECORDS_JOB_BATCH_SIZE = 100; export const DEFAULT_AUTH_REDIRECT = "/maps"; +export const DEFAULT_TRIAL_PERIOD_DAYS = 30; + export const DEFAULT_ZOOM = 5; export const DEFAULT_CUSTOM_COLOR = "#3b82f6"; @@ -42,3 +44,5 @@ export const ORGANISATION_COOKIE_NAME = "MappedOrgId"; export const SORT_BY_LOCATION = "__location"; // Special sort column to sort by `dataSource.columnRoles.nameColumns` export const SORT_BY_NAME_COLUMNS = "__name"; + +export const TRIAL_EXPIRED_MESSAGE = "Your trial has expired."; diff --git a/src/models/Invitation.ts b/src/models/Invitation.ts index 53a09b146..04d327d1c 100644 --- a/src/models/Invitation.ts +++ b/src/models/Invitation.ts @@ -10,6 +10,7 @@ export const invitationSchema = z.object({ createdAt: z.date(), updatedAt: z.date(), used: z.boolean(), + isTrial: z.boolean(), }); export type Invitation = z.infer; diff --git a/src/models/User.ts b/src/models/User.ts index 09c98c495..35fdf26d3 100644 --- a/src/models/User.ts +++ b/src/models/User.ts @@ -18,6 +18,7 @@ export const userSchema = z.object({ avatarUrl: z.string().url().trim().nullish(), passwordHash: z.string(), role: z.nativeEnum(UserRole).nullish(), + trialEndsAt: z.date().nullish(), }); export type User = z.infer; diff --git a/src/server/models/Invitation.ts b/src/server/models/Invitation.ts index 59e8cf893..7d2f33f93 100644 --- a/src/server/models/Invitation.ts +++ b/src/server/models/Invitation.ts @@ -10,6 +10,7 @@ import type { export type InvitationTable = Invitation & { id: GeneratedAlways; used: Generated; + isTrial: Generated; createdAt: ColumnType; updatedAt: ColumnType; }; diff --git a/src/server/models/User.ts b/src/server/models/User.ts index 069b75d49..f989b5e06 100644 --- a/src/server/models/User.ts +++ b/src/server/models/User.ts @@ -9,6 +9,7 @@ import type { export type UserTable = User & { id: GeneratedAlways; createdAt: ColumnType; + trialEndsAt: ColumnType; }; export type NewUser = Insertable; export type UserUpdate = Updateable; diff --git a/src/server/repositories/User.ts b/src/server/repositories/User.ts index c6fe88b17..057e15c6b 100644 --- a/src/server/repositories/User.ts +++ b/src/server/repositories/User.ts @@ -83,6 +83,15 @@ export function listUsers() { .execute(); } +export async function updateUserTrialEndsAt(id: string, trialEndsAt: Date) { + return db + .updateTable("user") + .where("id", "=", id) + .set({ trialEndsAt: trialEndsAt.toISOString() }) + .returningAll() + .executeTakeFirstOrThrow(); +} + export async function updateUserRole(id: string, role: UserRole | null) { return db .updateTable("user") diff --git a/src/server/services/database/schema.ts b/src/server/services/database/schema.ts index 34a8c4c68..cc804b4cb 100644 --- a/src/server/services/database/schema.ts +++ b/src/server/services/database/schema.ts @@ -66,6 +66,7 @@ export interface User { name: string; // text, NOT NULL, DEFAULT '' avatarUrl: string | null; // text, NULL role: string | null; // text, NULL — 'Advocate' or 'Superadmin' + trialEndsAt: string | null; // text, NULL — ISO timestamp when trial expires createdAt: string; // text, DEFAULT CURRENT_TIMESTAMP, NOT NULL } @@ -107,6 +108,7 @@ export interface Invitation { organisationId: string; // uuid, NOT NULL userId: string | null; // uuid, NULL used: boolean; // boolean, NOT NULL, DEFAULT false + isTrial: boolean; // boolean, NOT NULL, DEFAULT false createdAt: string; // text, DEFAULT CURRENT_TIMESTAMP, NOT NULL updatedAt: string; // text, DEFAULT CURRENT_TIMESTAMP, NOT NULL diff --git a/src/server/trpc/index.ts b/src/server/trpc/index.ts index eb3e51412..3d8d204d9 100644 --- a/src/server/trpc/index.ts +++ b/src/server/trpc/index.ts @@ -2,6 +2,7 @@ import { TRPCError, initTRPC } from "@trpc/server"; import superjson from "superjson"; import z, { ZodError } from "zod"; import { getServerSession } from "@/auth"; +import { TRIAL_EXPIRED_MESSAGE } from "@/constants"; import { UserRole } from "@/models/User"; import { getClientIp } from "@/server/services/ratelimit"; import { canReadDataSource } from "@/server/utils/auth"; @@ -70,6 +71,12 @@ const enforceUserIsAuthed = t.middleware(({ ctx, next }) => { code: "UNAUTHORIZED", message: "You must be logged in to perform this action.", }); + if (ctx.user.trialEndsAt && new Date(ctx.user.trialEndsAt) < new Date()) { + throw new TRPCError({ + code: "FORBIDDEN", + message: TRIAL_EXPIRED_MESSAGE, + }); + } return next({ ctx: { user: ctx.user } }); }); diff --git a/src/server/trpc/routers/auth.ts b/src/server/trpc/routers/auth.ts index 63cf01d70..137a7c619 100644 --- a/src/server/trpc/routers/auth.ts +++ b/src/server/trpc/routers/auth.ts @@ -4,6 +4,7 @@ import { JWTExpired } from "jose/errors"; import { NoResultError } from "kysely"; import z from "zod"; import { setJWT } from "@/auth/jwt"; +import { DEFAULT_TRIAL_PERIOD_DAYS } from "@/constants"; import { passwordSchema } from "@/models/User"; import ForgotPassword from "@/server/emails/ForgotPassword"; import { @@ -15,6 +16,7 @@ import { findUserByEmail, findUserByToken, updateUser, + updateUserTrialEndsAt, upsertUser, } from "@/server/repositories/User"; import logger from "@/server/services/logger"; @@ -40,12 +42,20 @@ export const authRouter = router({ const invitation = await findAndUseInvitation(payload.invitationId); // Create user with provided password - const user = await upsertUser({ + let user = await upsertUser({ email: invitation.email, name: invitation.name, password, }); + // Set trial end date for trial invitations + if (invitation.isTrial && !user.trialEndsAt) { + const trialEndsAt = new Date( + Date.now() + DEFAULT_TRIAL_PERIOD_DAYS * 24 * 60 * 60 * 1000, + ); + user = await updateUserTrialEndsAt(user.id, trialEndsAt); + } + // Link user to organisation await upsertOrganisationUser({ organisationId: invitation.organisationId, diff --git a/src/server/trpc/routers/invitation.ts b/src/server/trpc/routers/invitation.ts index 9796498fc..05c9ac5de 100644 --- a/src/server/trpc/routers/invitation.ts +++ b/src/server/trpc/routers/invitation.ts @@ -1,6 +1,7 @@ import { TRPCError } from "@trpc/server"; import { SignJWT } from "jose"; import z from "zod"; +import { UserRole } from "@/models/User"; import copyMapsToOrganisation from "@/server/commands/copyMapsToOrganisation"; import ensureOrganisationMap from "@/server/commands/ensureOrganisationMap"; import Invite from "@/server/emails/Invite"; @@ -80,6 +81,7 @@ export const invitationRouter = router({ name: input.name, organisationId: org.id, senderOrganisationId: senderOrg.id, + isTrial: ctx.user.role !== UserRole.Superadmin, }); const secret = new TextEncoder().encode(process.env.JWT_SECRET || ""); diff --git a/src/services/trpc/react.tsx b/src/services/trpc/react.tsx index fe95805b6..9ab229f6d 100644 --- a/src/services/trpc/react.tsx +++ b/src/services/trpc/react.tsx @@ -11,6 +11,7 @@ import { observable } from "@trpc/server/observable"; import { createTRPCContext } from "@trpc/tanstack-react-query"; import { useState } from "react"; import superjson from "superjson"; +import { DEFAULT_AUTH_REDIRECT } from "@/constants"; import { clientDataSourceSerializer, hasPasswordHashSerializer, @@ -59,16 +60,18 @@ const errorLink: TRPCLink = () => { observer.next(value); }, error(err) { - if ( - err instanceof TRPCClientError && - err.data?.code === "UNAUTHORIZED" && - typeof window !== "undefined" - ) { - const redirectTo = encodeURIComponent( - window.location.pathname + window.location.search, - ); - window.location.href = `/login?redirectTo=${redirectTo}`; - return; + if (err instanceof TRPCClientError && typeof window !== "undefined") { + if (err.data?.code === "UNAUTHORIZED") { + const redirectTo = encodeURIComponent( + window.location.pathname + window.location.search, + ); + window.location.href = `/login?redirectTo=${redirectTo}`; + return; + } + if (err.data?.code === "FORBIDDEN") { + window.location.href = DEFAULT_AUTH_REDIRECT; + return; + } } observer.error(err); }, diff --git a/tests/unit/app/api/login/route.test.ts b/tests/unit/app/api/login/route.test.ts index bd890a794..164501f3d 100644 --- a/tests/unit/app/api/login/route.test.ts +++ b/tests/unit/app/api/login/route.test.ts @@ -118,6 +118,7 @@ describe("POST /api/login", () => { passwordHash: "", avatarUrl: undefined, role: undefined, + trialEndsAt: null, }); const request = makeRequest( { email: "user@example.com", password: "correctpassword" }, diff --git a/tests/unit/server/trpc/middleware.test.ts b/tests/unit/server/trpc/middleware.test.ts new file mode 100644 index 000000000..3abe29cf9 --- /dev/null +++ b/tests/unit/server/trpc/middleware.test.ts @@ -0,0 +1,80 @@ +import { v4 as uuidv4 } from "uuid"; +import { afterAll, describe, expect, test } from "vitest"; +import { TRIAL_EXPIRED_MESSAGE } from "@/constants"; +import { + deleteUser, + updateUserRole, + updateUserTrialEndsAt, + upsertUser, +} from "@/server/repositories/User"; +import { invitationRouter } from "@/server/trpc/routers/invitation"; +import type { UserRole } from "@/models/User"; + +const userIds: string[] = []; + +async function createTestUser(role?: UserRole | null) { + const user = await upsertUser({ + email: `test-${uuidv4()}@example.com`, + password: "test-password-123", + name: "Test User", + avatarUrl: null, + }); + userIds.push(user.id); + if (role) { + await updateUserRole(user.id, role); + } + return user; +} + +// listForUser is a simple protectedProcedure — good for testing the middleware +function makeCaller(user: Awaited> | null) { + return invitationRouter.createCaller({ user, ip: "127.0.0.1" }); +} + +describe("enforceUserIsAuthed trial expiry", () => { + test("allows user with no trialEndsAt", async () => { + const user = await createTestUser(); + const caller = makeCaller(user); + const result = await caller.listForUser(); + expect(Array.isArray(result)).toBe(true); + }); + + test("allows user with trialEndsAt in the future", async () => { + const user = await createTestUser(); + const futureDate = new Date(Date.now() + 7 * 24 * 60 * 60 * 1000); + await updateUserTrialEndsAt(user.id, futureDate); + const updatedUser = { ...user, trialEndsAt: futureDate }; + const caller = makeCaller(updatedUser); + const result = await caller.listForUser(); + expect(Array.isArray(result)).toBe(true); + }); + + test("blocks user with trialEndsAt in the past with FORBIDDEN", async () => { + const user = await createTestUser(); + const pastDate = new Date(Date.now() - 24 * 60 * 60 * 1000); + await updateUserTrialEndsAt(user.id, pastDate); + const updatedUser = { ...user, trialEndsAt: pastDate }; + const caller = makeCaller(updatedUser); + await expect(caller.listForUser()).rejects.toMatchObject({ + code: "FORBIDDEN", + message: TRIAL_EXPIRED_MESSAGE, + }); + }); + + test("blocks unauthenticated user with UNAUTHORIZED", async () => { + const caller = makeCaller(null); + await expect(caller.listForUser()).rejects.toMatchObject({ + code: "UNAUTHORIZED", + }); + }); +}); + +afterAll(async () => { + for (const id of userIds) { + try { + await deleteUser(id); + } catch { + // already deleted + } + } +}); diff --git a/tests/unit/server/trpc/routers/confirmInvite.test.ts b/tests/unit/server/trpc/routers/confirmInvite.test.ts new file mode 100644 index 000000000..396790e01 --- /dev/null +++ b/tests/unit/server/trpc/routers/confirmInvite.test.ts @@ -0,0 +1,147 @@ +import { SignJWT } from "jose"; +import { v4 as uuidv4 } from "uuid"; +import { afterAll, describe, expect, test, vi } from "vitest"; +import { DEFAULT_TRIAL_PERIOD_DAYS } from "@/constants"; +import { createInvitation } from "@/server/repositories/Invitation"; +import { upsertOrganisation } from "@/server/repositories/Organisation"; +import { deleteUser, findUserByEmail } from "@/server/repositories/User"; +import { authRouter } from "@/server/trpc/routers/auth"; + +vi.mock("@/auth/jwt", () => ({ + setJWT: vi.fn(), +})); +vi.mock("@/server/services/logger", () => ({ + default: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); +vi.mock("@/server/services/mailer", () => ({ + sendEmail: vi.fn(), +})); + +const userEmails: string[] = []; + +function makeCaller() { + return authRouter.createCaller({ user: null, ip: "127.0.0.1" }); +} + +async function createInviteToken(invitationId: string) { + const secret = new TextEncoder().encode(process.env.JWT_SECRET || ""); + return new SignJWT({ invitationId }) + .setProtectedHeader({ alg: "HS256" }) + .setExpirationTime("7d") + .sign(secret); +} + +describe("auth.confirmInvite", () => { + test("trial invitation sets trialEndsAt on the user", async () => { + const org = await upsertOrganisation({ name: `Org ${uuidv4()}` }); + const email = `trial-${uuidv4()}@example.com`; + userEmails.push(email); + + const invitation = await createInvitation({ + email, + name: "Trial User", + organisationId: org.id, + senderOrganisationId: org.id, + isTrial: true, + }); + + const token = await createInviteToken(invitation.id); + const caller = makeCaller(); + const result = await caller.confirmInvite({ + token, + password: "test-password-123", + }); + + expect(result.trialEndsAt).toBeTruthy(); + if (!result.trialEndsAt) return; + const trialEndsAt = new Date(result.trialEndsAt); + const expectedMin = new Date( + Date.now() + (DEFAULT_TRIAL_PERIOD_DAYS - 1) * 24 * 60 * 60 * 1000, + ); + const expectedMax = new Date( + Date.now() + (DEFAULT_TRIAL_PERIOD_DAYS + 1) * 24 * 60 * 60 * 1000, + ); + expect(trialEndsAt.getTime()).toBeGreaterThan(expectedMin.getTime()); + expect(trialEndsAt.getTime()).toBeLessThan(expectedMax.getTime()); + }); + + test("non-trial invitation does not set trialEndsAt", async () => { + const org = await upsertOrganisation({ name: `Org ${uuidv4()}` }); + const email = `nontrial-${uuidv4()}@example.com`; + userEmails.push(email); + + const invitation = await createInvitation({ + email, + name: "Regular User", + organisationId: org.id, + senderOrganisationId: org.id, + isTrial: false, + }); + + const token = await createInviteToken(invitation.id); + const caller = makeCaller(); + const result = await caller.confirmInvite({ + token, + password: "test-password-123", + }); + + expect(result.trialEndsAt).toBeNull(); + }); + + test("trial invitation does not overwrite existing trialEndsAt", async () => { + const org = await upsertOrganisation({ name: `Org ${uuidv4()}` }); + const email = `existing-trial-${uuidv4()}@example.com`; + userEmails.push(email); + + // First invitation sets trialEndsAt + const invitation1 = await createInvitation({ + email, + name: "Existing Trial User", + organisationId: org.id, + senderOrganisationId: org.id, + isTrial: true, + }); + + const token1 = await createInviteToken(invitation1.id); + const caller = makeCaller(); + const firstResult = await caller.confirmInvite({ + token: token1, + password: "test-password-123", + }); + const originalTrialEndsAt = firstResult.trialEndsAt; + expect(originalTrialEndsAt).toBeTruthy(); + + // Second trial invitation should not overwrite + const invitation2 = await createInvitation({ + email, + name: "Existing Trial User", + organisationId: org.id, + senderOrganisationId: org.id, + isTrial: true, + }); + + const token2 = await createInviteToken(invitation2.id); + const secondResult = await caller.confirmInvite({ + token: token2, + password: "test-password-123", + }); + + expect(secondResult.trialEndsAt).toBeTruthy(); + expect(originalTrialEndsAt).toBeTruthy(); + if (!secondResult.trialEndsAt || !originalTrialEndsAt) return; + expect(new Date(secondResult.trialEndsAt).getTime()).toBe( + new Date(originalTrialEndsAt).getTime(), + ); + }); +}); + +afterAll(async () => { + for (const email of userEmails) { + try { + const user = await findUserByEmail(email); + if (user) await deleteUser(user.id); + } catch { + // already deleted + } + } +}); diff --git a/tests/unit/server/trpc/routers/invitation.test.ts b/tests/unit/server/trpc/routers/invitation.test.ts index 0e3dc789c..ce01e81a5 100644 --- a/tests/unit/server/trpc/routers/invitation.test.ts +++ b/tests/unit/server/trpc/routers/invitation.test.ts @@ -27,6 +27,7 @@ import { updateUserRole, upsertUser, } from "@/server/repositories/User"; +import { db } from "@/server/services/database"; import { invitationRouter } from "@/server/trpc/routers/invitation"; const userIds: string[] = []; @@ -296,6 +297,50 @@ describe("invitation.create", () => { }); }); +describe("invitation.create isTrial", () => { + test("advocate invitation is marked as trial", async () => { + const senderOrg = await createSenderOrg(); + const advocate = await createTestUser(UserRole.Advocate, senderOrg.id); + const caller = makeCaller(advocate); + + const email = `invitee-${uuidv4()}@example.com`; + await caller.create({ + name: "Invitee", + email, + senderOrganisationId: senderOrg.id, + organisationName: `New Org ${uuidv4()}`, + }); + + const invitation = await db + .selectFrom("invitation") + .where("email", "=", email) + .selectAll() + .executeTakeFirstOrThrow(); + expect(invitation.isTrial).toBe(true); + }); + + test("superadmin invitation is not marked as trial", async () => { + const senderOrg = await createSenderOrg(); + const superadmin = await createTestUser(UserRole.Superadmin, senderOrg.id); + const caller = makeCaller(superadmin); + + const email = `invitee-${uuidv4()}@example.com`; + await caller.create({ + name: "Invitee", + email, + senderOrganisationId: senderOrg.id, + organisationName: `New Org ${uuidv4()}`, + }); + + const invitation = await db + .selectFrom("invitation") + .where("email", "=", email) + .selectAll() + .executeTakeFirstOrThrow(); + expect(invitation.isTrial).toBe(false); + }); +}); + afterAll(async () => { for (const id of mapIds) { try {