From 08bc14df1051aad3ec13bc8c43e13da349e1c629 Mon Sep 17 00:00:00 2001 From: Aamod007 Date: Sun, 21 Jun 2026 13:00:39 +0530 Subject: [PATCH] fix(rate-limit): add Retry-After header, fix TOCTOU race, fix remaining() key prefix - Add Retry-After header to getRateLimitHeaders() for RFC-compliant 429 responses - Fix middleware.ts to use getRateLimitHeaders() instead of manual headers - Fix TOCTOU race in RateLimiter.checkWithResult() Redis path (replaced GET+TTL+INCR with single atomic INCR+EXPIRE pipeline) - Fix remaining() to use correct key prefix (ratelimit:) and query Redis when available Related to #5857 --- lib/rate-limit.ts | 76 +++++++++++++++++++++++------------------------ middleware.ts | 13 ++++---- 2 files changed, 43 insertions(+), 46 deletions(-) diff --git a/lib/rate-limit.ts b/lib/rate-limit.ts index 1eceea21c..5c44bdb7b 100644 --- a/lib/rate-limit.ts +++ b/lib/rate-limit.ts @@ -69,54 +69,27 @@ export class RateLimiter { if (url && token) { try { - const getRes = await fetch(`${url}/pipeline`, { + const incrRes = await fetch(`${url}/pipeline`, { method: 'POST', headers: { Authorization: `Bearer ${token}`, 'Content-Type': 'application/json', }, body: JSON.stringify([ - ['GET', `ratelimit_class:${ip}`], - ['TTL', `ratelimit_class:${ip}`], + ['INCR', `ratelimit:${ip}`], + ['EXPIRE', `ratelimit:${ip}`, Math.floor(this.windowMs / 1000), 'NX'], ]), }); - if (getRes.ok) { - const getData = await getRes.json(); - const currentCount = parseInt(getData[0].result ?? '0', 10); - const ttl = getData[1].result as number; - - if (currentCount >= this.limit) { - return { - success: false, - limit: this.limit, - remaining: 0, - reset: ttl > 0 ? now + ttl * 1000 : now + this.windowMs, - }; - } - - const incrRes = await fetch(`${url}/pipeline`, { - method: 'POST', - headers: { - Authorization: `Bearer ${token}`, - 'Content-Type': 'application/json', - }, - body: JSON.stringify([ - ['INCR', `ratelimit_class:${ip}`], - ['EXPIRE', `ratelimit_class:${ip}`, Math.floor(this.windowMs / 1000), 'NX'], - ]), - }); - - if (incrRes.ok) { - const incrData = await incrRes.json(); - const count = incrData[0].result as number; - return { - success: count <= this.limit, - limit: this.limit, - remaining: Math.max(0, this.limit - count), - reset: now + this.windowMs, - }; - } + if (incrRes.ok) { + const incrData = await incrRes.json(); + const count = incrData[0].result as number; + return { + success: count <= this.limit, + limit: this.limit, + remaining: Math.max(0, this.limit - count), + reset: now + this.windowMs, + }; } } catch (error) { console.error('RateLimiter KV error, falling back to memory:', error); @@ -193,6 +166,29 @@ export class RateLimiter { * console.log(`You have ${left} requests left.`); */ async remaining(ip: string): Promise { + const url = process.env.KV_REST_API_URL; + const token = process.env.KV_REST_API_TOKEN; + if (url && token) { + try { + const res = await fetch(`${url}/pipeline`, { + method: 'POST', + headers: { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify([ + ['GET', `ratelimit:${ip}`], + ]), + }); + if (res.ok) { + const data = await res.json(); + const count = parseInt(data[0].result ?? '0', 10); + return Math.max(0, this.limit - count); + } + } catch { + // fall through to local cache + } + } const count = ((await this.cache.get(`ratelimit:${ip}`)) as unknown as number) ?? 0; return Math.max(0, this.limit - count); } @@ -303,7 +299,9 @@ export async function rateLimit( } export function getRateLimitHeaders(result: RateLimitResult) { + const retryAfter = Math.ceil(Math.max(0, result.reset - Date.now()) / 1000); return { + 'Retry-After': retryAfter.toString(), 'X-RateLimit-Limit': result.limit.toString(), 'X-RateLimit-Remaining': result.remaining.toString(), 'X-RateLimit-Reset': result.reset.toString(), diff --git a/middleware.ts b/middleware.ts index dd443ed80..f8170e0eb 100644 --- a/middleware.ts +++ b/middleware.ts @@ -1,6 +1,6 @@ import { NextResponse } from 'next/server'; import type { NextRequest } from 'next/server'; -import { rateLimit } from './lib/rate-limit'; +import { rateLimit, getRateLimitHeaders } from './lib/rate-limit'; import { getClientIp } from './utils/getClientIp'; /** @@ -32,18 +32,17 @@ export async function middleware(request: NextRequest) { status: 429, headers: { 'Content-Type': 'application/json', - 'X-RateLimit-Limit': result.limit.toString(), - 'X-RateLimit-Remaining': result.remaining.toString(), - 'X-RateLimit-Reset': result.reset.toString(), + ...getRateLimitHeaders(result), }, } ); } const response = NextResponse.next(); - response.headers.set('X-RateLimit-Limit', result.limit.toString()); - response.headers.set('X-RateLimit-Remaining', result.remaining.toString()); - response.headers.set('X-RateLimit-Reset', result.reset.toString()); + const headers = getRateLimitHeaders(result); + for (const [key, value] of Object.entries(headers)) { + response.headers.set(key, value); + } return response; }