diff --git a/.dev.vars b/.dev.vars new file mode 100644 index 0000000..9410ca0 --- /dev/null +++ b/.dev.vars @@ -0,0 +1 @@ +IS_LOCAL_MODE=1 diff --git a/package.json b/package.json index bb6457b..53d1166 100644 --- a/package.json +++ b/package.json @@ -32,6 +32,7 @@ "@sveltejs/kit": "^2.4.3", "@sveltejs/vite-plugin-svelte": "^3.0.1", "@tailwindcss/typography": "^0.5.10", + "@types/cookie": "^0.6.0", "@types/debug": "^4.1.12", "@types/sql.js": "^1.4.9", "@typescript-eslint/eslint-plugin": "^6.19.1", @@ -79,5 +80,9 @@ "bugs": { "url": "https://github.com/JacobLinCool/d1-manager/issues" }, - "packageManager": "pnpm@8.14.3" + "packageManager": "pnpm@8.14.3", + "dependencies": { + "@cloudflare/pages-plugin-cloudflare-access": "^1.0.4", + "cookie": "^0.6.0" + } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index be58da6..3834626 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -4,6 +4,14 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +dependencies: + "@cloudflare/pages-plugin-cloudflare-access": + specifier: ^1.0.4 + version: 1.0.4 + cookie: + specifier: ^0.6.0 + version: 0.6.0 + devDependencies: "@ai-d/aid": specifier: ^0.1.5 @@ -38,6 +46,9 @@ devDependencies: "@tailwindcss/typography": specifier: ^0.5.10 version: 0.5.10(tailwindcss@3.4.1) + "@types/cookie": + specifier: ^0.6.0 + version: 0.6.0 "@types/debug": specifier: ^4.1.12 version: 4.1.12 @@ -502,6 +513,13 @@ packages: mime: 3.0.0 dev: true + /@cloudflare/pages-plugin-cloudflare-access@1.0.4: + resolution: + { + integrity: sha512-9R80Y4a+TSneX0v8zkwAc6scpYTMxxfyWI9BB3HJLkTtEviASTjZnV4tPjMk90bOtL9yoHT+Xpo90pSaplugUg==, + } + dev: false + /@cloudflare/workerd-darwin-64@1.20231218.0: resolution: { @@ -2992,7 +3010,6 @@ packages: integrity: sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==, } engines: { node: ">= 0.6" } - dev: true /cross-spawn@5.1.0: resolution: diff --git a/src/access.ts b/src/access.ts new file mode 100644 index 0000000..02b3cfa --- /dev/null +++ b/src/access.ts @@ -0,0 +1,137 @@ +// Copied from https://github.com/cloudflare/pages-plugins/blob/434fad8db20e483cc532d9c678d46a73a4ae7115/packages/cloudflare-access/functions/_middleware.ts +import type { PluginArgs } from "@cloudflare/pages-plugin-cloudflare-access"; +import { generateLoginURL, getIdentity } from "@cloudflare/pages-plugin-cloudflare-access/api"; +import { parse as parseCookies } from "cookie"; + +type CloudflareAccessPagesPluginFunction< + Env = unknown, + Params extends string = any, + Data extends Record = Record, +> = PagesPluginFunction; + +const extractJWTFromRequest = (request: Request) => + request.headers.get("Cf-Access-Jwt-Assertion") || + // I had to add this as some requests didn't have the header, just the cookie.. + parseCookies(request.headers.get("Cookie") || "")["CF_Authorization"]; + +// Adapted slightly from https://github.com/cloudflare/workers-access-external-auth-example +const base64URLDecode = (s: string) => { + s = s.replace(/-/g, "+").replace(/_/g, "/").replace(/\s/g, ""); + return new Uint8Array(Array.from(atob(s)).map((c: string) => c.charCodeAt(0))); +}; + +const asciiToUint8Array = (s: string) => { + const chars = []; + for (let i = 0; i < s.length; ++i) { + chars.push(s.charCodeAt(i)); + } + return new Uint8Array(chars); +}; + +const generateValidator = + ({ domain, aud }: { domain: string; aud: string }) => + async ( + request: Request, + ): Promise<{ + jwt: string; + payload: object; + }> => { + const jwt = extractJWTFromRequest(request); + const parts = jwt.split("."); + if (parts.length !== 3) { + throw new Error("JWT does not have three parts."); + } + const [header, payload, signature] = parts; + + const textDecoder = new TextDecoder("utf-8"); + const { kid, alg } = JSON.parse(textDecoder.decode(base64URLDecode(header))); + if (alg !== "RS256") { + throw new Error("Unknown JWT type or algorithm."); + } + + const certsURL = new URL("/cdn-cgi/access/certs", domain); + const certsResponse = await fetch(certsURL.toString()); + const { keys } = (await certsResponse.json()) as { + keys: ({ + kid: string; + } & JsonWebKey)[]; + public_cert: { kid: string; cert: string }; + public_certs: { kid: string; cert: string }[]; + }; + if (!keys) { + throw new Error("Could not fetch signing keys."); + } + const jwk = keys.find((key) => key.kid === kid); + if (!jwk) { + throw new Error("Could not find matching signing key."); + } + if (jwk.kty !== "RSA" || jwk.alg !== "RS256") { + throw new Error("Unknown key type of algorithm."); + } + + const key = await crypto.subtle.importKey( + "jwk", + jwk, + { name: "RSASSA-PKCS1-v1_5", hash: "SHA-256" }, + false, + ["verify"], + ); + + const unroundedSecondsSinceEpoch = Date.now() / 1000; + + const payloadObj = JSON.parse(textDecoder.decode(base64URLDecode(payload))); + + if (payloadObj.iss && payloadObj.iss !== certsURL.origin) { + throw new Error("JWT issuer is incorrect."); + } + if (payloadObj.aud && !payloadObj.aud.includes(aud)) { + throw new Error("JWT audience is incorrect."); + } + if (payloadObj.exp && Math.floor(unroundedSecondsSinceEpoch) >= payloadObj.exp) { + throw new Error("JWT has expired."); + } + if (payloadObj.nbf && Math.ceil(unroundedSecondsSinceEpoch) < payloadObj.nbf) { + throw new Error("JWT is not yet valid."); + } + + const verified = await crypto.subtle.verify( + "RSASSA-PKCS1-v1_5", + key, + base64URLDecode(signature), + asciiToUint8Array(`${header}.${payload}`), + ); + if (!verified) { + throw new Error("Could not verify JWT."); + } + + return { jwt, payload: payloadObj }; + }; + +export const onRequest: CloudflareAccessPagesPluginFunction = async ({ + request, + pluginArgs: { domain, aud }, + data, + next, +}) => { + try { + const validator = generateValidator({ domain, aud }); + + const { jwt, payload } = await validator(request); + + data.cloudflareAccess = { + JWT: { + payload, + getIdentity: () => getIdentity({ jwt, domain }), + }, + }; + + return next(); + } catch {} + + return new Response(null, { + status: 302, + headers: { + Location: generateLoginURL({ redirectURL: request.url, domain, aud }), + }, + }); +}; diff --git a/src/app.d.ts b/src/app.d.ts index 016abb0..ba9ac93 100644 --- a/src/app.d.ts +++ b/src/app.d.ts @@ -12,6 +12,8 @@ declare global { SHOW_INTERNAL_TABLES?: string; OPENAI_API_KEY?: string; AI?: unknown; + ACCESS_DOMAIN?: string; + ACCESS_AUD?: string; } & Record; } } diff --git a/src/hooks.server.ts b/src/hooks.server.ts index a5ef7a2..2556e55 100644 --- a/src/hooks.server.ts +++ b/src/hooks.server.ts @@ -2,8 +2,9 @@ import { extend } from "$lib/log"; import { DBMS } from "$lib/server/db/dbms"; import type { Handle, HandleServerError } from "@sveltejs/kit"; import { locale, waitLocale } from "svelte-i18n"; +import { onRequest } from "./access"; -export const handle: Handle = async ({ event, resolve }) => { +const handler: Handle = async ({ event, resolve }) => { const lang = event.request.headers.get("accept-language")?.split(",")[0] || "en"; locale.set(lang); await waitLocale(lang); @@ -14,6 +15,26 @@ export const handle: Handle = async ({ event, resolve }) => { return result; }; +export const handle: Handle = async ({ event, resolve }) => { + console.log(event.request.url); + // check request is authenticated + if (event.platform?.env.IS_LOCAL_MODE === "1") { + return await handler({ event, resolve }); + } else { + return await onRequest({ + request: event.request, + pluginArgs: { + domain: event.platform?.env.ACCESS_DOMAIN, + aud: event.platform?.env.ACCESS_AUD, + }, + data: {}, + next: async () => { + return await handler({ event, resolve }); + }, + }); + } +}; + const elog = extend("server-error"); elog.enabled = true;