diff --git a/demos/pds/wrangler.jsonc b/demos/pds/wrangler.jsonc index 6c9b9587..abac91c4 100644 --- a/demos/pds/wrangler.jsonc +++ b/demos/pds/wrangler.jsonc @@ -41,10 +41,13 @@ "DID": "did:web:pds.mk.gg", // Account handle (e.g., "alice.example.com") "HANDLE": "pds.mk.gg" - } + }, // Secrets (set via `pds init` or `pds secret `): // - AUTH_TOKEN: Bearer token for API write operations // - SIGNING_KEY: Private signing key (secp256k1 JWK) // - JWT_SECRET: Secret for signing session JWTs // - PASSWORD_HASH: Bcrypt hash of account password (for Bluesky app login) + "observability": { + "enabled": true + } } \ No newline at end of file diff --git a/packages/create-pds/README.md b/packages/create-pds/README.md index b23b8ab8..9b2dc0f0 100644 --- a/packages/create-pds/README.md +++ b/packages/create-pds/README.md @@ -45,4 +45,4 @@ npm run dev Your PDS will be running at http://localhost:5173 -See the [@ascorbic/pds documentation](https://github.com/ascorbic/atproto-worker/tree/main/packages/pds) for configuration and deployment instructions. \ No newline at end of file +See the [@ascorbic/pds documentation](https://github.com/ascorbic/atproto-worker/tree/main/packages/pds) for configuration and deployment instructions. diff --git a/packages/pds/package.json b/packages/pds/package.json index c85c6cf9..69c3ca06 100644 --- a/packages/pds/package.json +++ b/packages/pds/package.json @@ -22,6 +22,7 @@ "dependencies": { "@atproto/common-web": "^0.4.7", "@atproto/crypto": "^0.4.5", + "@atproto/identity": "^0.4.10", "@atproto/lex-cbor": "^0.0.3", "@atproto/lex-data": "^0.0.3", "@atproto/lexicon": "^0.6.0", diff --git a/packages/pds/src/did-cache.ts b/packages/pds/src/did-cache.ts new file mode 100644 index 00000000..0c370553 --- /dev/null +++ b/packages/pds/src/did-cache.ts @@ -0,0 +1,95 @@ +/** + * DID cache using Cloudflare Workers Cache API + */ + +import type { DidCache, CacheResult, DidDocument } from "@atproto/identity"; +import { check, didDocument } from "@atproto/common-web"; +import { waitUntil } from "cloudflare:workers"; + +const STALE_TTL = 60 * 60 * 1000; // 1 hour - serve from cache but refresh in background +const MAX_TTL = 24 * 60 * 60 * 1000; // 24 hours - must refresh + +export class WorkersDidCache implements DidCache { + private cache: Cache; + + constructor() { + this.cache = caches.default; + } + + private getCacheKey(did: string): string { + // Use a stable URL format for cache keys + return `https://did-cache.internal/${encodeURIComponent(did)}`; + } + + async cacheDid( + did: string, + doc: DidDocument, + _prevResult?: CacheResult, + ): Promise { + const cacheKey = this.getCacheKey(did); + const response = new Response(JSON.stringify(doc), { + headers: { + "Content-Type": "application/json", + "Cache-Control": "max-age=86400", // 24 hours + "X-Cached-At": Date.now().toString(), + }, + }); + + await this.cache.put(cacheKey, response); + } + + async checkCache(did: string): Promise { + const cacheKey = this.getCacheKey(did); + const response = await this.cache.match(cacheKey); + + if (!response) { + return null; + } + + const cachedAt = parseInt(response.headers.get("X-Cached-At") || "0", 10); + const now = Date.now(); + const age = now - cachedAt; + + const doc = await response.json(); + + // Validate cached document schema + if (!check.is(doc, didDocument) || doc.id !== did) { + await this.clearEntry(did); + return null; + } + + return { + did, + doc, + updatedAt: cachedAt, + stale: age > STALE_TTL, + expired: age > MAX_TTL, + }; + } + + async refreshCache( + did: string, + getDoc: () => Promise, + _prevResult?: CacheResult, + ): Promise { + // Background refresh using waitUntil to ensure it completes after response + waitUntil( + getDoc().then((doc) => { + if (doc) { + return this.cacheDid(did, doc); + } + }), + ); + } + + async clearEntry(did: string): Promise { + const cacheKey = this.getCacheKey(did); + await this.cache.delete(cacheKey); + } + + async clear(): Promise { + // Cache API doesn't have a clear-all method + // Would need to track keys separately if needed + // For now, entries will expire naturally + } +} diff --git a/packages/pds/src/did-resolver.ts b/packages/pds/src/did-resolver.ts new file mode 100644 index 00000000..7a743531 --- /dev/null +++ b/packages/pds/src/did-resolver.ts @@ -0,0 +1,154 @@ +/** + * DID resolution for Cloudflare Workers + * + * We can't use @atproto/identity directly because it uses `redirect: "error"` + * which Cloudflare Workers doesn't support. This is a simple implementation + * that's compatible with Workers. + */ + +import { check, didDocument, type DidDocument } from "@atproto/common-web"; +import type { DidCache } from "@atproto/identity"; + +const PLC_DIRECTORY = "https://plc.directory"; +const TIMEOUT_MS = 3000; + +export interface DidResolverOpts { + plcUrl?: string; + timeout?: number; + didCache?: DidCache; +} + +export class DidResolver { + private plcUrl: string; + private timeout: number; + private cache?: DidCache; + + constructor(opts: DidResolverOpts = {}) { + this.plcUrl = opts.plcUrl ?? PLC_DIRECTORY; + this.timeout = opts.timeout ?? TIMEOUT_MS; + this.cache = opts.didCache; + } + + async resolve(did: string): Promise { + // Check cache first + if (this.cache) { + const cached = await this.cache.checkCache(did); + if (cached && !cached.expired) { + // Trigger background refresh if stale + if (cached.stale) { + this.cache.refreshCache(did, () => this.resolveNoCache(did), cached); + } + return cached.doc; + } + } + + const doc = await this.resolveNoCache(did); + + // Update cache + if (doc && this.cache) { + await this.cache.cacheDid(did, doc); + } else if (!doc && this.cache) { + await this.cache.clearEntry(did); + } + + return doc; + } + + private async resolveNoCache(did: string): Promise { + if (did.startsWith("did:web:")) { + return this.resolveDidWeb(did); + } + if (did.startsWith("did:plc:")) { + return this.resolveDidPlc(did); + } + throw new Error(`Unsupported DID method: ${did}`); + } + + private async resolveDidWeb(did: string): Promise { + const parts = did.split(":").slice(2); + if (parts.length === 0) { + throw new Error(`Invalid did:web format: ${did}`); + } + + // Only support simple did:web without paths (like @atproto/identity) + if (parts.length > 1) { + throw new Error(`Unsupported did:web with path: ${did}`); + } + + const domain = decodeURIComponent(parts[0]!); + const url = new URL(`https://${domain}/.well-known/did.json`); + + // Use http for localhost + if (url.hostname === "localhost") { + url.protocol = "http:"; + } + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), this.timeout); + + try { + const res = await fetch(url.toString(), { + signal: controller.signal, + redirect: "manual", // Workers doesn't support "error" + headers: { accept: "application/did+ld+json,application/json" }, + }); + + // Check for redirect (we don't follow them for security) + if (res.status >= 300 && res.status < 400) { + return null; + } + + if (!res.ok) { + return null; + } + + const doc = await res.json(); + return this.validateDidDoc(did, doc); + } finally { + clearTimeout(timeoutId); + } + } + + private async resolveDidPlc(did: string): Promise { + const url = new URL(`/${encodeURIComponent(did)}`, this.plcUrl); + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), this.timeout); + + try { + const res = await fetch(url.toString(), { + signal: controller.signal, + redirect: "manual", // Workers doesn't support "error" + headers: { accept: "application/did+ld+json,application/json" }, + }); + + // Check for redirect (we don't follow them for security) + if (res.status >= 300 && res.status < 400) { + return null; + } + + if (res.status === 404) { + return null; + } + + if (!res.ok) { + throw new Error(`PLC directory error: ${res.status} ${res.statusText}`); + } + + const doc = (await res.json()) as DidDocument; + return this.validateDidDoc(did, doc); + } finally { + clearTimeout(timeoutId); + } + } + + private validateDidDoc(did: string, doc: unknown): DidDocument | null { + if (!check.is(doc, didDocument)) { + return null; + } + if (doc.id !== did) { + return null; + } + return doc; + } +} diff --git a/packages/pds/src/index.ts b/packages/pds/src/index.ts index 627eeef6..c77b75ef 100644 --- a/packages/pds/src/index.ts +++ b/packages/pds/src/index.ts @@ -8,8 +8,9 @@ import { env as _env } from "cloudflare:workers"; import { Secp256k1Keypair } from "@atproto/crypto"; import { ensureValidDid, ensureValidHandle } from "@atproto/syntax"; import { requireAuth } from "./middleware/auth"; -import { createServiceJwt } from "./service-auth"; -import { verifyAccessToken } from "./session"; +import { DidResolver } from "./did-resolver"; +import { WorkersDidCache } from "./did-cache"; +import { handleXrpcProxy } from "./xrpc-proxy"; import * as sync from "./xrpc/sync"; import * as repo from "./xrpc/repo"; import * as server from "./xrpc/server"; @@ -48,9 +49,11 @@ try { ); } -// Bluesky service DIDs for service auth -const APPVIEW_DID = "did:web:api.bsky.app"; -const CHAT_DID = "did:web:api.bsky.chat"; +const didResolver = new DidResolver({ + didCache: new WorkersDidCache(), + timeout: 3000, // 3 second timeout for DID resolution + plcUrl: "https://plc.directory", +}); // Lazy-loaded keypair for service auth let keypairPromise: Promise | null = null; @@ -252,77 +255,8 @@ app.post("/admin/emit-identity", requireAuth, async (c) => { return c.json(result); }); -// Proxy unhandled XRPC requests to Bluesky services -app.all("/xrpc/*", async (c) => { - const url = new URL(c.req.url); - url.protocol = "https:"; - - // Extract XRPC method name from path (e.g., "app.bsky.feed.getTimeline") - const lxm = url.pathname.replace("/xrpc/", ""); - - // Route to appropriate service based on lexicon namespace - const isChat = lxm.startsWith("chat.bsky."); - url.host = isChat ? "api.bsky.chat" : "api.bsky.app"; - const audienceDid = isChat ? CHAT_DID : APPVIEW_DID; - - // Check for authorization header - const auth = c.req.header("Authorization"); - let headers: Record = {}; - - if (auth?.startsWith("Bearer ")) { - const token = auth.slice(7); - const serviceDid = `did:web:${c.env.PDS_HOSTNAME}`; - - // Try to verify the token - if valid, create a service JWT - try { - // Check static token first - let userDid: string; - if (token === c.env.AUTH_TOKEN) { - userDid = c.env.DID; - } else { - // Verify JWT - const payload = await verifyAccessToken( - token, - c.env.JWT_SECRET, - serviceDid, - ); - userDid = payload.sub; - } - - // Create service JWT for target service - const keypair = await getKeypair(); - const serviceJwt = await createServiceJwt({ - iss: userDid, - aud: audienceDid, - lxm, - keypair, - }); - headers["Authorization"] = `Bearer ${serviceJwt}`; - } catch { - // Token verification failed - forward without auth - // Target service will return appropriate error - } - } - - // Forward request with potentially replaced auth header - // Remove original authorization header to prevent conflicts - const originalHeaders = Object.fromEntries(c.req.raw.headers); - delete originalHeaders["authorization"]; - - const reqInit: RequestInit = { - method: c.req.method, - headers: { - ...originalHeaders, - ...headers, - }, - }; - - // Include body for non-GET requests - if (c.req.method !== "GET" && c.req.method !== "HEAD") { - reqInit.body = c.req.raw.body; - } - - return fetch(url.toString(), reqInit); -}); +// Proxy unhandled XRPC requests to services specified via atproto-proxy header +// or fall back to Bluesky services for backward compatibility +app.all("/xrpc/*", (c) => handleXrpcProxy(c, didResolver, getKeypair)); export default app; diff --git a/packages/pds/src/service-auth.ts b/packages/pds/src/service-auth.ts index a92ec4b8..be94382f 100644 --- a/packages/pds/src/service-auth.ts +++ b/packages/pds/src/service-auth.ts @@ -108,7 +108,9 @@ export async function verifyServiceJwt( throw new Error("Invalid JWT format"); } - const [headerB64, payloadB64, signatureB64] = parts; + const headerB64 = parts[0]!; + const payloadB64 = parts[1]!; + const signatureB64 = parts[2]!; // Decode header const header = JSON.parse(Buffer.from(headerB64, "base64url").toString()); diff --git a/packages/pds/src/xrpc-proxy.ts b/packages/pds/src/xrpc-proxy.ts new file mode 100644 index 00000000..6968f03b --- /dev/null +++ b/packages/pds/src/xrpc-proxy.ts @@ -0,0 +1,219 @@ +/** + * XRPC service proxying with atproto-proxy header support + * See: https://atproto.com/specs/xrpc#service-proxying + */ + +import type { Context } from "hono"; +import { DidResolver } from "./did-resolver"; +import { getServiceEndpoint } from "@atproto/common-web"; +import { createServiceJwt } from "./service-auth"; +import { verifyAccessToken } from "./session"; +import type { PDSEnv } from "./types"; +import type { Secp256k1Keypair } from "@atproto/crypto"; + +/** + * Parse atproto-proxy header value + * Format: "did:web:example.com#service_id" + * Returns: { did: "did:web:example.com", serviceId: "service_id" } + */ +export function parseProxyHeader( + header: string, +): { did: string; serviceId: string } | null { + const parts = header.split("#"); + if (parts.length !== 2) { + return null; + } + + const [did, serviceId] = parts; + if (!did?.startsWith("did:") || !serviceId) { + return null; + } + + return { did, serviceId }; +} + +/** + * Handle XRPC proxy requests + * Routes requests to external services based on atproto-proxy header or lexicon namespace + */ +export async function handleXrpcProxy( + c: Context<{ Bindings: PDSEnv }>, + didResolver: DidResolver, + getKeypair: () => Promise, +): Promise { + // Extract XRPC method name from path (e.g., "app.bsky.feed.getTimeline") + const url = new URL(c.req.url); + const lxm = url.pathname.replace("/xrpc/", ""); + + // Validate XRPC path to prevent path traversal + if (lxm.includes("..") || lxm.includes("//")) { + return c.json( + { + error: "InvalidRequest", + message: "Invalid XRPC method path", + }, + 400, + ); + } + + // Check for atproto-proxy header for explicit service routing + const proxyHeader = c.req.header("atproto-proxy"); + let audienceDid: string; + let targetUrl: URL; + + if (proxyHeader) { + // Parse proxy header: "did:web:example.com#service_id" + const parsed = parseProxyHeader(proxyHeader); + if (!parsed) { + return c.json( + { + error: "InvalidRequest", + message: `Invalid atproto-proxy header format: ${proxyHeader}`, + }, + 400, + ); + } + + try { + // Resolve DID document to get service endpoint (with caching) + const didDoc = await didResolver.resolve(parsed.did); + if (!didDoc) { + return c.json( + { + error: "InvalidRequest", + message: `DID not found: ${parsed.did}`, + }, + 400, + ); + } + + // getServiceEndpoint expects the ID to start with # + const serviceId = parsed.serviceId.startsWith("#") + ? parsed.serviceId + : `#${parsed.serviceId}`; + const endpoint = getServiceEndpoint(didDoc, { id: serviceId }); + + if (!endpoint) { + return c.json( + { + error: "InvalidRequest", + message: `Service not found in DID document: ${parsed.serviceId}`, + }, + 400, + ); + } + + // Use the resolved service endpoint + audienceDid = parsed.did; + targetUrl = new URL(endpoint); + if (targetUrl.protocol !== "https:") { + return c.json( + { + error: "InvalidRequest", + message: "Proxy target must use HTTPS", + }, + 400, + ); + } + targetUrl.pathname = url.pathname; + targetUrl.search = url.search; + } catch (err) { + return c.json( + { + error: "InvalidRequest", + message: `Failed to resolve service: ${err instanceof Error ? err.message : String(err)}`, + }, + 400, + ); + } + } else { + // Fallback: Route to Bluesky services based on lexicon namespace + // These are well-known endpoints that don't require DID resolution + const isChat = lxm.startsWith("chat.bsky."); + audienceDid = isChat ? "did:web:api.bsky.chat" : "did:web:api.bsky.app"; + const endpoint = isChat ? "https://api.bsky.chat" : "https://api.bsky.app"; + + // Construct URL safely using URL constructor + targetUrl = new URL(`/xrpc/${lxm}${url.search}`, endpoint); + } + + // Check for authorization header + const auth = c.req.header("Authorization"); + let headers: Record = {}; + + if (auth?.startsWith("Bearer ")) { + const token = auth.slice(7); + const serviceDid = `did:web:${c.env.PDS_HOSTNAME}`; + + // Try to verify the token - if valid, create a service JWT + try { + // Check static token first + let userDid: string; + if (token === c.env.AUTH_TOKEN) { + userDid = c.env.DID; + } else { + // Verify JWT + const payload = await verifyAccessToken( + token, + c.env.JWT_SECRET, + serviceDid, + ); + if (!payload.sub) { + throw new Error("Missing sub claim in token"); + } + userDid = payload.sub; + } + + // Create service JWT for target service + const keypair = await getKeypair(); + const serviceJwt = await createServiceJwt({ + iss: userDid, + aud: audienceDid, + lxm, + keypair, + }); + headers["Authorization"] = `Bearer ${serviceJwt}`; + } catch { + // Token verification failed - forward without auth + // Target service will return appropriate error + } + } + + // Forward request with potentially replaced auth header + // Use Headers object for case-insensitive handling + const forwardHeaders = new Headers(c.req.raw.headers); + + // Remove headers that shouldn't be forwarded (security/privacy) + const headersToRemove = [ + "authorization", // Replaced with service JWT + "atproto-proxy", // Internal routing header + "host", // Will be set by fetch + "connection", // Connection-specific + "cookie", // Privacy - don't leak cookies + "x-forwarded-for", // Don't leak client IP + "x-real-ip", // Don't leak client IP + "x-forwarded-proto", // Internal + "x-forwarded-host", // Internal + ]; + + for (const header of headersToRemove) { + forwardHeaders.delete(header); + } + + // Add service auth if we have it + if (headers["Authorization"]) { + forwardHeaders.set("Authorization", headers["Authorization"]); + } + + const reqInit: RequestInit = { + method: c.req.method, + headers: forwardHeaders, + }; + + // Include body for non-GET requests + if (c.req.method !== "GET" && c.req.method !== "HEAD") { + reqInit.body = c.req.raw.body; + } + + return fetch(targetUrl.toString(), reqInit); +} diff --git a/packages/pds/test/did-resolver.test.ts b/packages/pds/test/did-resolver.test.ts new file mode 100644 index 00000000..6da57e26 --- /dev/null +++ b/packages/pds/test/did-resolver.test.ts @@ -0,0 +1,134 @@ +import { describe, it, expect } from "vitest"; +import { parseProxyHeader } from "../src/xrpc-proxy"; +import { getServiceEndpoint, type DidDocument } from "@atproto/common-web"; + +describe("DID Resolver", () => { + describe("parseProxyHeader", () => { + it("should parse valid proxy header", () => { + const result = parseProxyHeader("did:web:example.com#atproto_labeler"); + expect(result).toEqual({ + did: "did:web:example.com", + serviceId: "atproto_labeler", + }); + }); + + it("should parse did:plc header", () => { + const result = parseProxyHeader("did:plc:abc123xyz#atproto_labeler"); + expect(result).toEqual({ + did: "did:plc:abc123xyz", + serviceId: "atproto_labeler", + }); + }); + + it("should return null for invalid format (no hash)", () => { + const result = parseProxyHeader("did:web:example.com"); + expect(result).toBeNull(); + }); + + it("should return null for invalid format (not a DID)", () => { + const result = parseProxyHeader("https://example.com#service"); + expect(result).toBeNull(); + }); + + it("should return null for multiple hashes", () => { + const result = parseProxyHeader("did:web:example.com#service#extra"); + expect(result).toBeNull(); + }); + }); + + describe("getServiceEndpoint", () => { + it("should extract endpoint with fragment-only ID", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + + it("should extract endpoint with full ID", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "did:web:example.com#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + + it("should extract endpoint when serviceId includes hash", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + + it("should return undefined for non-existent service", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#nonexistent" }); + expect(endpoint).toBeUndefined(); + }); + + it("should return undefined when no services exist", () => { + const doc: DidDocument = { + id: "did:web:example.com", + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBeUndefined(); + }); + + it("should handle multiple services", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_pds", + type: "AtprotoPersonalDataServer", + serviceEndpoint: "https://pds.example.com", + }, + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + }); +}); diff --git a/packages/pds/test/proxy.test.ts b/packages/pds/test/proxy.test.ts new file mode 100644 index 00000000..3ea526e8 --- /dev/null +++ b/packages/pds/test/proxy.test.ts @@ -0,0 +1,372 @@ +import { describe, it, expect, beforeAll, vi, afterEach } from "vitest"; +import { env, worker } from "./helpers"; + +// Mock DID documents for testing +const mockDidDocuments: Record = { + "did:web:labeler.example.com": { + id: "did:web:labeler.example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }, + "did:web:api.bsky.app": { + id: "did:web:api.bsky.app", + service: [ + { + id: "#atproto_appview", + type: "AtprotoAppView", + serviceEndpoint: "https://api.bsky.app", + }, + ], + }, +}; + +describe("XRPC Service Proxying", () => { + let authToken: string; + let originalFetch: typeof fetch; + + beforeAll(async () => { + // Get auth token for tests that need authentication + authToken = env.AUTH_TOKEN; + + // Save original fetch + originalFetch = globalThis.fetch; + }); + + afterEach(() => { + // Restore original fetch after each test + globalThis.fetch = originalFetch; + vi.unstubAllGlobals(); + }); + + describe("atproto-proxy header", () => { + it("should reject invalid proxy header format", async () => { + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.feed.getAuthorFeed?actor=test.bsky.social", + { + headers: { + "atproto-proxy": "invalid-format", + }, + }, + ), + env, + ); + + expect(response.status).toBe(400); + const data = await response.json(); + expect(data).toMatchObject({ + error: "InvalidRequest", + message: expect.stringContaining("Invalid atproto-proxy header"), + }); + }); + + it("should reject proxy header without service ID", async () => { + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.feed.getAuthorFeed?actor=test.bsky.social", + { + headers: { + "atproto-proxy": "did:web:example.com", + }, + }, + ), + env, + ); + + expect(response.status).toBe(400); + const data = await response.json(); + expect(data).toMatchObject({ + error: "InvalidRequest", + message: expect.stringContaining("Invalid atproto-proxy header"), + }); + }); + + it("should handle DID resolution failure gracefully", async () => { + // Mock fetch to simulate DID resolution failure + vi.stubGlobal( + "fetch", + vi.fn((url: string) => { + if ( + url === + "https://nonexistent-domain-12345.invalid/.well-known/did.json" + ) { + return Promise.reject(new Error("DNS lookup failed")); + } + return originalFetch(url); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.feed.getAuthorFeed?actor=test.bsky.social", + { + headers: { + "atproto-proxy": + "did:web:nonexistent-domain-12345.invalid#atproto_labeler", + }, + }, + ), + env, + ); + + expect(response.status).toBe(400); + const data = await response.json(); + expect(data).toMatchObject({ + error: "InvalidRequest", + message: expect.stringContaining("Failed to resolve service"), + }); + }); + + it("should reject when service not found in DID document", async () => { + // Mock fetch to return DID document without the requested service + vi.stubGlobal( + "fetch", + vi.fn((url: string) => { + if (url === "https://api.bsky.app/.well-known/did.json") { + return Promise.resolve( + new Response( + JSON.stringify(mockDidDocuments["did:web:api.bsky.app"]), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + } + return originalFetch(url); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.feed.getAuthorFeed?actor=test.bsky.social", + { + headers: { + "atproto-proxy": "did:web:api.bsky.app#nonexistent_service", + }, + }, + ), + env, + ); + + expect(response.status).toBe(400); + const data = await response.json(); + expect(data).toMatchObject({ + error: "InvalidRequest", + message: expect.stringContaining("Service not found in DID document"), + }); + }); + + it("should reject non-HTTPS service endpoints", async () => { + // Mock DID document with HTTP endpoint + vi.stubGlobal( + "fetch", + vi.fn((url: string) => { + if (url === "https://insecure.example.com/.well-known/did.json") { + return Promise.resolve( + new Response( + JSON.stringify({ + id: "did:web:insecure.example.com", + service: [ + { + id: "#atproto_pds", + type: "AtprotoPersonalDataServer", + serviceEndpoint: "http://insecure.example.com", // HTTP, not HTTPS + }, + ], + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + } + return originalFetch(url); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.feed.getAuthorFeed?actor=test", + { + headers: { + "atproto-proxy": "did:web:insecure.example.com#atproto_pds", + }, + }, + ), + env, + ); + + expect(response.status).toBe(400); + const data = await response.json(); + expect(data).toMatchObject({ + error: "InvalidRequest", + message: "Proxy target must use HTTPS", + }); + }); + + it("should successfully proxy with valid atproto-proxy header", async () => { + // Mock fetch for both DID resolution and the proxied request + vi.stubGlobal( + "fetch", + vi.fn((url: string | URL, init?: RequestInit) => { + const urlStr = url.toString(); + if (urlStr === "https://labeler.example.com/.well-known/did.json") { + return Promise.resolve( + new Response( + JSON.stringify(mockDidDocuments["did:web:labeler.example.com"]), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + } + if (urlStr.startsWith("https://labeler.example.com/xrpc/")) { + // Verify the service JWT was added + const headers = new Headers(init?.headers); + const authHeader = headers.get("Authorization"); + expect(authHeader).toMatch(/^Bearer /); + + return Promise.resolve( + new Response(JSON.stringify({ success: true }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + } + return originalFetch(url, init); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.feed.getAuthorFeed?actor=test.bsky.social", + { + headers: { + "atproto-proxy": "did:web:labeler.example.com#atproto_labeler", + Authorization: `Bearer ${authToken}`, + }, + }, + ), + env, + ); + + expect(response.status).toBe(200); + const data = await response.json(); + expect(data).toEqual({ success: true }); + }); + }); + + describe("Fallback behavior", () => { + it("should proxy to Bluesky AppView when no proxy header present", async () => { + // Mock fetch to verify request goes to api.bsky.app + vi.stubGlobal( + "fetch", + vi.fn((url: string) => { + if (url.includes("api.bsky.app")) { + return Promise.resolve( + new Response(JSON.stringify({ proxied: true }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + } + return originalFetch(url); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.actor.getProfile?actor=test.bsky.social", + ), + env, + ); + + expect(response.status).toBe(200); + const data = await response.json(); + expect(data).toEqual({ proxied: true }); + }); + + it("should proxy chat methods to api.bsky.chat", async () => { + // Mock fetch to verify request goes to api.bsky.chat + vi.stubGlobal( + "fetch", + vi.fn((url: string) => { + if (url.includes("api.bsky.chat")) { + return Promise.resolve( + new Response(JSON.stringify({ chat: true }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + } + return originalFetch(url); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/chat.bsky.convo.getConvo?convoId=123", + { + headers: { + Authorization: `Bearer ${authToken}`, + }, + }, + ), + env, + ); + + expect(response.status).toBe(200); + const data = await response.json(); + expect(data).toEqual({ chat: true }); + }); + + it("should forward Authorization header as service JWT", async () => { + let capturedAuthHeader: string | null = null; + + // Mock fetch to capture the Authorization header + vi.stubGlobal( + "fetch", + vi.fn((url: string, init?: RequestInit) => { + if (url.includes("api.bsky.app")) { + // Headers can be a Headers object, array, or plain object + const headers = new Headers(init?.headers); + capturedAuthHeader = headers.get("Authorization"); + return Promise.resolve( + new Response(JSON.stringify({ ok: true }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + } + return originalFetch(url, init); + }), + ); + + const response = await worker.fetch( + new Request( + "http://pds.test/xrpc/app.bsky.actor.getProfile?actor=test.bsky.social", + { + headers: { + Authorization: `Bearer ${authToken}`, + }, + }, + ), + env, + ); + + expect(response.status).toBe(200); + // Verify service JWT was created and forwarded + expect(capturedAuthHeader).toMatch(/^Bearer /); + // The forwarded token should be different from the original (it's a service JWT) + expect(capturedAuthHeader).not.toBe(`Bearer ${authToken}`); + }); + }); +}); diff --git a/packages/pds/test/security.test.ts b/packages/pds/test/security.test.ts new file mode 100644 index 00000000..aadba06f --- /dev/null +++ b/packages/pds/test/security.test.ts @@ -0,0 +1,153 @@ +import { describe, it, expect } from "vitest"; +import { parseProxyHeader } from "../src/xrpc-proxy"; +import { getServiceEndpoint, type DidDocument } from "@atproto/common-web"; + +describe("DID Resolver URL Validation", () => { + describe("Protocol validation", () => { + it("should reject non-HTTP(S) URLs", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "ftp://example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBeUndefined(); + }); + + it("should reject invalid URLs", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "not-a-url", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBeUndefined(); + }); + + it("should allow HTTP URLs", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "http://example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("http://example.com"); + }); + + it("should allow HTTPS URLs", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + + it("should allow URLs with ports", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com:8443", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com:8443"); + }); + + it("should allow URLs with paths", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://example.com/labeler", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://example.com/labeler"); + }); + }); + + describe("Service ID matching", () => { + it("should match service ID with hash prefix", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + + it("should match service ID without hash prefix", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + + it("should match full service ID", () => { + const doc: DidDocument = { + id: "did:web:example.com", + service: [ + { + id: "did:web:example.com#atproto_labeler", + type: "AtprotoLabeler", + serviceEndpoint: "https://labeler.example.com", + }, + ], + }; + + const endpoint = getServiceEndpoint(doc, { id: "#atproto_labeler" }); + expect(endpoint).toBe("https://labeler.example.com"); + }); + }); +}); diff --git a/packages/pds/test/xrpc.test.ts b/packages/pds/test/xrpc.test.ts index 6f7b64cb..fbe782d1 100644 --- a/packages/pds/test/xrpc.test.ts +++ b/packages/pds/test/xrpc.test.ts @@ -1031,14 +1031,11 @@ describe("XRPC Endpoints", () => { it("should require aud parameter", async () => { const response = await worker.fetch( - new Request( - "http://pds.test/xrpc/com.atproto.server.getServiceAuth", - { - headers: { - Authorization: `Bearer ${env.AUTH_TOKEN}`, - }, + new Request("http://pds.test/xrpc/com.atproto.server.getServiceAuth", { + headers: { + Authorization: `Bearer ${env.AUTH_TOKEN}`, }, - ), + }), env, ); expect(response.status).toBe(400); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2dc5d9da..ecd3d35d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -75,6 +75,9 @@ importers: '@atproto/crypto': specifier: ^0.4.5 version: 0.4.5 + '@atproto/identity': + specifier: ^0.4.10 + version: 0.4.10 '@atproto/lex-cbor': specifier: ^0.0.3 version: 0.0.3 @@ -168,6 +171,10 @@ packages: resolution: {integrity: sha512-n40aKkMoCatP0u9Yvhrdk6fXyOHFDDbkdm4h4HCyWW+KlKl8iXfD5iV+ECq+w5BM+QH25aIpt3/j6EUNerhLxw==} engines: {node: '>=18.7.0'} + '@atproto/identity@0.4.10': + resolution: {integrity: sha512-nQbzDLXOhM8p/wo0cTh5DfMSOSHzj6jizpodX37LJ4S1TZzumSxAjHEZa5Rev3JaoD5uSWMVE0MmKEGWkPPvfQ==} + engines: {node: '>=18.7.0'} + '@atproto/lex-cbor@0.0.3': resolution: {integrity: sha512-N8lCV3kK5ZcjSOWxKLWqzlnaSpK4isjXRZ0EqApl/5y9KB64s78hQ/U3KIE5qnPRlBbW5kSH3YACoU27u9nTOA==} @@ -2628,6 +2635,11 @@ snapshots: '@noble/hashes': 1.8.0 uint8arrays: 3.0.0 + '@atproto/identity@0.4.10': + dependencies: + '@atproto/common-web': 0.4.7 + '@atproto/crypto': 0.4.5 + '@atproto/lex-cbor@0.0.3': dependencies: '@atproto/lex-data': 0.0.3