diff --git a/docs/EmulatorPolicyChecklist.md b/docs/EmulatorPolicyChecklist.md index 71d1fd94..db72ea53 100644 --- a/docs/EmulatorPolicyChecklist.md +++ b/docs/EmulatorPolicyChecklist.md @@ -58,7 +58,7 @@ Track progress of emulator policy handler implementation. Each policy needs a ha | ⬜ | LlmSemanticCacheStore | LlmSemanticCacheStoreHandler.cs | No-op + callbacks | `emulator/llm-semantic-cache-store` | | ⬜ | Quota | QuotaHandler.cs | No-op + callbacks | `emulator/quota` | | ✅ | RateLimit | RateLimitHandler.cs | No-op + callbacks | `emulator/rate-limit` | -| ⬜ | RateLimitByKey | RateLimitByKeyHandler.cs | No-op + callbacks | `emulator/rate-limit-by-key` | +| ✅ | RateLimitByKey | RateLimitByKeyHandler.cs | No-op + callbacks | `emulator/rate-limit-by-key` | | ✅ | RewriteUri | RewriteUriHandler.cs | Context mutation | `emulator/rewrite-uri` | | ⬜ | SendRequest | SendRequestHandler.cs | External service mock | `emulator/send-request` | | ⬜ | SetBackendService | SetBackendServiceHandler.cs | Context mutation | `emulator/set-backend-service` | diff --git a/src/Testing/Emulator/Policies/RateLimitByKeyHandler.cs b/src/Testing/Emulator/Policies/RateLimitByKeyHandler.cs index 1a17179b..17dd10e2 100644 --- a/src/Testing/Emulator/Policies/RateLimitByKeyHandler.cs +++ b/src/Testing/Emulator/Policies/RateLimitByKeyHandler.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using Microsoft.Azure.ApiManagement.PolicyToolkit.Authoring; +using Microsoft.Azure.ApiManagement.PolicyToolkit.Testing.Expressions; namespace Microsoft.Azure.ApiManagement.PolicyToolkit.Testing.Emulator.Policies; @@ -12,6 +13,67 @@ internal class RateLimitByKeyHandler : PolicyHandler protected override void Handle(GatewayContext context, RateLimitByKeyConfig config) { - throw new NotImplementedException(); + var key = config.CounterKey; + var currentCount = context.RateLimitStore.GetCount(key); + var remaining = config.Calls - currentCount - 1; + + if (currentCount >= config.Calls) + { + DenyRequest(context, config); + } + + // Increment counter only if condition is met (default: true) + if (config.IncrementCondition != false) + { + var increment = config.IncrementCount ?? 1; + context.RateLimitStore.Increment(key, increment); + } + + // Set headers/variables on success path only + if (config.RemainingCallsHeaderName is not null) + { + context.Response.Headers[config.RemainingCallsHeaderName] = [Math.Max(0, remaining).ToString()]; + } + + if (config.TotalCallsHeaderName is not null) + { + context.Response.Headers[config.TotalCallsHeaderName] = [config.Calls.ToString()]; + } + + if (config.RemainingCallsVariableName is not null) + { + context.Variables[config.RemainingCallsVariableName] = Math.Max(0, remaining); + } + + if (context.Response.StatusCode == 429) + { + context.Response = new MockResponse { Headers = context.Response.Headers }; + } + } + + private static void DenyRequest(GatewayContext context, RateLimitByKeyConfig config) + { + if (config.RetryAfterVariableName is not null) + { + context.Variables[config.RetryAfterVariableName] = config.RenewalPeriod; + } + + context.Response = new MockResponse + { + StatusCode = 429, + StatusReason = "Too Many Requests", + }; + + if (config.RetryAfterHeaderName is not null) + { + context.Response.Headers[config.RetryAfterHeaderName] = [config.RenewalPeriod.ToString()]; + } + + if (config.TotalCallsHeaderName is not null) + { + context.Response.Headers[config.TotalCallsHeaderName] = [config.Calls.ToString()]; + } + + throw new FinishSectionProcessingException(); } } \ No newline at end of file diff --git a/test/Test.Testing/Emulator/Policies/RateLimitByKeyTests.cs b/test/Test.Testing/Emulator/Policies/RateLimitByKeyTests.cs new file mode 100644 index 00000000..0279df5b --- /dev/null +++ b/test/Test.Testing/Emulator/Policies/RateLimitByKeyTests.cs @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Azure.ApiManagement.PolicyToolkit.Authoring; +using Microsoft.Azure.ApiManagement.PolicyToolkit.Testing; +using Microsoft.Azure.ApiManagement.PolicyToolkit.Testing.Document; + +namespace Test.Emulator.Emulator.Policies; + +[TestClass] +public class RateLimitByKeyTests +{ + class SimpleRateLimitByKey : IDocument + { + public void Inbound(IInboundContext context) + { + context.RateLimitByKey(new RateLimitByKeyConfig + { + Calls = 3, + RenewalPeriod = 60, + CounterKey = "client-ip" + }); + } + } + + class RateLimitByKeyWithHeaders : IDocument + { + public void Inbound(IInboundContext context) + { + context.RateLimitByKey(new RateLimitByKeyConfig + { + Calls = 5, + RenewalPeriod = 30, + CounterKey = "user-id", + RemainingCallsHeaderName = "X-RateLimit-Remaining", + TotalCallsHeaderName = "X-RateLimit-Limit", + RetryAfterHeaderName = "Retry-After", + RetryAfterVariableName = "retryAfter", + RemainingCallsVariableName = "remainingCalls" + }); + } + } + + class RateLimitByKeyWithIncrementCount : IDocument + { + public void Inbound(IInboundContext context) + { + context.RateLimitByKey(new RateLimitByKeyConfig + { + Calls = 10, + RenewalPeriod = 60, + CounterKey = "heavy-op", + IncrementCount = 5 + }); + } + } + + class RateLimitByKeyWithIncrementConditionFalse : IDocument + { + public void Inbound(IInboundContext context) + { + context.RateLimitByKey(new RateLimitByKeyConfig + { + Calls = 3, + RenewalPeriod = 60, + CounterKey = "conditional", + IncrementCondition = false + }); + } + } + + class RateLimitByKeyThenSetHeader : IDocument + { + public void Inbound(IInboundContext context) + { + context.RateLimitByKey(new RateLimitByKeyConfig + { + Calls = 1, + RenewalPeriod = 60, + CounterKey = "block-me" + }); + context.SetHeader("X-After-RateLimit", "executed"); + } + } + + [TestMethod] + public void RateLimitByKey_UnderLimit() + { + var test = new SimpleRateLimitByKey().AsTestDocument(); + + test.RunInbound(); + + test.Context.Response.StatusCode.Should().NotBe(429); + } + + [TestMethod] + public void RateLimitByKey_ExceedsLimit() + { + var test = new SimpleRateLimitByKey().AsTestDocument(); + test.SetupRateLimitStore().SetCount("client-ip", 3); + + test.RunInbound(); + + test.Context.Response.StatusCode.Should().Be(429); + test.Context.Response.StatusReason.Should().Be("Too Many Requests"); + } + + [TestMethod] + public void RateLimitByKey_ResetAndRetry() + { + var test = new SimpleRateLimitByKey().AsTestDocument(); + test.SetupRateLimitStore().SetCount("client-ip", 3); + + test.RunInbound(); + test.Context.Response.StatusCode.Should().Be(429); + + test.SetupRateLimitStore().Reset(); + + test.RunInbound(); + test.Context.Response.StatusCode.Should().NotBe(429); + } + + [TestMethod] + public void RateLimitByKey_DifferentKeysAreIndependent() + { + var test = new SimpleRateLimitByKey().AsTestDocument(); + test.SetupRateLimitStore().SetCount("other-key", 100); + + test.RunInbound(); + + test.Context.Response.StatusCode.Should().NotBe(429); + } + + [TestMethod] + public void RateLimitByKey_SetsHeaders() + { + var test = new RateLimitByKeyWithHeaders().AsTestDocument(); + + test.RunInbound(); + + test.Context.Response.Headers.Should().ContainKey("X-RateLimit-Remaining"); + test.Context.Response.Headers.Should().ContainKey("X-RateLimit-Limit"); + test.Context.Variables.Should().ContainKey("remainingCalls"); + } + + [TestMethod] + public void RateLimitByKey_SetsRetryAfterOnExceeded() + { + var test = new RateLimitByKeyWithHeaders().AsTestDocument(); + test.SetupRateLimitStore().SetCount("user-id", 5); + + test.RunInbound(); + + test.Context.Response.StatusCode.Should().Be(429); + test.Context.Response.Headers.Should().ContainKey("Retry-After"); + test.Context.Variables.Should().ContainKey("retryAfter"); + } + + [TestMethod] + public void RateLimitByKey_IncrementCount() + { + var test = new RateLimitByKeyWithIncrementCount().AsTestDocument(); + + test.RunInbound(); + + test.SetupRateLimitStore().GetCount("heavy-op").Should().Be(5); + } + + [TestMethod] + public void RateLimitByKey_IncrementConditionFalse_DoesNotIncrement() + { + var test = new RateLimitByKeyWithIncrementConditionFalse().AsTestDocument(); + + test.RunInbound(); + + test.SetupRateLimitStore().GetCount("conditional").Should().Be(0); + } + + [TestMethod] + public void RateLimitByKey_TerminatesSectionOnExceeded() + { + var test = new RateLimitByKeyThenSetHeader().AsTestDocument(); + test.SetupRateLimitStore().SetCount("block-me", 1); + var headerExecuted = false; + test.SetupInbound().SetHeader().WithCallback((_, _, _) => headerExecuted = true); + + test.RunInbound(); + + headerExecuted.Should().BeFalse(); + test.Context.Response.StatusCode.Should().Be(429); + } + + [TestMethod] + public void RateLimitByKey_CounterNotIncrementedOnExceeded() + { + var test = new SimpleRateLimitByKey().AsTestDocument(); + test.SetupRateLimitStore().SetCount("client-ip", 3); + + test.RunInbound(); + + test.Context.Response.StatusCode.Should().Be(429); + test.SetupRateLimitStore().GetCount("client-ip").Should().Be(3); + } + + [TestMethod] + public void RateLimitByKey_Callback() + { + var test = new SimpleRateLimitByKey().AsTestDocument(); + var callbackExecuted = false; + test.SetupInbound().RateLimitByKey().WithCallback((_, config) => + { + callbackExecuted = true; + config.CounterKey.Should().Be("client-ip"); + }); + + test.RunInbound(); + + callbackExecuted.Should().BeTrue(); + } +}