From ca38f711152c6b5b88693bd8f1a47629d823fd6e Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Sat, 14 Feb 2026 01:31:01 -0600 Subject: [PATCH 1/2] Update DNS redirector --- tavern/internal/redirectors/dns/dns.go | 182 +++++-- tavern/internal/redirectors/dns/dns_test.go | 543 +++++++++++++++++++- 2 files changed, 678 insertions(+), 47 deletions(-) diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 5b704c683..b55b1f40a 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -26,12 +26,10 @@ import ( ) const ( - convTimeout = 15 * time.Minute defaultUDPPort = "53" // DNS protocol constants dnsHeaderSize = 12 - maxLabelLength = 63 txtRecordType = 16 aRecordType = 1 aaaaRecordType = 28 @@ -54,10 +52,13 @@ const ( MaxActiveConversations = 200000 NormalConversationTimeout = 15 * time.Minute ReducedConversationTimeout = 5 * time.Minute + ServedConversationTimeout = 2 * time.Minute CapacityRecoveryThreshold = 0.5 // 50% MaxAckRangesInResponse = 20 MaxNacksInResponse = 50 MaxDataSize = 50 * 1024 * 1024 // 50MB max data size + + MaxConcurrentHandlers = 256 // Max concurrent handler goroutines ) func init() { @@ -70,6 +71,7 @@ type Redirector struct { baseDomains []string conversationCount int32 conversationTimeout time.Duration + activeHandlers int32 } // Conversation tracks state for a request-response exchange @@ -86,6 +88,8 @@ type Conversation struct { ResponseChunks [][]byte // Split response for multi-fetch ResponseCRC uint32 Completed bool // Set to true when all chunks received + Failed bool // Set when processing fails + ResponseServed bool // Set after first successful FETCH } func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn, _ *tls.Config) error { @@ -116,6 +120,8 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr go r.cleanupConversations(ctx) + sem := make(chan struct{}, MaxConcurrentHandlers) + buf := make([]byte, 4096) for { select { @@ -133,10 +139,21 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr continue } - // Process query synchronously + // Copy query data before passing to goroutine queryCopy := make([]byte, n) copy(queryCopy, buf[:n]) - r.handleDNSQuery(ctx, conn, addr, queryCopy, upstream) + + // Acquire sem slot + sem <- struct{}{} + + go func(query []byte, remoteAddr *net.UDPAddr) { + defer func() { + atomic.AddInt32(&r.activeHandlers, -1) + <-sem + }() + atomic.AddInt32(&r.activeHandlers, 1) + r.handleDNSQuery(ctx, conn, remoteAddr, query, upstream) + }(queryCopy, addr) } } } @@ -197,16 +214,44 @@ func (r *Redirector) cleanupConversations(ctx context.Context) { r.conversationTimeout = NormalConversationTimeout } + var cleaned int32 r.conversations.Range(func(key, value interface{}) bool { conv := value.(*Conversation) conv.mu.Lock() - if now.Sub(conv.LastActivity) > r.conversationTimeout { + + // Use shorter timeout for conversations that already served their response or failed. + timeout := r.conversationTimeout + if conv.ResponseServed || conv.Failed { + timeout = ServedConversationTimeout + } + + if now.Sub(conv.LastActivity) > timeout { r.conversations.Delete(key) atomic.AddInt32(&r.conversationCount, -1) + cleaned++ } conv.mu.Unlock() return true }) + + activeConvs := atomic.LoadInt32(&r.conversationCount) + activeHdlrs := atomic.LoadInt32(&r.activeHandlers) + + // Warn if handler pool is near capacity + if activeHdlrs >= int32(MaxConcurrentHandlers*4/5) { + slog.Warn("handler pool near capacity", + "active_handlers", activeHdlrs, + "max_handlers", MaxConcurrentHandlers, + "active_conversations", activeConvs) + } + + // Log stats to have better visibility to current load on the redirector + slog.Info("dns redirector stats", + "active_conversations", activeConvs, + "active_handlers", activeHdlrs, + "max_handlers", MaxConcurrentHandlers, + "cleaned", cleaned, + "timeout", r.conversationTimeout) } } } @@ -227,7 +272,6 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr domain = strings.ToLower(domain) - slog.Info("dns redirector: request", "source", addr.String(), "destination", domain) slog.Debug("dns redirector: query details", "domain", domain, "query_type", queryType, "source", addr.String()) // Extract subdomain @@ -301,7 +345,12 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr } if err != nil { - slog.Error("dns redirector: upstream request failed", "type", packet.Type, "conv_id", packet.ConversationId, "source", addr.String(), "error", err) + if strings.Contains(err.Error(), "conversation not found") { + slog.Debug("packet for unknown conversation", + "type", packet.Type, "conv_id", packet.ConversationId) + } else { + slog.Error("dns redirector: upstream request failed", "type", packet.Type, "conv_id", packet.ConversationId, "source", addr.String(), "error", err) + } r.sendErrorResponse(conn, addr, transactionID) return } @@ -389,9 +438,6 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { slog.Warn("INIT packet missing file_size field", "conv_id", packet.ConversationId, "total_chunks", initPayload.TotalChunks) } - slog.Debug("creating conversation", "conv_id", packet.ConversationId, "method", initPayload.MethodCode, - "total_chunks", initPayload.TotalChunks, "file_size", initPayload.FileSize, "crc32", initPayload.DataCrc32) - conv := &Conversation{ ID: packet.ConversationId, MethodPath: initPayload.MethodCode, @@ -403,7 +449,34 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { Completed: false, } - r.conversations.Store(packet.ConversationId, conv) + // Use LoadOrStore to handle duplicate INITs from DNS resolvers idempotently. + // DNS recursive resolvers may forward the same query from multiple nodes, + // causing duplicate INIT packets. Thanks AWS. + actual, loaded := r.conversations.LoadOrStore(packet.ConversationId, conv) + if loaded { + // Conversation already exists + atomic.AddInt32(&r.conversationCount, -1) + + existingConv := actual.(*Conversation) + existingConv.mu.Lock() + defer existingConv.mu.Unlock() + existingConv.LastActivity = time.Now() + + slog.Debug("duplicate INIT for existing conversation", "conv_id", packet.ConversationId) + + acks, nacks := r.computeAcksNacks(existingConv) + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: acks, + Nacks: nacks, + } + statusData, err := proto.Marshal(statusPacket) + if err != nil { + return nil, fmt.Errorf("failed to marshal duplicate init status: %w", err) + } + return statusData, nil + } slog.Debug("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", initPayload.FileSize) @@ -427,8 +500,6 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { val, ok := r.conversations.Load(packet.ConversationId) if !ok { - slog.Debug("DATA packet for unknown conversation (INIT may be lost/delayed)", - "conv_id", packet.ConversationId, "seq", packet.Sequence) return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } @@ -436,6 +507,22 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client conv.mu.Lock() defer conv.mu.Unlock() + // Once the conversation has been forwarded to upstream, return full ack range immediately. + if conv.Completed || conv.Failed { + conv.LastActivity = time.Now() + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: []*dnspb.AckRange{{StartSeq: 1, EndSeq: conv.TotalChunks}}, + Nacks: []uint32{}, + } + statusData, err := proto.Marshal(statusPacket) + if err != nil { + return nil, fmt.Errorf("failed to marshal status packet: %w", err) + } + return statusData, nil + } + if packet.Sequence < 1 || packet.Sequence > conv.TotalChunks { return nil, fmt.Errorf("sequence out of bounds: %d (expected 1-%d)", packet.Sequence, conv.TotalChunks) } @@ -445,7 +532,7 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client slog.Debug("received chunk", "conv_id", conv.ID, "seq", packet.Sequence, "size", len(packet.Data), "total", len(conv.Chunks)) - if uint32(len(conv.Chunks)) == conv.TotalChunks && !conv.Completed { + if uint32(len(conv.Chunks)) == conv.TotalChunks { conv.Completed = true slog.Debug("C2 request complete, forwarding to upstream", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", conv.ExpectedDataSize) @@ -481,6 +568,9 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream // Reassemble data var fullData []byte + if conv.ExpectedDataSize > 0 { + fullData = make([]byte, 0, conv.ExpectedDataSize) + } for i := uint32(1); i <= conv.TotalChunks; i++ { chunk, ok := conv.Chunks[i] if !ok { @@ -489,25 +579,29 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream fullData = append(fullData, chunk...) } + for seq := range conv.Chunks { + conv.Chunks[seq] = nil + } + actualCRC := crc32.ChecksumIEEE(fullData) if actualCRC != conv.ExpectedCRC { - r.conversations.Delete(conv.ID) - atomic.AddInt32(&r.conversationCount, -1) + conv.Failed = true + conv.ResponseData = []byte{} return fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) } slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) if conv.ExpectedDataSize > 0 && uint32(len(fullData)) != conv.ExpectedDataSize { - r.conversations.Delete(conv.ID) - atomic.AddInt32(&r.conversationCount, -1) + conv.Failed = true + conv.ResponseData = []byte{} return fmt.Errorf("reassembled data size mismatch: expected %d bytes, got %d bytes", conv.ExpectedDataSize, len(fullData)) } responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) if err != nil { - r.conversations.Delete(conv.ID) - atomic.AddInt32(&r.conversationCount, -1) + conv.Failed = true + conv.ResponseData = []byte{} return fmt.Errorf("failed to forward to upstream: %w", err) } @@ -616,8 +710,19 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) conv.mu.Lock() defer conv.mu.Unlock() + // If conversation failed, return empty response + if conv.Failed { + slog.Debug("returning empty response for failed conversation", "conv_id", conv.ID) + conv.ResponseServed = true + conv.LastActivity = time.Now() + return []byte{}, nil + } + if conv.ResponseData == nil { - return nil, fmt.Errorf("no response data available") + // Response not ready yet, the upstream gRPC call is still in progress. + slog.Debug("response not ready yet - upstream call in progress", "conv_id", conv.ID) + conv.LastActivity = time.Now() + return []byte{}, nil } conv.LastActivity = time.Now() @@ -654,32 +759,21 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) slog.Debug("returning response chunk", "conv_id", conv.ID, "chunk", fetchPayload.ChunkIndex, "size", len(conv.ResponseChunks[chunkIndex]), "total_chunks", len(conv.ResponseChunks)) + conv.ResponseServed = true return conv.ResponseChunks[chunkIndex], nil } slog.Debug("returning response", "conv_id", conv.ID, "size", len(conv.ResponseData)) + conv.ResponseServed = true return conv.ResponseData, nil } // handleCompletePacket processes COMPLETE packet and cleans up conversation func (r *Redirector) handleCompletePacket(packet *dnspb.DNSPacket) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) - if !ok { - return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) - } + val, loaded := r.conversations.LoadAndDelete(packet.ConversationId) - conv := val.(*Conversation) - conv.mu.Lock() - defer conv.mu.Unlock() - - slog.Debug("C2 conversation completed and confirmed by client", "conv_id", conv.ID, "method", conv.MethodPath) - - // Delete conversation and decrement counter - r.conversations.Delete(packet.ConversationId) - atomic.AddInt32(&r.conversationCount, -1) - - // Return empty success status + // Build success status statusPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_STATUS, ConversationId: packet.ConversationId, @@ -690,6 +784,22 @@ func (r *Redirector) handleCompletePacket(packet *dnspb.DNSPacket) ([]byte, erro if err != nil { return nil, fmt.Errorf("failed to marshal complete status: %w", err) } + + if !loaded { + // Conversation already cleaned up by a prior COMPLETE - return success + slog.Debug("duplicate COMPLETE for already-completed conversation", "conv_id", packet.ConversationId) + return statusData, nil + } + + conv := val.(*Conversation) + conv.mu.Lock() + defer conv.mu.Unlock() + + slog.Debug("C2 conversation completed and confirmed by client", "conv_id", conv.ID, "method", conv.MethodPath) + + // Decrement counter + atomic.AddInt32(&r.conversationCount, -1) + return statusData, nil } diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go index 3bec7bc85..ced1b31ca 100644 --- a/tavern/internal/redirectors/dns/dns_test.go +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -7,6 +7,7 @@ import ( "net" "sort" "sync" + "sync/atomic" "testing" "time" @@ -414,6 +415,97 @@ func TestHandleInitPacket(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "max active conversations") }) + + t.Run("duplicate INIT returns status without counter leak", func(t *testing.T) { + r := &Redirector{} + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 3, + DataCrc32: 0xDEADBEEF, + FileSize: 512, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "dupinit1234", + Data: payloadBytes, + } + + // First INIT creates conversation + resp1, err := r.handleInitPacket(packet) + require.NoError(t, err) + require.NotNil(t, resp1) + assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount)) + + // Verify first response is STATUS + var status1 dnspb.DNSPacket + err = proto.Unmarshal(resp1, &status1) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, status1.Type) + + // Simulate duplicate INIT from DNS resolver + resp2, err := r.handleInitPacket(packet) + require.NoError(t, err) + require.NotNil(t, resp2) + + // Counter should NOT increment (no leak) + assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount), "duplicate INIT should not increment counter") + + // Verify duplicate response is also STATUS + var status2 dnspb.DNSPacket + err = proto.Unmarshal(resp2, &status2) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, status2.Type) + assert.Equal(t, "dupinit1234", status2.ConversationId) + + // Conversation should still exist and be unchanged + val, ok := r.conversations.Load("dupinit1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) + assert.Equal(t, uint32(3), conv.TotalChunks) + }) + + t.Run("concurrent duplicate INITs from resolvers", func(t *testing.T) { + r := &Redirector{} + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 5, + DataCrc32: 0x12345678, + FileSize: 1024, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "concurrent-init", + Data: payloadBytes, + } + + // Simulate 10 concurrent INITs from different resolver nodes + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := r.handleInitPacket(packet) + assert.NoError(t, err) + }() + } + wg.Wait() + + // Counter should be exactly 1 (no leaks) + assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount), "concurrent INITs should not cause counter leak") + + // Conversation should exist + _, ok := r.conversations.Load("concurrent-init") + assert.True(t, ok) + }) } // TestHandleFetchPacket tests FETCH packet processing @@ -508,12 +600,62 @@ func TestHandleFetchPacket(t *testing.T) { assert.Contains(t, err.Error(), "conversation not found") }) - t.Run("fetch with no response ready", func(t *testing.T) { + t.Run("fetch on failed conversation returns empty response", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "failconv", + ResponseData: []byte{}, + Failed: true, + LastActivity: time.Now(), + } + r.conversations.Store("failconv", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "failconv", + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + assert.Equal(t, []byte{}, data) + }) + + t.Run("fetch on failed conversation does not spam errors", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "failconv2", + ResponseData: []byte{}, + Failed: true, + LastActivity: time.Now(), + } + r.conversations.Store("failconv2", conv) + + // Multiple FETCH requests should all succeed (no error) instead of + // returning "conversation not found" after deletion + for i := 0; i < 10; i++ { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "failconv2", + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err, "FETCH attempt %d should not error", i) + assert.Equal(t, []byte{}, data) + } + + // Conversation should still exist in the map (not deleted) + _, ok := r.conversations.Load("failconv2") + assert.True(t, ok, "failed conversation should remain in map for cleanup") + }) + + t.Run("fetch with no response ready returns empty (upstream in progress)", func(t *testing.T) { r := &Redirector{} conv := &Conversation{ ID: "conv1234", - ResponseData: nil, // No response yet + ResponseData: nil, // upstream call still in progress LastActivity: time.Now(), } r.conversations.Store("conv1234", conv) @@ -523,9 +665,10 @@ func TestHandleFetchPacket(t *testing.T) { ConversationId: "conv1234", } - _, err := r.handleFetchPacket(packet) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no response data") + // Should return empty response (not error) to avoid NXDOMAIN + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + assert.Equal(t, []byte{}, data) }) t.Run("fetch chunk out of bounds", func(t *testing.T) { @@ -555,6 +698,146 @@ func TestHandleFetchPacket(t *testing.T) { }) } +// TestHandleCompletePacket tests COMPLETE packet processing +func TestHandleCompletePacket(t *testing.T) { + t.Run("successful complete cleans up conversation", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "complete1234", + MethodPath: "/c2.C2/ClaimTasks", + ResponseData: []byte("response"), + LastActivity: time.Now(), + } + r.conversations.Store("complete1234", conv) + atomic.StoreInt32(&r.conversationCount, 1) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "complete1234", + } + + responseData, err := r.handleCompletePacket(packet) + require.NoError(t, err) + require.NotNil(t, responseData) + + // Verify response is STATUS + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(responseData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + assert.Equal(t, "complete1234", statusPacket.ConversationId) + + // Verify conversation was removed + _, ok := r.conversations.Load("complete1234") + assert.False(t, ok, "conversation should be removed after COMPLETE") + + // Verify counter decremented + assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount)) + }) + + t.Run("duplicate COMPLETE returns success idempotently", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "dupcomp1234", + MethodPath: "/c2.C2/ClaimTasks", + ResponseData: []byte("response"), + LastActivity: time.Now(), + } + r.conversations.Store("dupcomp1234", conv) + atomic.StoreInt32(&r.conversationCount, 1) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "dupcomp1234", + } + + // First COMPLETE removes conversation + resp1, err := r.handleCompletePacket(packet) + require.NoError(t, err) + require.NotNil(t, resp1) + assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount)) + + // Verify conversation removed + _, ok := r.conversations.Load("dupcomp1234") + assert.False(t, ok) + + // Second COMPLETE (duplicate from resolver) should succeed, not error + resp2, err := r.handleCompletePacket(packet) + require.NoError(t, err, "duplicate COMPLETE should not error") + require.NotNil(t, resp2) + + // Counter should still be 0 (no double-decrement) + assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount), "duplicate COMPLETE should not double-decrement") + + // Verify response is also STATUS + var status dnspb.DNSPacket + err = proto.Unmarshal(resp2, &status) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, status.Type) + }) + + t.Run("COMPLETE for never-existed conversation returns success", func(t *testing.T) { + r := &Redirector{} + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "nonexistent", + } + + // Should succeed (not error) for idempotency + responseData, err := r.handleCompletePacket(packet) + require.NoError(t, err) + require.NotNil(t, responseData) + + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(responseData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + + // Counter should remain unchanged + assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount)) + }) + + t.Run("concurrent COMPLETEs from resolvers", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "concurrent-complete", + MethodPath: "/c2.C2/ClaimTasks", + ResponseData: []byte("response"), + LastActivity: time.Now(), + } + r.conversations.Store("concurrent-complete", conv) + atomic.StoreInt32(&r.conversationCount, 1) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "concurrent-complete", + } + + // Simulate 10 concurrent COMPLETEs from different resolver nodes + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := r.handleCompletePacket(packet) + assert.NoError(t, err) + }() + } + wg.Wait() + + // Counter should be exactly 0 (no negative values from double-decrement) + assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount), "concurrent COMPLETEs should not cause counter underflow") + + // Conversation should be removed + _, ok := r.conversations.Load("concurrent-complete") + assert.False(t, ok) + }) +} + // TestParseDomainNameAndType tests DNS query parsing func TestParseDomainNameAndType(t *testing.T) { r := &Redirector{} @@ -655,7 +938,7 @@ func TestConversationCleanup(t *testing.T) { LastActivity: time.Now().Add(-20 * time.Minute), } r.conversations.Store("stale", staleConv) - r.conversationCount = 1 + atomic.StoreInt32(&r.conversationCount, 1) // Create fresh conversation freshConv := &Conversation{ @@ -663,16 +946,20 @@ func TestConversationCleanup(t *testing.T) { LastActivity: time.Now(), } r.conversations.Store("fresh", freshConv) - r.conversationCount = 2 + atomic.StoreInt32(&r.conversationCount, 2) - // Run cleanup + // Run cleanup (mirrors cleanupConversations logic) now := time.Now() r.conversations.Range(func(key, value any) bool { conv := value.(*Conversation) conv.mu.Lock() - if now.Sub(conv.LastActivity) > r.conversationTimeout { + timeout := r.conversationTimeout + if conv.ResponseServed || conv.Failed { + timeout = ServedConversationTimeout + } + if now.Sub(conv.LastActivity) > timeout { r.conversations.Delete(key) - r.conversationCount-- + atomic.AddInt32(&r.conversationCount, -1) } conv.mu.Unlock() return true @@ -686,7 +973,70 @@ func TestConversationCleanup(t *testing.T) { _, ok = r.conversations.Load("fresh") assert.True(t, ok, "fresh conversation should remain") - assert.Equal(t, int32(1), r.conversationCount) + assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount)) +} + +// TestServedConversationCleanup tests that conversations with ResponseServed +// or Failed flags are cleaned up with the shorter ServedConversationTimeout +func TestServedConversationCleanup(t *testing.T) { + r := &Redirector{ + conversationTimeout: 15 * time.Minute, + } + + // Served conversation: response was delivered 3 minutes ago (> 2 min served timeout) + servedConv := &Conversation{ + ID: "served", + ResponseServed: true, + LastActivity: time.Now().Add(-3 * time.Minute), + } + r.conversations.Store("served", servedConv) + + // Failed conversation: failed 3 minutes ago (> 2 min served timeout) + failedConv := &Conversation{ + ID: "failed", + Failed: true, + LastActivity: time.Now().Add(-3 * time.Minute), + } + r.conversations.Store("failed", failedConv) + + // In-progress conversation: 3 minutes old but not served/failed (< 15 min normal timeout) + activeConv := &Conversation{ + ID: "active", + LastActivity: time.Now().Add(-3 * time.Minute), + } + r.conversations.Store("active", activeConv) + + atomic.StoreInt32(&r.conversationCount, 3) + + // Run cleanup + now := time.Now() + r.conversations.Range(func(key, value any) bool { + conv := value.(*Conversation) + conv.mu.Lock() + timeout := r.conversationTimeout + if conv.ResponseServed || conv.Failed { + timeout = ServedConversationTimeout + } + if now.Sub(conv.LastActivity) > timeout { + r.conversations.Delete(key) + atomic.AddInt32(&r.conversationCount, -1) + } + conv.mu.Unlock() + return true + }) + + // Served and failed conversations should be cleaned (3 min > 2 min served timeout) + _, ok := r.conversations.Load("served") + assert.False(t, ok, "served conversation should be cleaned up with shorter timeout") + + _, ok = r.conversations.Load("failed") + assert.False(t, ok, "failed conversation should be cleaned up with shorter timeout") + + // Active conversation should remain (3 min < 15 min normal timeout) + _, ok = r.conversations.Load("active") + assert.True(t, ok, "in-progress conversation should remain") + + assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount)) } // TestConcurrentConversationAccess tests thread safety of conversation handling @@ -928,6 +1278,81 @@ func TestHandleDataPacket(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "sequence out of bounds") }) + + t.Run("short-circuit for completed conversation", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create a conversation that is already completed + conv := &Conversation{ + ID: "completed1", + TotalChunks: 3, + Completed: true, + Chunks: map[uint32][]byte{ + 1: {0x01}, + 2: {0x02}, + 3: {0x03}, + }, + LastActivity: time.Now(), + } + r.conversations.Store("completed1", conv) + + // Send a duplicate DATA to completed conversation + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "completed1", + Sequence: 1, + Data: []byte{0xFF}, // Different data + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + // Should get full ack range without recomputation + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + require.Len(t, statusPacket.Acks, 1) + assert.Equal(t, uint32(1), statusPacket.Acks[0].StartSeq) + assert.Equal(t, uint32(3), statusPacket.Acks[0].EndSeq) + assert.Empty(t, statusPacket.Nacks) + + // Original chunk data should NOT be overwritten + assert.Equal(t, []byte{0x01}, conv.Chunks[1]) + }) + + t.Run("short-circuit for failed conversation", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + conv := &Conversation{ + ID: "failed1", + TotalChunks: 2, + Failed: true, + Chunks: map[uint32][]byte{1: nil, 2: nil}, + LastActivity: time.Now(), + } + r.conversations.Store("failed1", conv) + + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "failed1", + Sequence: 1, + Data: []byte{0x01}, + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + require.Len(t, statusPacket.Acks, 1) + assert.Equal(t, uint32(1), statusPacket.Acks[0].StartSeq) + assert.Equal(t, uint32(2), statusPacket.Acks[0].EndSeq) + }) } // TestProcessCompletedConversation tests data reassembly and CRC validation @@ -991,3 +1416,99 @@ func TestProcessCompletedConversation(t *testing.T) { assert.NotEqual(t, wrongCRC, actualCRC, "CRC should mismatch") }) } + +// TestConversationNotFoundError verifies the error message for missing conversations +func TestConversationNotFoundError(t *testing.T) { + r := &Redirector{} + + t.Run("DATA returns conversation not found error", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "missing123", + Sequence: 1, + Data: []byte{0x01}, + } + + _, err := r.handleDataPacket(context.Background(), nil, packet, txtRecordType) + require.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + assert.Contains(t, err.Error(), "missing123") + }) + + t.Run("FETCH returns conversation not found error", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "missing456", + } + + _, err := r.handleFetchPacket(packet) + require.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + assert.Contains(t, err.Error(), "missing456") + }) +} + +// TestActiveHandlersCounter verifies atomic counter operations work correctly +func TestActiveHandlersCounter(t *testing.T) { + t.Run("starts at zero", func(t *testing.T) { + r := &Redirector{} + assert.Equal(t, int32(0), atomic.LoadInt32(&r.activeHandlers)) + }) + + t.Run("concurrent increment and decrement", func(t *testing.T) { + r := &Redirector{} + + // Simulate concurrent handler goroutines incrementing and decrementing + var wg sync.WaitGroup + iterations := 100 + + for i := 0; i < iterations; i++ { + wg.Add(1) + go func() { + defer wg.Done() + // Simulate handler lifecycle: increment at start, decrement at end + atomic.AddInt32(&r.activeHandlers, 1) + time.Sleep(time.Microsecond) // Small delay to increase contention + atomic.AddInt32(&r.activeHandlers, -1) + }() + } + + wg.Wait() + + // After all handlers complete, counter should be back to zero + assert.Equal(t, int32(0), atomic.LoadInt32(&r.activeHandlers), "counter should return to zero after all handlers complete") + }) + + t.Run("peak tracking under load", func(t *testing.T) { + r := &Redirector{} + + var peak int32 + var peakMu sync.Mutex + var wg sync.WaitGroup + + // Start handlers that overlap in time + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + current := atomic.AddInt32(&r.activeHandlers, 1) + + peakMu.Lock() + if current > peak { + peak = current + } + peakMu.Unlock() + + time.Sleep(time.Millisecond) + atomic.AddInt32(&r.activeHandlers, -1) + }() + } + + wg.Wait() + + // Peak should be > 1 (some concurrency achieved) + assert.Greater(t, peak, int32(1), "peak should show concurrent handlers") + // Final value should be zero + assert.Equal(t, int32(0), atomic.LoadInt32(&r.activeHandlers)) + }) +} From 7e9122d919c770cf36cbf47a1c876e36f25c39ae Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Mon, 16 Feb 2026 22:31:15 -0600 Subject: [PATCH 2/2] PR review fixes, I do not like sync code. --- tavern/internal/redirectors/dns/dns.go | 177 +++----- tavern/internal/redirectors/dns/dns_test.go | 444 +++++++++----------- 2 files changed, 267 insertions(+), 354 deletions(-) diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index b55b1f40a..26080d42e 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -15,9 +15,9 @@ import ( "sort" "strings" "sync" - "sync/atomic" "time" + lru "github.com/hashicorp/golang-lru/v2" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" @@ -49,29 +49,30 @@ const ( benignARecordIP = "0.0.0.0" // Async protocol configuration - MaxActiveConversations = 200000 - NormalConversationTimeout = 15 * time.Minute - ReducedConversationTimeout = 5 * time.Minute - ServedConversationTimeout = 2 * time.Minute - CapacityRecoveryThreshold = 0.5 // 50% - MaxAckRangesInResponse = 20 - MaxNacksInResponse = 50 - MaxDataSize = 50 * 1024 * 1024 // 50MB max data size - - MaxConcurrentHandlers = 256 // Max concurrent handler goroutines + MaxActiveConversations = 200000 + NormalConversationTimeout = 15 * time.Minute + ServedConversationTimeout = 2 * time.Minute + MaxAckRangesInResponse = 20 + MaxNacksInResponse = 50 + MaxDataSize = 50 * 1024 * 1024 // 50MB max data size ) func init() { - redirectors.Register("dns", &Redirector{}) + cache, err := lru.New[string, *Conversation](MaxActiveConversations) + if err != nil { + slog.Error("dns redirector: failed to create conversation cache") + } + redirectors.Register("dns", &Redirector{ + conversations: cache, + }) } // Redirector handles DNS-based C2 communication type Redirector struct { - conversations sync.Map + conversationsMu sync.Mutex + conversations *lru.Cache[string, *Conversation] baseDomains []string - conversationCount int32 conversationTimeout time.Duration - activeHandlers int32 } // Conversation tracks state for a request-response exchange @@ -88,7 +89,6 @@ type Conversation struct { ResponseChunks [][]byte // Split response for multi-fetch ResponseCRC uint32 Completed bool // Set to true when all chunks received - Failed bool // Set when processing fails ResponseServed bool // Set after first successful FETCH } @@ -120,8 +120,6 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr go r.cleanupConversations(ctx) - sem := make(chan struct{}, MaxConcurrentHandlers) - buf := make([]byte, 4096) for { select { @@ -143,17 +141,7 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr queryCopy := make([]byte, n) copy(queryCopy, buf[:n]) - // Acquire sem slot - sem <- struct{}{} - - go func(query []byte, remoteAddr *net.UDPAddr) { - defer func() { - atomic.AddInt32(&r.activeHandlers, -1) - <-sem - }() - atomic.AddInt32(&r.activeHandlers, 1) - r.handleDNSQuery(ctx, conn, remoteAddr, query, upstream) - }(queryCopy, addr) + go r.handleDNSQuery(ctx, conn, addr, queryCopy, upstream) } } } @@ -205,51 +193,31 @@ func (r *Redirector) cleanupConversations(ctx context.Context) { return case <-ticker.C: now := time.Now() - count := atomic.LoadInt32(&r.conversationCount) - - // Adjust timeout based on capacity - if count >= MaxActiveConversations { - r.conversationTimeout = ReducedConversationTimeout - } else if float64(count) < float64(MaxActiveConversations)*CapacityRecoveryThreshold { - r.conversationTimeout = NormalConversationTimeout - } + var cleaned int - var cleaned int32 - r.conversations.Range(func(key, value interface{}) bool { - conv := value.(*Conversation) + for _, key := range r.conversations.Keys() { + conv, ok := r.conversations.Peek(key) + if !ok { + continue + } conv.mu.Lock() - // Use shorter timeout for conversations that already served their response or failed. - timeout := r.conversationTimeout - if conv.ResponseServed || conv.Failed { - timeout = ServedConversationTimeout + shouldDelete := false + if conv.ResponseServed { + shouldDelete = now.Sub(conv.LastActivity) > ServedConversationTimeout + } else { + shouldDelete = now.Sub(conv.LastActivity) > r.conversationTimeout } - if now.Sub(conv.LastActivity) > timeout { - r.conversations.Delete(key) - atomic.AddInt32(&r.conversationCount, -1) + if shouldDelete { + r.conversations.Remove(key) cleaned++ } conv.mu.Unlock() - return true - }) - - activeConvs := atomic.LoadInt32(&r.conversationCount) - activeHdlrs := atomic.LoadInt32(&r.activeHandlers) - - // Warn if handler pool is near capacity - if activeHdlrs >= int32(MaxConcurrentHandlers*4/5) { - slog.Warn("handler pool near capacity", - "active_handlers", activeHdlrs, - "max_handlers", MaxConcurrentHandlers, - "active_conversations", activeConvs) } - // Log stats to have better visibility to current load on the redirector slog.Info("dns redirector stats", - "active_conversations", activeConvs, - "active_handlers", activeHdlrs, - "max_handlers", MaxConcurrentHandlers, + "active_conversations", r.conversations.Len(), "cleaned", cleaned, "timeout", r.conversationTimeout) } @@ -412,25 +380,13 @@ func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { // handleInitPacket processes INIT packet func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { - for { - current := atomic.LoadInt32(&r.conversationCount) - if current >= MaxActiveConversations { - return nil, fmt.Errorf("max active conversations reached: %d", current) - } - if atomic.CompareAndSwapInt32(&r.conversationCount, current, current+1) { - break - } - } - var initPayload dnspb.InitPayload if err := proto.Unmarshal(packet.Data, &initPayload); err != nil { - atomic.AddInt32(&r.conversationCount, -1) return nil, fmt.Errorf("failed to unmarshal init payload: %w", err) } // Validate file size from client if initPayload.FileSize > MaxDataSize { - atomic.AddInt32(&r.conversationCount, -1) return nil, fmt.Errorf("data size exceeds maximum: %d > %d bytes", initPayload.FileSize, MaxDataSize) } @@ -449,22 +405,21 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { Completed: false, } - // Use LoadOrStore to handle duplicate INITs from DNS resolvers idempotently. + // Use conversationsMu to atomically check-then-store, handling duplicate INITs + // from DNS recursive resolvers idempotently. // DNS recursive resolvers may forward the same query from multiple nodes, // causing duplicate INIT packets. Thanks AWS. - actual, loaded := r.conversations.LoadOrStore(packet.ConversationId, conv) - if loaded { - // Conversation already exists - atomic.AddInt32(&r.conversationCount, -1) + r.conversationsMu.Lock() + if existing, ok := r.conversations.Get(packet.ConversationId); ok { + r.conversationsMu.Unlock() - existingConv := actual.(*Conversation) - existingConv.mu.Lock() - defer existingConv.mu.Unlock() - existingConv.LastActivity = time.Now() + existing.mu.Lock() + defer existing.mu.Unlock() + existing.LastActivity = time.Now() slog.Debug("duplicate INIT for existing conversation", "conv_id", packet.ConversationId) - acks, nacks := r.computeAcksNacks(existingConv) + acks, nacks := r.computeAcksNacks(existing) statusPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_STATUS, ConversationId: packet.ConversationId, @@ -478,6 +433,13 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { return statusData, nil } + evicted := r.conversations.Add(packet.ConversationId, conv) + r.conversationsMu.Unlock() + + if evicted { + slog.Debug("LRU evicted oldest conversation to make room", "conv_id", conv.ID) + } + slog.Debug("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", initPayload.FileSize) @@ -489,8 +451,7 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { } statusData, err := proto.Marshal(statusPacket) if err != nil { - atomic.AddInt32(&r.conversationCount, -1) - r.conversations.Delete(packet.ConversationId) + r.conversations.Remove(packet.ConversationId) return nil, fmt.Errorf("failed to marshal init status: %w", err) } return statusData, nil @@ -498,17 +459,16 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { // handleDataPacket processes DATA packet func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) + conv, ok := r.conversations.Get(packet.ConversationId) if !ok { return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() // Once the conversation has been forwarded to upstream, return full ack range immediately. - if conv.Completed || conv.Failed { + if conv.Completed { conv.LastActivity = time.Now() statusPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_STATUS, @@ -585,23 +545,20 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream actualCRC := crc32.ChecksumIEEE(fullData) if actualCRC != conv.ExpectedCRC { - conv.Failed = true - conv.ResponseData = []byte{} + r.conversations.Remove(conv.ID) return fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) } slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) if conv.ExpectedDataSize > 0 && uint32(len(fullData)) != conv.ExpectedDataSize { - conv.Failed = true - conv.ResponseData = []byte{} + r.conversations.Remove(conv.ID) return fmt.Errorf("reassembled data size mismatch: expected %d bytes, got %d bytes", conv.ExpectedDataSize, len(fullData)) } responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) if err != nil { - conv.Failed = true - conv.ResponseData = []byte{} + r.conversations.Remove(conv.ID) return fmt.Errorf("failed to forward to upstream: %w", err) } @@ -701,23 +658,14 @@ func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, [] // handleFetchPacket processes FETCH packet func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) + conv, ok := r.conversations.Get(packet.ConversationId) if !ok { return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() - // If conversation failed, return empty response - if conv.Failed { - slog.Debug("returning empty response for failed conversation", "conv_id", conv.ID) - conv.ResponseServed = true - conv.LastActivity = time.Now() - return []byte{}, nil - } - if conv.ResponseData == nil { // Response not ready yet, the upstream gRPC call is still in progress. slog.Debug("response not ready yet - upstream call in progress", "conv_id", conv.ID) @@ -771,8 +719,6 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) // handleCompletePacket processes COMPLETE packet and cleans up conversation func (r *Redirector) handleCompletePacket(packet *dnspb.DNSPacket) ([]byte, error) { - val, loaded := r.conversations.LoadAndDelete(packet.ConversationId) - // Build success status statusPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_STATUS, @@ -785,20 +731,23 @@ func (r *Redirector) handleCompletePacket(packet *dnspb.DNSPacket) ([]byte, erro return nil, fmt.Errorf("failed to marshal complete status: %w", err) } + // Atomically check-then-delete to avoid race with concurrent COMPLETEs + r.conversationsMu.Lock() + conv, loaded := r.conversations.Get(packet.ConversationId) + if loaded { + r.conversations.Remove(packet.ConversationId) + } + r.conversationsMu.Unlock() + if !loaded { // Conversation already cleaned up by a prior COMPLETE - return success slog.Debug("duplicate COMPLETE for already-completed conversation", "conv_id", packet.ConversationId) return statusData, nil } - conv := val.(*Conversation) conv.mu.Lock() - defer conv.mu.Unlock() - slog.Debug("C2 conversation completed and confirmed by client", "conv_id", conv.ID, "method", conv.MethodPath) - - // Decrement counter - atomic.AddInt32(&r.conversationCount, -1) + conv.mu.Unlock() return statusData, nil } diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go index ced1b31ca..fbbc652cd 100644 --- a/tavern/internal/redirectors/dns/dns_test.go +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -3,20 +3,29 @@ package dns import ( "context" "encoding/base32" + "fmt" "hash/crc32" "net" "sort" "sync" - "sync/atomic" "testing" "time" + lru "github.com/hashicorp/golang-lru/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "realm.pub/tavern/internal/c2/dnspb" ) +func newTestRedirector() *Redirector { + cache, _ := lru.New[string, *Conversation](MaxActiveConversations) + return &Redirector{ + conversations: cache, + conversationTimeout: NormalConversationTimeout, + } +} + // TestParseListenAddr tests the ParseListenAddr function func TestParseListenAddr(t *testing.T) { tests := []struct { @@ -81,9 +90,8 @@ func TestParseListenAddr(t *testing.T) { // TestExtractSubdomain tests subdomain extraction from full domain names func TestExtractSubdomain(t *testing.T) { - r := &Redirector{ - baseDomains: []string{"dnsc2.realm.pub", "foo.bar.com"}, - } + r := newTestRedirector() + r.baseDomains = []string{"dnsc2.realm.pub", "foo.bar.com"} tests := []struct { name string @@ -140,7 +148,7 @@ func TestExtractSubdomain(t *testing.T) { // TestDecodePacket tests Base32 decoding and protobuf unmarshaling func TestDecodePacket(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() t.Run("valid INIT packet", func(t *testing.T) { packet := &dnspb.DNSPacket{ @@ -232,7 +240,7 @@ func TestDecodePacket(t *testing.T) { // TestComputeAcksNacks tests the ACK range and NACK computation func TestComputeAcksNacks(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() tests := []struct { name string @@ -320,7 +328,7 @@ func TestComputeAcksNacks(t *testing.T) { // TestHandleInitPacket tests INIT packet processing func TestHandleInitPacket(t *testing.T) { t.Run("valid init packet", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -349,9 +357,8 @@ func TestHandleInitPacket(t *testing.T) { assert.Equal(t, "conv1234", statusPacket.ConversationId) // Verify conversation was created - val, ok := r.conversations.Load("conv1234") + conv, ok := r.conversations.Get("conv1234") require.True(t, ok) - conv := val.(*Conversation) assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) assert.Equal(t, uint32(5), conv.TotalChunks) assert.Equal(t, uint32(0x12345678), conv.ExpectedCRC) @@ -359,7 +366,7 @@ func TestHandleInitPacket(t *testing.T) { }) t.Run("invalid init payload", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_INIT, @@ -372,7 +379,7 @@ func TestHandleInitPacket(t *testing.T) { }) t.Run("data size exceeds maximum", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -393,9 +400,12 @@ func TestHandleInitPacket(t *testing.T) { assert.Contains(t, err.Error(), "exceeds maximum") }) - t.Run("max conversations reached", func(t *testing.T) { + t.Run("max conversations triggers LRU eviction", func(t *testing.T) { + // Create a small LRU to test eviction + cache, _ := lru.New[string, *Conversation](2) r := &Redirector{ - conversationCount: MaxActiveConversations, + conversations: cache, + conversationTimeout: NormalConversationTimeout, } initPayload := &dnspb.InitPayload{ @@ -405,19 +415,40 @@ func TestHandleInitPacket(t *testing.T) { payloadBytes, err := proto.Marshal(initPayload) require.NoError(t, err) + // Fill the LRU + for i := 0; i < 2; i++ { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: fmt.Sprintf("conv%d", i), + Data: payloadBytes, + } + _, err = r.handleInitPacket(packet) + require.NoError(t, err) + } + + assert.Equal(t, 2, r.conversations.Len()) + + // Third conversation should evict the oldest packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_INIT, - ConversationId: "conv1234", + ConversationId: "conv2", Data: payloadBytes, } - _, err = r.handleInitPacket(packet) - assert.Error(t, err) - assert.Contains(t, err.Error(), "max active conversations") + require.NoError(t, err) + + // LRU should still be at capacity (oldest evicted) + assert.Equal(t, 2, r.conversations.Len()) + // conv0 should have been evicted + _, ok := r.conversations.Get("conv0") + assert.False(t, ok, "oldest conversation should be evicted") + // conv2 (newest) should exist + _, ok = r.conversations.Get("conv2") + assert.True(t, ok) }) - t.Run("duplicate INIT returns status without counter leak", func(t *testing.T) { - r := &Redirector{} + t.Run("duplicate INIT returns status without leaking state", func(t *testing.T) { + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -438,7 +469,7 @@ func TestHandleInitPacket(t *testing.T) { resp1, err := r.handleInitPacket(packet) require.NoError(t, err) require.NotNil(t, resp1) - assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount)) + assert.Equal(t, 1, r.conversations.Len()) // Verify first response is STATUS var status1 dnspb.DNSPacket @@ -452,7 +483,7 @@ func TestHandleInitPacket(t *testing.T) { require.NotNil(t, resp2) // Counter should NOT increment (no leak) - assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount), "duplicate INIT should not increment counter") + assert.Equal(t, 1, r.conversations.Len(), "duplicate INIT should not create new conversation") // Verify duplicate response is also STATUS var status2 dnspb.DNSPacket @@ -462,15 +493,14 @@ func TestHandleInitPacket(t *testing.T) { assert.Equal(t, "dupinit1234", status2.ConversationId) // Conversation should still exist and be unchanged - val, ok := r.conversations.Load("dupinit1234") + conv, ok := r.conversations.Get("dupinit1234") require.True(t, ok) - conv := val.(*Conversation) assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) assert.Equal(t, uint32(3), conv.TotalChunks) }) t.Run("concurrent duplicate INITs from resolvers", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -499,11 +529,11 @@ func TestHandleInitPacket(t *testing.T) { } wg.Wait() - // Counter should be exactly 1 (no leaks) - assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount), "concurrent INITs should not cause counter leak") + // Should have exactly 1 conversation (no duplicates) + assert.Equal(t, 1, r.conversations.Len(), "concurrent INITs should not create duplicates") // Conversation should exist - _, ok := r.conversations.Load("concurrent-init") + _, ok := r.conversations.Get("concurrent-init") assert.True(t, ok) }) } @@ -511,7 +541,7 @@ func TestHandleInitPacket(t *testing.T) { // TestHandleFetchPacket tests FETCH packet processing func TestHandleFetchPacket(t *testing.T) { t.Run("fetch single response", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() responseData := []byte("test response data") conv := &Conversation{ @@ -519,7 +549,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseData: responseData, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -532,7 +562,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch chunked response metadata", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() responseData := []byte("full response") responseCRC := crc32.ChecksumIEEE(responseData) @@ -543,7 +573,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseCRC: responseCRC, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -562,7 +592,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch specific chunk", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "conv1234", @@ -570,7 +600,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseChunks: [][]byte{[]byte("chunk0"), []byte("chunk1"), []byte("chunk2")}, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) fetchPayload := &dnspb.FetchPayload{ChunkIndex: 2} // 1-indexed payloadBytes, err := proto.Marshal(fetchPayload) @@ -588,7 +618,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch unknown conversation", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -600,65 +630,47 @@ func TestHandleFetchPacket(t *testing.T) { assert.Contains(t, err.Error(), "conversation not found") }) - t.Run("fetch on failed conversation returns empty response", func(t *testing.T) { - r := &Redirector{} - - conv := &Conversation{ - ID: "failconv", - ResponseData: []byte{}, - Failed: true, - LastActivity: time.Now(), - } - r.conversations.Store("failconv", conv) + t.Run("fetch on failed conversation returns not found", func(t *testing.T) { + r := newTestRedirector() + // Failed conversations are immediately removed from the cache, + // so a FETCH should get "conversation not found" packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, ConversationId: "failconv", } - data, err := r.handleFetchPacket(packet) - require.NoError(t, err) - assert.Equal(t, []byte{}, data) + _, err := r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") }) - t.Run("fetch on failed conversation does not spam errors", func(t *testing.T) { - r := &Redirector{} - - conv := &Conversation{ - ID: "failconv2", - ResponseData: []byte{}, - Failed: true, - LastActivity: time.Now(), - } - r.conversations.Store("failconv2", conv) + t.Run("fetch after failure returns not found consistently", func(t *testing.T) { + r := newTestRedirector() - // Multiple FETCH requests should all succeed (no error) instead of - // returning "conversation not found" after deletion + // Multiple FETCH requests for a removed conversation should all + // consistently return "conversation not found" for i := 0; i < 10; i++ { packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, ConversationId: "failconv2", } - data, err := r.handleFetchPacket(packet) - require.NoError(t, err, "FETCH attempt %d should not error", i) - assert.Equal(t, []byte{}, data) + _, err := r.handleFetchPacket(packet) + assert.Error(t, err, "FETCH attempt %d should error with not found", i) + assert.Contains(t, err.Error(), "conversation not found") } - - // Conversation should still exist in the map (not deleted) - _, ok := r.conversations.Load("failconv2") - assert.True(t, ok, "failed conversation should remain in map for cleanup") }) t.Run("fetch with no response ready returns empty (upstream in progress)", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "conv1234", ResponseData: nil, // upstream call still in progress LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -672,7 +684,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch chunk out of bounds", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "conv1234", @@ -680,7 +692,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseChunks: [][]byte{[]byte("chunk0")}, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) fetchPayload := &dnspb.FetchPayload{ChunkIndex: 10} // Out of bounds payloadBytes, err := proto.Marshal(fetchPayload) @@ -701,7 +713,7 @@ func TestHandleFetchPacket(t *testing.T) { // TestHandleCompletePacket tests COMPLETE packet processing func TestHandleCompletePacket(t *testing.T) { t.Run("successful complete cleans up conversation", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "complete1234", @@ -709,8 +721,7 @@ func TestHandleCompletePacket(t *testing.T) { ResponseData: []byte("response"), LastActivity: time.Now(), } - r.conversations.Store("complete1234", conv) - atomic.StoreInt32(&r.conversationCount, 1) + r.conversations.Add("complete1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, @@ -729,15 +740,12 @@ func TestHandleCompletePacket(t *testing.T) { assert.Equal(t, "complete1234", statusPacket.ConversationId) // Verify conversation was removed - _, ok := r.conversations.Load("complete1234") + _, ok := r.conversations.Get("complete1234") assert.False(t, ok, "conversation should be removed after COMPLETE") - - // Verify counter decremented - assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount)) }) t.Run("duplicate COMPLETE returns success idempotently", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "dupcomp1234", @@ -745,8 +753,7 @@ func TestHandleCompletePacket(t *testing.T) { ResponseData: []byte("response"), LastActivity: time.Now(), } - r.conversations.Store("dupcomp1234", conv) - atomic.StoreInt32(&r.conversationCount, 1) + r.conversations.Add("dupcomp1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, @@ -757,10 +764,9 @@ func TestHandleCompletePacket(t *testing.T) { resp1, err := r.handleCompletePacket(packet) require.NoError(t, err) require.NotNil(t, resp1) - assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount)) // Verify conversation removed - _, ok := r.conversations.Load("dupcomp1234") + _, ok := r.conversations.Get("dupcomp1234") assert.False(t, ok) // Second COMPLETE (duplicate from resolver) should succeed, not error @@ -768,9 +774,6 @@ func TestHandleCompletePacket(t *testing.T) { require.NoError(t, err, "duplicate COMPLETE should not error") require.NotNil(t, resp2) - // Counter should still be 0 (no double-decrement) - assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount), "duplicate COMPLETE should not double-decrement") - // Verify response is also STATUS var status dnspb.DNSPacket err = proto.Unmarshal(resp2, &status) @@ -779,7 +782,7 @@ func TestHandleCompletePacket(t *testing.T) { }) t.Run("COMPLETE for never-existed conversation returns success", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, @@ -795,13 +798,10 @@ func TestHandleCompletePacket(t *testing.T) { err = proto.Unmarshal(responseData, &statusPacket) require.NoError(t, err) assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) - - // Counter should remain unchanged - assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount)) }) t.Run("concurrent COMPLETEs from resolvers", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "concurrent-complete", @@ -809,8 +809,7 @@ func TestHandleCompletePacket(t *testing.T) { ResponseData: []byte("response"), LastActivity: time.Now(), } - r.conversations.Store("concurrent-complete", conv) - atomic.StoreInt32(&r.conversationCount, 1) + r.conversations.Add("concurrent-complete", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, @@ -829,18 +828,15 @@ func TestHandleCompletePacket(t *testing.T) { } wg.Wait() - // Counter should be exactly 0 (no negative values from double-decrement) - assert.Equal(t, int32(0), atomic.LoadInt32(&r.conversationCount), "concurrent COMPLETEs should not cause counter underflow") - // Conversation should be removed - _, ok := r.conversations.Load("concurrent-complete") + _, ok := r.conversations.Get("concurrent-complete") assert.False(t, ok) }) } // TestParseDomainNameAndType tests DNS query parsing func TestParseDomainNameAndType(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() tests := []struct { name string @@ -928,60 +924,57 @@ func TestParseDomainNameAndType(t *testing.T) { // TestConversationCleanup tests cleanup of stale conversations func TestConversationCleanup(t *testing.T) { - r := &Redirector{ - conversationTimeout: 15 * time.Minute, - } + r := newTestRedirector() // Create stale conversation staleConv := &Conversation{ ID: "stale", LastActivity: time.Now().Add(-20 * time.Minute), } - r.conversations.Store("stale", staleConv) - atomic.StoreInt32(&r.conversationCount, 1) + r.conversations.Add("stale", staleConv) // Create fresh conversation freshConv := &Conversation{ ID: "fresh", LastActivity: time.Now(), } - r.conversations.Store("fresh", freshConv) - atomic.StoreInt32(&r.conversationCount, 2) + r.conversations.Add("fresh", freshConv) // Run cleanup (mirrors cleanupConversations logic) now := time.Now() - r.conversations.Range(func(key, value any) bool { - conv := value.(*Conversation) + for _, key := range r.conversations.Keys() { + conv, ok := r.conversations.Peek(key) + if !ok { + continue + } conv.mu.Lock() - timeout := r.conversationTimeout - if conv.ResponseServed || conv.Failed { - timeout = ServedConversationTimeout + shouldDelete := false + if conv.ResponseServed { + shouldDelete = now.Sub(conv.LastActivity) > ServedConversationTimeout + } else { + shouldDelete = now.Sub(conv.LastActivity) > r.conversationTimeout } - if now.Sub(conv.LastActivity) > timeout { - r.conversations.Delete(key) - atomic.AddInt32(&r.conversationCount, -1) + if shouldDelete { + r.conversations.Remove(key) } conv.mu.Unlock() - return true - }) + } // Verify stale was removed - _, ok := r.conversations.Load("stale") + _, ok := r.conversations.Get("stale") assert.False(t, ok, "stale conversation should be removed") // Verify fresh remains - _, ok = r.conversations.Load("fresh") + _, ok = r.conversations.Get("fresh") assert.True(t, ok, "fresh conversation should remain") - assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount)) + assert.Equal(t, 1, r.conversations.Len()) } // TestServedConversationCleanup tests that conversations with ResponseServed -// or Failed flags are cleaned up with the shorter ServedConversationTimeout +// flag are cleaned up with the shorter ServedConversationTimeout. func TestServedConversationCleanup(t *testing.T) { - r := &Redirector{ - conversationTimeout: 15 * time.Minute, - } + r := newTestRedirector() // Served conversation: response was delivered 3 minutes ago (> 2 min served timeout) servedConv := &Conversation{ @@ -989,59 +982,49 @@ func TestServedConversationCleanup(t *testing.T) { ResponseServed: true, LastActivity: time.Now().Add(-3 * time.Minute), } - r.conversations.Store("served", servedConv) + r.conversations.Add("served", servedConv) - // Failed conversation: failed 3 minutes ago (> 2 min served timeout) - failedConv := &Conversation{ - ID: "failed", - Failed: true, - LastActivity: time.Now().Add(-3 * time.Minute), - } - r.conversations.Store("failed", failedConv) - - // In-progress conversation: 3 minutes old but not served/failed (< 15 min normal timeout) + // In-progress conversation: 3 minutes old but not served (< 15 min normal timeout) activeConv := &Conversation{ ID: "active", LastActivity: time.Now().Add(-3 * time.Minute), } - r.conversations.Store("active", activeConv) - - atomic.StoreInt32(&r.conversationCount, 3) + r.conversations.Add("active", activeConv) // Run cleanup now := time.Now() - r.conversations.Range(func(key, value any) bool { - conv := value.(*Conversation) + for _, key := range r.conversations.Keys() { + conv, ok := r.conversations.Peek(key) + if !ok { + continue + } conv.mu.Lock() - timeout := r.conversationTimeout - if conv.ResponseServed || conv.Failed { - timeout = ServedConversationTimeout + shouldDelete := false + if conv.ResponseServed { + shouldDelete = now.Sub(conv.LastActivity) > ServedConversationTimeout + } else { + shouldDelete = now.Sub(conv.LastActivity) > r.conversationTimeout } - if now.Sub(conv.LastActivity) > timeout { - r.conversations.Delete(key) - atomic.AddInt32(&r.conversationCount, -1) + if shouldDelete { + r.conversations.Remove(key) } conv.mu.Unlock() - return true - }) + } - // Served and failed conversations should be cleaned (3 min > 2 min served timeout) - _, ok := r.conversations.Load("served") + // Served conversation should be cleaned (3 min > 2 min served timeout) + _, ok := r.conversations.Get("served") assert.False(t, ok, "served conversation should be cleaned up with shorter timeout") - _, ok = r.conversations.Load("failed") - assert.False(t, ok, "failed conversation should be cleaned up with shorter timeout") - // Active conversation should remain (3 min < 15 min normal timeout) - _, ok = r.conversations.Load("active") + _, ok = r.conversations.Get("active") assert.True(t, ok, "in-progress conversation should remain") - assert.Equal(t, int32(1), atomic.LoadInt32(&r.conversationCount)) + assert.Equal(t, 1, r.conversations.Len()) } // TestConcurrentConversationAccess tests thread safety of conversation handling func TestConcurrentConversationAccess(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -1068,11 +1051,10 @@ func TestConcurrentConversationAccess(t *testing.T) { go func(seq uint32) { defer wg.Done() - val, ok := r.conversations.Load("concurrent") + conv, ok := r.conversations.Get("concurrent") if !ok { return } - conv := val.(*Conversation) conv.mu.Lock() conv.Chunks[seq] = []byte{byte(seq)} conv.mu.Unlock() @@ -1081,15 +1063,14 @@ func TestConcurrentConversationAccess(t *testing.T) { wg.Wait() // Verify all chunks stored - val, ok := r.conversations.Load("concurrent") + conv, ok := r.conversations.Get("concurrent") require.True(t, ok) - conv := val.(*Conversation) assert.Len(t, conv.Chunks, 100) } // TestBuildDNSResponse tests DNS response packet construction func TestBuildDNSResponse(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() // Create a mock UDP connection for testing serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") @@ -1127,7 +1108,7 @@ func TestBuildDNSResponse(t *testing.T) { // TestHandleDataPacket tests DATA packet processing and chunk storage func TestHandleDataPacket(t *testing.T) { t.Run("store single chunk", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create conversation first with INIT - set TotalChunks > 1 to avoid completion @@ -1167,15 +1148,14 @@ func TestHandleDataPacket(t *testing.T) { assert.Equal(t, "data1234", statusPacket.ConversationId) // Verify chunk was stored - val, ok := r.conversations.Load("data1234") + conv, ok := r.conversations.Get("data1234") require.True(t, ok) - conv := val.(*Conversation) assert.Len(t, conv.Chunks, 1) assert.Equal(t, []byte{0x01}, conv.Chunks[1]) }) t.Run("store multiple chunks with gaps", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create conversation @@ -1219,9 +1199,8 @@ func TestHandleDataPacket(t *testing.T) { } // Verify chunks stored - val, ok := r.conversations.Load("gaps1234") + conv, ok := r.conversations.Get("gaps1234") require.True(t, ok) - conv := val.(*Conversation) assert.Len(t, conv.Chunks, 3) assert.Equal(t, []byte{1}, conv.Chunks[1]) assert.Equal(t, []byte{3}, conv.Chunks[3]) @@ -1230,7 +1209,7 @@ func TestHandleDataPacket(t *testing.T) { }) t.Run("unknown conversation", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() dataPacket := &dnspb.DNSPacket{ @@ -1246,7 +1225,7 @@ func TestHandleDataPacket(t *testing.T) { }) t.Run("sequence out of bounds", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create conversation @@ -1280,7 +1259,7 @@ func TestHandleDataPacket(t *testing.T) { }) t.Run("short-circuit for completed conversation", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create a conversation that is already completed @@ -1295,7 +1274,7 @@ func TestHandleDataPacket(t *testing.T) { }, LastActivity: time.Now(), } - r.conversations.Store("completed1", conv) + r.conversations.Add("completed1", conv) // Send a duplicate DATA to completed conversation dataPacket := &dnspb.DNSPacket{ @@ -1322,19 +1301,12 @@ func TestHandleDataPacket(t *testing.T) { assert.Equal(t, []byte{0x01}, conv.Chunks[1]) }) - t.Run("short-circuit for failed conversation", func(t *testing.T) { - r := &Redirector{} + t.Run("data packet for removed conversation returns not found", func(t *testing.T) { + r := newTestRedirector() ctx := context.Background() - conv := &Conversation{ - ID: "failed1", - TotalChunks: 2, - Failed: true, - Chunks: map[uint32][]byte{1: nil, 2: nil}, - LastActivity: time.Now(), - } - r.conversations.Store("failed1", conv) - + // Failed conversations are immediately removed from the cache, + // so a DATA packet should get "conversation not found" dataPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_DATA, ConversationId: "failed1", @@ -1342,16 +1314,9 @@ func TestHandleDataPacket(t *testing.T) { Data: []byte{0x01}, } - statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) - require.NoError(t, err) - - var statusPacket dnspb.DNSPacket - err = proto.Unmarshal(statusData, &statusPacket) - require.NoError(t, err) - assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) - require.Len(t, statusPacket.Acks, 1) - assert.Equal(t, uint32(1), statusPacket.Acks[0].StartSeq) - assert.Equal(t, uint32(2), statusPacket.Acks[0].EndSeq) + _, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") }) } @@ -1419,7 +1384,7 @@ func TestProcessCompletedConversation(t *testing.T) { // TestConversationNotFoundError verifies the error message for missing conversations func TestConversationNotFoundError(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() t.Run("DATA returns conversation not found error", func(t *testing.T) { packet := &dnspb.DNSPacket{ @@ -1448,67 +1413,66 @@ func TestConversationNotFoundError(t *testing.T) { }) } -// TestActiveHandlersCounter verifies atomic counter operations work correctly -func TestActiveHandlersCounter(t *testing.T) { - t.Run("starts at zero", func(t *testing.T) { - r := &Redirector{} - assert.Equal(t, int32(0), atomic.LoadInt32(&r.activeHandlers)) - }) - - t.Run("concurrent increment and decrement", func(t *testing.T) { - r := &Redirector{} - - // Simulate concurrent handler goroutines incrementing and decrementing - var wg sync.WaitGroup - iterations := 100 +// TestLRUEvictionBehavior verifies the LRU cache evicts oldest conversations when full +func TestLRUEvictionBehavior(t *testing.T) { + t.Run("evicts oldest when at capacity", func(t *testing.T) { + cache, _ := lru.New[string, *Conversation](3) + r := &Redirector{ + conversations: cache, + conversationTimeout: NormalConversationTimeout, + } - for i := 0; i < iterations; i++ { - wg.Add(1) - go func() { - defer wg.Done() - // Simulate handler lifecycle: increment at start, decrement at end - atomic.AddInt32(&r.activeHandlers, 1) - time.Sleep(time.Microsecond) // Small delay to increase contention - atomic.AddInt32(&r.activeHandlers, -1) - }() + // Add 3 conversations to fill capacity + for i := 0; i < 3; i++ { + r.conversations.Add(fmt.Sprintf("conv%d", i), &Conversation{ + ID: fmt.Sprintf("conv%d", i), + LastActivity: time.Now(), + }) } + assert.Equal(t, 3, r.conversations.Len()) - wg.Wait() + // Adding a 4th should evict the oldest (conv0) + r.conversations.Add("conv3", &Conversation{ + ID: "conv3", + LastActivity: time.Now(), + }) + assert.Equal(t, 3, r.conversations.Len()) - // After all handlers complete, counter should be back to zero - assert.Equal(t, int32(0), atomic.LoadInt32(&r.activeHandlers), "counter should return to zero after all handlers complete") - }) + _, ok := r.conversations.Get("conv0") + assert.False(t, ok, "oldest conversation should be evicted") - t.Run("peak tracking under load", func(t *testing.T) { - r := &Redirector{} + _, ok = r.conversations.Get("conv3") + assert.True(t, ok, "newest conversation should exist") + }) - var peak int32 - var peakMu sync.Mutex - var wg sync.WaitGroup + t.Run("Get refreshes recency", func(t *testing.T) { + cache, _ := lru.New[string, *Conversation](3) + r := &Redirector{ + conversations: cache, + conversationTimeout: NormalConversationTimeout, + } - // Start handlers that overlap in time - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - defer wg.Done() - current := atomic.AddInt32(&r.activeHandlers, 1) + // Add 3 conversations + for i := 0; i < 3; i++ { + r.conversations.Add(fmt.Sprintf("conv%d", i), &Conversation{ + ID: fmt.Sprintf("conv%d", i), + LastActivity: time.Now(), + }) + } - peakMu.Lock() - if current > peak { - peak = current - } - peakMu.Unlock() + // Access conv0 to refresh its recency + r.conversations.Get("conv0") - time.Sleep(time.Millisecond) - atomic.AddInt32(&r.activeHandlers, -1) - }() - } + // Adding conv3 should evict conv1 (now the oldest) instead of conv0 + r.conversations.Add("conv3", &Conversation{ + ID: "conv3", + LastActivity: time.Now(), + }) - wg.Wait() + _, ok := r.conversations.Get("conv0") + assert.True(t, ok, "conv0 should survive due to recent Get") - // Peak should be > 1 (some concurrency achieved) - assert.Greater(t, peak, int32(1), "peak should show concurrent handlers") - // Final value should be zero - assert.Equal(t, int32(0), atomic.LoadInt32(&r.activeHandlers)) + _, ok = r.conversations.Get("conv1") + assert.False(t, ok, "conv1 should be evicted as the oldest") }) }