diff --git a/packages/console/app/src/routes/stripe/webhook.ts b/packages/console/app/src/routes/stripe/webhook.ts index edae0cf53..032680728 100644 --- a/packages/console/app/src/routes/stripe/webhook.ts +++ b/packages/console/app/src/routes/stripe/webhook.ts @@ -17,7 +17,7 @@ export async function POST(input: APIEvent) { input.request.headers.get("stripe-signature")!, Resource.STRIPE_WEBHOOK_SECRET.value, ) - console.log(body.type, JSON.stringify(body, null, 2)) + console.log("stripe webhook:", body.type, body.id) return (async () => { if (body.type === "customer.updated") { @@ -285,7 +285,6 @@ export async function POST(input: APIEvent) { if (!invoiceID) throw new Error("Invoice ID not found") const paymentIntent = await Billing.stripe().paymentIntents.retrieve(invoiceID) - console.log(JSON.stringify(paymentIntent)) const errorMessage = typeof paymentIntent === "object" && paymentIntent !== null ? paymentIntent.last_payment_error?.message diff --git a/packages/enterprise/src/routes/share/[shareID].tsx b/packages/enterprise/src/routes/share/[shareID].tsx index 0c0f40787..1c58e24db 100644 --- a/packages/enterprise/src/routes/share/[shareID].tsx +++ b/packages/enterprise/src/routes/share/[shareID].tsx @@ -126,14 +126,10 @@ export default function () { return } console.error(error) - const details = error instanceof Error ? (error.stack ?? error.message) : String(error) return (

Unable to render this share.

-

Check the console for more details.

-
-              {details}
-            
+

An unexpected error occurred. Please try again later.

) }} diff --git a/packages/function/src/api.ts b/packages/function/src/api.ts index 58c74fe32..d2d91fc4e 100644 --- a/packages/function/src/api.ts +++ b/packages/function/src/api.ts @@ -18,6 +18,10 @@ export class SyncServer extends DurableObject { super(ctx, env) } async fetch() { + const secret = await this.getSecret() + if (!secret) { + return new Response("Not found", { status: 404 }) + } console.log("SyncServer subscribe") const webSocketPair = new WebSocketPair() @@ -77,6 +81,8 @@ export class SyncServer extends DurableObject { } public async getData() { + const secret = await this.getSecret() + if (!secret) return [] const data = (await this.ctx.storage.list()) as Map return Array.from(data.entries()) .filter(([key, _]) => key.startsWith("session/")) @@ -116,6 +122,10 @@ export class SyncServer extends DurableObject { export default new Hono<{ Bindings: Env }>() .get("/", (c) => c.text("Hello, world!")) .post("/share_create", async (c) => { + const authHeader = c.req.header("authorization") + if (!authHeader || authHeader !== `Bearer ${Resource.ADMIN_SECRET.value}`) { + return c.text("Unauthorized", 401) + } const body = await c.req.json<{ sessionID: string }>() const sessionID = body.sessionID const short = SyncServer.shortName(sessionID) @@ -202,7 +212,9 @@ export default new Hono<{ Bindings: Env }>() return c.json({ info, messages }) }) .post("/feishu", async (c) => { - const body = (await c.req.json()) as { + const rawBody = await c.req.text() + + let body: { challenge?: string event?: { message?: { @@ -214,9 +226,41 @@ export default new Hono<{ Bindings: Env }>() } } } - console.log(JSON.stringify(body, null, 2)) - const challenge = body.challenge - if (challenge) return c.json({ challenge }) + try { + body = JSON.parse(rawBody) + } catch { + return c.text("Invalid JSON body", 400) + } + + // Challenge requests during setup don't require signature verification + if (body.challenge) return c.json({ challenge: body.challenge }) + + // All non-challenge requests must have a valid signature + const signature = c.req.header("x-lark-signature") + const timestamp = c.req.header("x-lark-request-timestamp") + const nonce = c.req.header("x-lark-request-nonce") + if (!signature || !timestamp || !nonce) { + return c.text("Missing signature headers", 403) + } + + // Reject stale timestamps (±5 min window) + const ts = parseInt(timestamp, 10) + const now = Math.floor(Date.now() / 1000) + if (isNaN(ts) || Math.abs(now - ts) > 300) { + return c.text("Timestamp expired", 403) + } + + const encryptKey = Resource.FEISHU_APP_SECRET.value + const payload = timestamp + nonce + encryptKey + rawBody + const hash = await crypto.subtle.digest("SHA-256", new TextEncoder().encode(payload)) + const expected = Array.from(new Uint8Array(hash)) + .map((b) => b.toString(16).padStart(2, "0")) + .join("") + if (expected !== signature) { + return c.text("Invalid signature", 403) + } + + console.log("feishu webhook:", body.event?.message?.message_id ?? "unknown") const content = body.event?.message?.content const parsed = diff --git a/packages/opencode/src/mcp/index.ts b/packages/opencode/src/mcp/index.ts index ef7c571a4..7e92b9d01 100644 --- a/packages/opencode/src/mcp/index.ts +++ b/packages/opencode/src/mcp/index.ts @@ -18,6 +18,7 @@ import { Installation } from "../installation" import { InstallationVersion } from "../installation/version" import { withTimeout } from "@/util/timeout" import { AppFileSystem } from "@mimo-ai/shared/filesystem" +import { assertSafeUrl } from "@/util/ssrf" import { McpOAuthProvider } from "./oauth-provider" import { McpOAuthCallback } from "./oauth-callback" import { McpAuth } from "./auth" @@ -287,6 +288,11 @@ export const layer = Layer.effect( key: string, mcp: ConfigMCP.Info & { type: "remote" }, ) { + yield* Effect.tryPromise({ + try: () => assertSafeUrl(mcp.url), + catch: (e) => new Error(e instanceof Error ? e.message : String(e)), + }).pipe(Effect.orDie) + const oauthDisabled = mcp.oauth === false const oauthConfig = typeof mcp.oauth === "object" ? mcp.oauth : undefined let authProvider: McpOAuthProvider | undefined @@ -745,6 +751,11 @@ export const layer = Layer.effect( if (mcpConfig.type !== "remote") throw new Error(`MCP server ${mcpName} is not a remote server`) if (mcpConfig.oauth === false) throw new Error(`MCP server ${mcpName} has OAuth explicitly disabled`) + yield* Effect.tryPromise({ + try: () => assertSafeUrl(mcpConfig.url), + catch: (e) => new Error(e instanceof Error ? e.message : String(e)), + }).pipe(Effect.orDie) + // OAuth config is optional - if not provided, we'll use auto-discovery const oauthConfig = typeof mcpConfig.oauth === "object" ? mcpConfig.oauth : undefined diff --git a/packages/opencode/src/mcp/oauth-callback.ts b/packages/opencode/src/mcp/oauth-callback.ts index fbb43d392..0c9798408 100644 --- a/packages/opencode/src/mcp/oauth-callback.ts +++ b/packages/opencode/src/mcp/oauth-callback.ts @@ -9,6 +9,10 @@ const log = Log.create({ service: "mcp.oauth-callback" }) let currentPort = OAUTH_CALLBACK_PORT let currentPath = OAUTH_CALLBACK_PATH +function escapeHtml(s: string): string { + return s.replace(/&/g, "&").replace(//g, ">").replace(/"/g, """).replace(/'/g, "'") +} + const HTML_SUCCESS = ` @@ -45,7 +49,7 @@ const HTML_ERROR = (error: string) => `

Authorization Failed

An error occurred during authorization.

-
${error}
+
${escapeHtml(error)}
` diff --git a/packages/opencode/src/plugin/codex.ts b/packages/opencode/src/plugin/codex.ts index a48e94c16..b78ddada9 100644 --- a/packages/opencode/src/plugin/codex.ts +++ b/packages/opencode/src/plugin/codex.ts @@ -186,6 +186,10 @@ const HTML_SUCCESS = ` ` +function escapeHtml(s: string): string { + return s.replace(/&/g, "&").replace(//g, ">").replace(/"/g, """).replace(/'/g, "'") +} + const HTML_ERROR = (error: string) => ` @@ -229,7 +233,7 @@ const HTML_ERROR = (error: string) => `

Authorization Failed

An error occurred during authorization.

-
${error}
+
${escapeHtml(error)}
` diff --git a/packages/opencode/src/server/middleware.ts b/packages/opencode/src/server/middleware.ts index 92bb3acbe..3962f3f40 100644 --- a/packages/opencode/src/server/middleware.ts +++ b/packages/opencode/src/server/middleware.ts @@ -31,7 +31,7 @@ export const ErrorMiddleware: ErrorHandler = (err, c) => { return c.json(new NamedError.Unknown({ message: err.message }).toObject(), { status: 409 }) } if (err instanceof HTTPException) return err.getResponse() - const message = err instanceof Error && err.stack ? err.stack : err.toString() + const message = err instanceof Error ? err.message : "Internal Server Error" return c.json(new NamedError.Unknown({ message }).toObject(), { status: 500, }) @@ -48,8 +48,6 @@ export const AuthMiddleware: MiddlewareHandler = (c, next) => { const username = Flag.MIMOCODE_SERVER_USERNAME ?? "mimocode" - if (c.req.query("auth_token")) c.req.raw.headers.set("authorization", `Basic ${c.req.query("auth_token")}`) - return basicAuth({ username, password })(c, next) } diff --git a/packages/opencode/src/server/rate-limit.ts b/packages/opencode/src/server/rate-limit.ts new file mode 100644 index 000000000..5aacc2cfe --- /dev/null +++ b/packages/opencode/src/server/rate-limit.ts @@ -0,0 +1,38 @@ +import type { MiddlewareHandler } from "hono" + +const windows = new Map() + +let lastSweep = Date.now() +const SWEEP_INTERVAL = 60_000 + +function sweep() { + const now = Date.now() + if (now - lastSweep < SWEEP_INTERVAL) return + lastSweep = now + for (const [key, entry] of windows) { + if (now >= entry.resetAt) windows.delete(key) + } +} + +export function RateLimitMiddleware(opts: { + windowMs: number + max: number + keyPrefix?: string +}): MiddlewareHandler { + return async (c, next) => { + sweep() + const key = (opts.keyPrefix ?? c.req.path) + ":" + (c.req.header("x-forwarded-for") ?? "local") + const now = Date.now() + let entry = windows.get(key) + if (!entry || now >= entry.resetAt) { + entry = { count: 0, resetAt: now + opts.windowMs } + windows.set(key, entry) + } + entry.count++ + if (entry.count > opts.max) { + c.header("Retry-After", String(Math.ceil((entry.resetAt - now) / 1000))) + return c.json({ error: "Too many requests" }, 429) + } + return next() + } +} diff --git a/packages/opencode/src/server/routes/instance/session.ts b/packages/opencode/src/server/routes/instance/session.ts index 6449e43d8..743049e2e 100644 --- a/packages/opencode/src/server/routes/instance/session.ts +++ b/packages/opencode/src/server/routes/instance/session.ts @@ -30,6 +30,7 @@ import { lazy } from "@/util/lazy" import { Bus } from "@/bus" import { NamedError } from "@mimo-ai/shared/util/error" import { jsonRequest, runRequest } from "./trace" +import { RateLimitMiddleware } from "../../rate-limit" const log = Log.create({ service: "server" }) @@ -698,8 +699,9 @@ export const SessionRoutes = lazy(() => .number() .int() .min(0) + .max(1000) .optional() - .meta({ description: "Maximum number of messages to return" }), + .meta({ description: "Maximum number of messages to return (max 1000)" }), before: z .string() .optional() @@ -744,7 +746,7 @@ export const SessionRoutes = lazy(() => Effect.gen(function* () { const session = yield* Session.Service yield* session.get(sessionID) - return yield* session.messages({ sessionID, agentID }) + return yield* session.messages({ sessionID, agentID, limit: 1000 }) }), ) return c.json(messages) @@ -1008,6 +1010,7 @@ export const SessionRoutes = lazy(() => ) .post( "/:sessionID/prompt_async", + RateLimitMiddleware({ windowMs: 60_000, max: 20, keyPrefix: "prompt_async" }), describeRoute({ summary: "Send async message", description: @@ -1122,6 +1125,7 @@ export const SessionRoutes = lazy(() => ) .post( "/:sessionID/shell", + RateLimitMiddleware({ windowMs: 60_000, max: 20, keyPrefix: "shell" }), describeRoute({ summary: "Run shell command", description: "Execute a shell command within the session context and return the AI's response.", diff --git a/packages/opencode/src/server/server.ts b/packages/opencode/src/server/server.ts index d4a366457..e47fd1a69 100644 --- a/packages/opencode/src/server/server.ts +++ b/packages/opencode/src/server/server.ts @@ -97,7 +97,17 @@ export async function listen(opts: { mdns?: boolean mdnsDomain?: string cors?: string[] + noAuth?: boolean }): Promise { + const isLoopback = + opts.hostname === "127.0.0.1" || opts.hostname === "localhost" || opts.hostname === "::1" + if (!isLoopback && !Flag.MIMOCODE_SERVER_PASSWORD && !opts.noAuth) { + throw new Error( + "Refusing to bind to non-loopback address without MIMOCODE_SERVER_PASSWORD. " + + "Set the environment variable or pass noAuth to explicitly allow unauthenticated access.", + ) + } + const built = create(opts) const server = await built.runtime.listen(opts) diff --git a/packages/opencode/src/tool/webfetch.ts b/packages/opencode/src/tool/webfetch.ts index d24c660e5..78c3eb011 100644 --- a/packages/opencode/src/tool/webfetch.ts +++ b/packages/opencode/src/tool/webfetch.ts @@ -5,6 +5,7 @@ import * as Tool from "./tool" import TurndownService from "turndown" import DESCRIPTION from "./webfetch.txt" import { isImageAttachment } from "@/util/media" +import { assertSafeUrl } from "@/util/ssrf" const MAX_RESPONSE_SIZE = 5 * 1024 * 1024 // 5MB const DEFAULT_TIMEOUT = 30 * 1000 // 30 seconds @@ -34,6 +35,8 @@ export const WebFetchTool = Tool.define( throw new Error("URL must start with http:// or https://") } + yield* Effect.promise(() => assertSafeUrl(params.url)) + yield* ctx.ask({ permission: "webfetch", patterns: [params.url], @@ -90,6 +93,12 @@ export const WebFetchTool = Tool.define( Effect.timeoutOrElse({ duration: timeout, orElse: () => Effect.die(new Error("Request timed out")) }), ) + // Block SSRF via redirect: if the response was redirected, validate final URL + const source = (response as any).source as Response | undefined + if (source?.url && source.url !== params.url) { + yield* Effect.promise(() => assertSafeUrl(source.url)) + } + // Check content length const contentLength = response.headers["content-length"] if (contentLength && parseInt(contentLength) > MAX_RESPONSE_SIZE) { diff --git a/packages/opencode/src/util/ssrf.ts b/packages/opencode/src/util/ssrf.ts new file mode 100644 index 000000000..155cbfc4e --- /dev/null +++ b/packages/opencode/src/util/ssrf.ts @@ -0,0 +1,116 @@ +import { lookup } from "dns/promises" + +const BLOCKED_HOSTNAMES = new Set([ + "metadata.google.internal", + "metadata.goog", + "kubernetes.default.svc", +]) + +const BLOCKED_IPV4_PREFIXES = [ + "10.", // private class A + "0.", // current network +] + +const BLOCKED_IPV4_RANGES: Array<{ start: number; end: number }> = [ + { start: ip4ToInt("172.16.0.0"), end: ip4ToInt("172.31.255.255") }, // private class B + { start: ip4ToInt("192.168.0.0"), end: ip4ToInt("192.168.255.255") }, // private class C + { start: ip4ToInt("169.254.0.0"), end: ip4ToInt("169.254.255.255") }, // link-local + { start: ip4ToInt("100.64.0.0"), end: ip4ToInt("100.127.255.255") }, // shared address (CGN) + { start: ip4ToInt("100.100.100.200"), end: ip4ToInt("100.100.100.200") }, // Alibaba Cloud metadata +] + +function ip4ToInt(ip: string): number { + const parts = ip.split(".") + return ((+parts[0]! << 24) | (+parts[1]! << 16) | (+parts[2]! << 8) | +parts[3]!) >>> 0 +} + +function isBlockedIPv4(ip: string): boolean { + for (const prefix of BLOCKED_IPV4_PREFIXES) { + if (ip.startsWith(prefix)) return true + } + const n = ip4ToInt(ip) + for (const range of BLOCKED_IPV4_RANGES) { + if (n >= range.start && n <= range.end) return true + } + return false +} + +function isBlockedIPv6(ip: string): boolean { + const normalized = ip.toLowerCase() + if (normalized.startsWith("fe80:")) return true // link-local + if (normalized.startsWith("fc") || normalized.startsWith("fd")) return true // ULA + // IPv4-mapped IPv6 in dotted-decimal form (::ffff:a.b.c.d) + const mapped = normalized.match(/^::ffff:(\d+\.\d+\.\d+\.\d+)$/) + if (mapped) return isBlockedIPv4(mapped[1]!) + // IPv4-mapped IPv6 in hex form (::ffff:HHHH:HHHH) — URL parsers normalize to this + const hexMapped = normalized.match(/^::ffff:([0-9a-f]{1,4}):([0-9a-f]{1,4})$/) + if (hexMapped) { + const hi = parseInt(hexMapped[1]!, 16) + const lo = parseInt(hexMapped[2]!, 16) + const ipv4 = `${(hi >> 8) & 0xff}.${hi & 0xff}.${(lo >> 8) & 0xff}.${lo & 0xff}` + return isBlockedIPv4(ipv4) + } + return false +} + +const MAX_REDIRECTS = 5 + +export async function safeFetch( + url: string, + init?: RequestInit, + fetchImpl: typeof fetch = fetch, +): Promise { + await assertSafeUrl(url) + let currentUrl = url + for (let i = 0; i < MAX_REDIRECTS; i++) { + const response = await fetchImpl(currentUrl, { ...init, redirect: "manual" }) + if (response.status >= 300 && response.status < 400) { + const location = response.headers.get("location") + if (!location) return response + currentUrl = new URL(location, currentUrl).toString() + await assertSafeUrl(currentUrl) + continue + } + return response + } + throw new Error("SSRF protection: too many redirects") +} + +export async function assertSafeUrl(url: string): Promise { + const parsed = new URL(url) + const hostname = parsed.hostname.replace(/^\[|\]$/g, "") + + if (BLOCKED_HOSTNAMES.has(hostname)) { + throw new Error(`SSRF protection: blocked hostname "${hostname}"`) + } + + // Numeric IPv4 check (before DNS) + if (/^\d+\.\d+\.\d+\.\d+$/.test(hostname)) { + if (isBlockedIPv4(hostname)) { + throw new Error(`SSRF protection: blocked private/internal IP "${hostname}"`) + } + return + } + + // Numeric IPv6 check (before DNS) + if (hostname.includes(":")) { + if (isBlockedIPv6(hostname)) { + throw new Error(`SSRF protection: blocked private/internal IPv6 "${hostname}"`) + } + return + } + + // DNS resolution check to prevent DNS rebinding + try { + const { address, family } = await lookup(hostname) + if (family === 4 && isBlockedIPv4(address)) { + throw new Error(`SSRF protection: hostname "${hostname}" resolves to blocked IP "${address}"`) + } + if (family === 6 && isBlockedIPv6(address)) { + throw new Error(`SSRF protection: hostname "${hostname}" resolves to blocked IPv6 "${address}"`) + } + } catch (e: any) { + if (e.message?.startsWith("SSRF protection:")) throw e + throw new Error(`SSRF protection: DNS resolution failed for "${hostname}"`) + } +} diff --git a/packages/opencode/test/server/rate-limit.test.ts b/packages/opencode/test/server/rate-limit.test.ts new file mode 100644 index 000000000..47bdf643e --- /dev/null +++ b/packages/opencode/test/server/rate-limit.test.ts @@ -0,0 +1,66 @@ +import { describe, expect, test } from "bun:test" +import { RateLimitMiddleware } from "../../src/server/rate-limit" + +const PASSED = new Response(null, { status: 200 }) + +function makeContext() { + const headers = new Map() + return { + req: { path: "/test", header: () => undefined }, + header: (k: string, v: string) => headers.set(k, v), + json: (body: any, status?: number) => ({ body, status }), + _headers: headers, + } +} + +describe("RateLimitMiddleware", () => { + test("allows requests within limit", async () => { + const mw = RateLimitMiddleware({ windowMs: 60_000, max: 3, keyPrefix: "test-allow" }) + const next = () => PASSED + + for (let i = 0; i < 3; i++) { + const c = makeContext() + const result = await mw(c as any, next as any) + expect(result).toBe(PASSED) + } + }) + + test("blocks requests exceeding limit", async () => { + const mw = RateLimitMiddleware({ windowMs: 60_000, max: 2, keyPrefix: "test-block" }) + const next = () => PASSED + + const c1 = makeContext() + expect(await mw(c1 as any, next as any)).toBe(PASSED) + + const c2 = makeContext() + expect(await mw(c2 as any, next as any)).toBe(PASSED) + + const c3 = makeContext() + const result = (await mw(c3 as any, next as any)) as any + expect(result.status).toBe(429) + expect(result.body.error).toBe("Too many requests") + }) + + test("sets Retry-After header on 429", async () => { + const mw = RateLimitMiddleware({ windowMs: 60_000, max: 1, keyPrefix: "test-header" }) + const next = () => PASSED + + await mw(makeContext() as any, next as any) + + const c = makeContext() + await mw(c as any, next as any) + expect(c._headers.has("Retry-After")).toBe(true) + }) + + test("resets after window expires", async () => { + const mw = RateLimitMiddleware({ windowMs: 1, max: 1, keyPrefix: "test-reset" }) + const next = () => PASSED + + await mw(makeContext() as any, next as any) + await Bun.sleep(5) + + const c = makeContext() + const result = await mw(c as any, next as any) + expect(result).toBe(PASSED) + }) +}) diff --git a/packages/opencode/test/util/ssrf.test.ts b/packages/opencode/test/util/ssrf.test.ts new file mode 100644 index 000000000..258095a2f --- /dev/null +++ b/packages/opencode/test/util/ssrf.test.ts @@ -0,0 +1,115 @@ +import { describe, expect, test } from "bun:test" +import { assertSafeUrl, safeFetch } from "../../src/util/ssrf" + +describe("assertSafeUrl", () => { + describe("blocks private IPv4", () => { + test.each(["http://10.0.0.1/", "http://10.255.255.255/"])("blocks 10.x: %s", async (url) => { + await expect(assertSafeUrl(url)).rejects.toThrow("SSRF protection") + }) + + test.each(["http://172.16.0.1/", "http://172.31.255.255/"])("blocks 172.16-31.x: %s", async (url) => { + await expect(assertSafeUrl(url)).rejects.toThrow("SSRF protection") + }) + + test.each(["http://192.168.0.1/", "http://192.168.255.255/"])("blocks 192.168.x: %s", async (url) => { + await expect(assertSafeUrl(url)).rejects.toThrow("SSRF protection") + }) + + test.each(["http://169.254.0.1/", "http://169.254.169.254/"])("blocks link-local: %s", async (url) => { + await expect(assertSafeUrl(url)).rejects.toThrow("SSRF protection") + }) + + test("blocks CGN range", async () => { + await expect(assertSafeUrl("http://100.64.0.1/")).rejects.toThrow("SSRF protection") + await expect(assertSafeUrl("http://100.100.100.200/")).rejects.toThrow("SSRF protection") + }) + }) + + describe("blocks metadata hostnames", () => { + test.each(["http://metadata.google.internal/", "http://metadata.goog/", "http://kubernetes.default.svc/"])( + "blocks %s", + async (url) => { + await expect(assertSafeUrl(url)).rejects.toThrow("SSRF protection") + }, + ) + }) + + describe("blocks IPv6", () => { + test("blocks link-local", async () => { + await expect(assertSafeUrl("http://[fe80::1]/")).rejects.toThrow("SSRF protection") + }) + + test("blocks ULA", async () => { + await expect(assertSafeUrl("http://[fd00::1]/")).rejects.toThrow("SSRF protection") + await expect(assertSafeUrl("http://[fc00::1]/")).rejects.toThrow("SSRF protection") + }) + + test("blocks IPv4-mapped private IPs (hex form)", async () => { + // ::ffff:c0a8:101 = 192.168.1.1 + await expect(assertSafeUrl("http://[::ffff:c0a8:101]/")).rejects.toThrow("SSRF protection") + // ::ffff:a9fe:a9fe = 169.254.169.254 + await expect(assertSafeUrl("http://[::ffff:a9fe:a9fe]/")).rejects.toThrow("SSRF protection") + }) + }) + + describe("allows loopback (CLI tool use case)", () => { + test("allows 127.0.0.1", async () => { + await expect(assertSafeUrl("http://127.0.0.1:3000/")).resolves.toBeUndefined() + }) + + test("allows localhost", async () => { + await expect(assertSafeUrl("http://localhost:8080/")).resolves.toBeUndefined() + }) + }) + + describe("allows public IPs", () => { + test("allows non-private IPv4", async () => { + await expect(assertSafeUrl("http://8.8.8.8/")).resolves.toBeUndefined() + await expect(assertSafeUrl("http://172.32.0.1/")).resolves.toBeUndefined() + await expect(assertSafeUrl("http://93.184.216.34/")).resolves.toBeUndefined() + }) + }) + + describe("DNS fail-closed", () => { + test("rejects unresolvable hostnames", async () => { + await expect(assertSafeUrl("http://this-domain-definitely-does-not-exist-xyz123.invalid/")).rejects.toThrow( + "SSRF protection: DNS resolution failed", + ) + }) + }) +}) + +describe("safeFetch", () => { + test("blocks redirect to private IP", async () => { + const mockFetch = async () => new Response(null, { + status: 302, + headers: { Location: "http://169.254.169.254/latest/meta-data/" }, + }) + await expect(safeFetch("http://127.0.0.1:8080/", undefined, mockFetch as any)).rejects.toThrow("SSRF protection") + }) + + test("follows safe redirects", async () => { + let callCount = 0 + const mockFetch = async () => { + callCount++ + if (callCount === 1) { + return new Response(null, { + status: 302, + headers: { Location: "http://127.0.0.1:9090/final" }, + }) + } + return new Response("ok", { status: 200 }) + } + const res = await safeFetch("http://127.0.0.1:8080/redirect", undefined, mockFetch as any) + expect(res.status).toBe(200) + expect(await res.text()).toBe("ok") + }) + + test("rejects too many redirects", async () => { + const mockFetch = async () => new Response(null, { + status: 302, + headers: { Location: "http://127.0.0.1:8080/loop" }, + }) + await expect(safeFetch("http://127.0.0.1:8080/loop", undefined, mockFetch as any)).rejects.toThrow("too many redirects") + }) +}) diff --git a/packages/web/src/components/share/content-markdown.tsx b/packages/web/src/components/share/content-markdown.tsx index 10a06bf5e..81d4f0a84 100644 --- a/packages/web/src/components/share/content-markdown.tsx +++ b/packages/web/src/components/share/content-markdown.tsx @@ -1,17 +1,23 @@ import { marked } from "marked" import { codeToHtml } from "shiki" import markedShiki from "marked-shiki" +import DOMPurify from "dompurify" import { createOverflow, useShareMessages } from "./common" import { CopyButton } from "./copy-button" import { createResource, createSignal } from "solid-js" import style from "./content-markdown.module.css" +function escapeAttr(s: string): string { + return s.replace(/&/g, "&").replace(/"/g, """).replace(//g, ">") +} + const markedWithShiki = marked.use( { renderer: { link({ href, title, text }) { - const titleAttr = title ? ` title="${title}"` : "" - return `${text}` + const safeHref = escapeAttr(href) + const titleAttr = title ? ` title="${escapeAttr(title)}"` : "" + return `${text}` }, }, }, @@ -37,7 +43,11 @@ export function ContentMarkdown(props: Props) { const [html] = createResource( () => strip(props.text), async (markdown) => { - return markedWithShiki.parse(markdown) + const raw = await markedWithShiki.parse(markdown) + return DOMPurify.sanitize(raw, { + FORBID_TAGS: ["style"], + FORBID_ATTR: ["onerror", "onload", "onclick", "onmouseover"], + }) }, ) const [expanded, setExpanded] = createSignal(false)