diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 5b704c683..26080d42e 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -15,9 +15,9 @@ import ( "sort" "strings" "sync" - "sync/atomic" "time" + lru "github.com/hashicorp/golang-lru/v2" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" @@ -26,12 +26,10 @@ import ( ) const ( - convTimeout = 15 * time.Minute defaultUDPPort = "53" // DNS protocol constants dnsHeaderSize = 12 - maxLabelLength = 63 txtRecordType = 16 aRecordType = 1 aaaaRecordType = 28 @@ -51,24 +49,29 @@ const ( benignARecordIP = "0.0.0.0" // Async protocol configuration - MaxActiveConversations = 200000 - NormalConversationTimeout = 15 * time.Minute - ReducedConversationTimeout = 5 * time.Minute - CapacityRecoveryThreshold = 0.5 // 50% - MaxAckRangesInResponse = 20 - MaxNacksInResponse = 50 - MaxDataSize = 50 * 1024 * 1024 // 50MB max data size + MaxActiveConversations = 200000 + NormalConversationTimeout = 15 * time.Minute + ServedConversationTimeout = 2 * time.Minute + MaxAckRangesInResponse = 20 + MaxNacksInResponse = 50 + MaxDataSize = 50 * 1024 * 1024 // 50MB max data size ) func init() { - redirectors.Register("dns", &Redirector{}) + cache, err := lru.New[string, *Conversation](MaxActiveConversations) + if err != nil { + slog.Error("dns redirector: failed to create conversation cache") + } + redirectors.Register("dns", &Redirector{ + conversations: cache, + }) } // Redirector handles DNS-based C2 communication type Redirector struct { - conversations sync.Map + conversationsMu sync.Mutex + conversations *lru.Cache[string, *Conversation] baseDomains []string - conversationCount int32 conversationTimeout time.Duration } @@ -86,6 +89,7 @@ type Conversation struct { ResponseChunks [][]byte // Split response for multi-fetch ResponseCRC uint32 Completed bool // Set to true when all chunks received + ResponseServed bool // Set after first successful FETCH } func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn, _ *tls.Config) error { @@ -133,10 +137,11 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr continue } - // Process query synchronously + // Copy query data before passing to goroutine queryCopy := make([]byte, n) copy(queryCopy, buf[:n]) - r.handleDNSQuery(ctx, conn, addr, queryCopy, upstream) + + go r.handleDNSQuery(ctx, conn, addr, queryCopy, upstream) } } } @@ -188,25 +193,33 @@ func (r *Redirector) cleanupConversations(ctx context.Context) { return case <-ticker.C: now := time.Now() - count := atomic.LoadInt32(&r.conversationCount) - - // Adjust timeout based on capacity - if count >= MaxActiveConversations { - r.conversationTimeout = ReducedConversationTimeout - } else if float64(count) < float64(MaxActiveConversations)*CapacityRecoveryThreshold { - r.conversationTimeout = NormalConversationTimeout - } + var cleaned int - r.conversations.Range(func(key, value interface{}) bool { - conv := value.(*Conversation) + for _, key := range r.conversations.Keys() { + conv, ok := r.conversations.Peek(key) + if !ok { + continue + } conv.mu.Lock() - if now.Sub(conv.LastActivity) > r.conversationTimeout { - r.conversations.Delete(key) - atomic.AddInt32(&r.conversationCount, -1) + + shouldDelete := false + if conv.ResponseServed { + shouldDelete = now.Sub(conv.LastActivity) > ServedConversationTimeout + } else { + shouldDelete = now.Sub(conv.LastActivity) > r.conversationTimeout + } + + if shouldDelete { + r.conversations.Remove(key) + cleaned++ } conv.mu.Unlock() - return true - }) + } + + slog.Info("dns redirector stats", + "active_conversations", r.conversations.Len(), + "cleaned", cleaned, + "timeout", r.conversationTimeout) } } } @@ -227,7 +240,6 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr domain = strings.ToLower(domain) - slog.Info("dns redirector: request", "source", addr.String(), "destination", domain) slog.Debug("dns redirector: query details", "domain", domain, "query_type", queryType, "source", addr.String()) // Extract subdomain @@ -301,7 +313,12 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr } if err != nil { - slog.Error("dns redirector: upstream request failed", "type", packet.Type, "conv_id", packet.ConversationId, "source", addr.String(), "error", err) + if strings.Contains(err.Error(), "conversation not found") { + slog.Debug("packet for unknown conversation", + "type", packet.Type, "conv_id", packet.ConversationId) + } else { + slog.Error("dns redirector: upstream request failed", "type", packet.Type, "conv_id", packet.ConversationId, "source", addr.String(), "error", err) + } r.sendErrorResponse(conn, addr, transactionID) return } @@ -363,25 +380,13 @@ func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { // handleInitPacket processes INIT packet func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { - for { - current := atomic.LoadInt32(&r.conversationCount) - if current >= MaxActiveConversations { - return nil, fmt.Errorf("max active conversations reached: %d", current) - } - if atomic.CompareAndSwapInt32(&r.conversationCount, current, current+1) { - break - } - } - var initPayload dnspb.InitPayload if err := proto.Unmarshal(packet.Data, &initPayload); err != nil { - atomic.AddInt32(&r.conversationCount, -1) return nil, fmt.Errorf("failed to unmarshal init payload: %w", err) } // Validate file size from client if initPayload.FileSize > MaxDataSize { - atomic.AddInt32(&r.conversationCount, -1) return nil, fmt.Errorf("data size exceeds maximum: %d > %d bytes", initPayload.FileSize, MaxDataSize) } @@ -389,9 +394,6 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { slog.Warn("INIT packet missing file_size field", "conv_id", packet.ConversationId, "total_chunks", initPayload.TotalChunks) } - slog.Debug("creating conversation", "conv_id", packet.ConversationId, "method", initPayload.MethodCode, - "total_chunks", initPayload.TotalChunks, "file_size", initPayload.FileSize, "crc32", initPayload.DataCrc32) - conv := &Conversation{ ID: packet.ConversationId, MethodPath: initPayload.MethodCode, @@ -403,7 +405,40 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { Completed: false, } - r.conversations.Store(packet.ConversationId, conv) + // Use conversationsMu to atomically check-then-store, handling duplicate INITs + // from DNS recursive resolvers idempotently. + // DNS recursive resolvers may forward the same query from multiple nodes, + // causing duplicate INIT packets. Thanks AWS. + r.conversationsMu.Lock() + if existing, ok := r.conversations.Get(packet.ConversationId); ok { + r.conversationsMu.Unlock() + + existing.mu.Lock() + defer existing.mu.Unlock() + existing.LastActivity = time.Now() + + slog.Debug("duplicate INIT for existing conversation", "conv_id", packet.ConversationId) + + acks, nacks := r.computeAcksNacks(existing) + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: acks, + Nacks: nacks, + } + statusData, err := proto.Marshal(statusPacket) + if err != nil { + return nil, fmt.Errorf("failed to marshal duplicate init status: %w", err) + } + return statusData, nil + } + + evicted := r.conversations.Add(packet.ConversationId, conv) + r.conversationsMu.Unlock() + + if evicted { + slog.Debug("LRU evicted oldest conversation to make room", "conv_id", conv.ID) + } slog.Debug("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", initPayload.FileSize) @@ -416,8 +451,7 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { } statusData, err := proto.Marshal(statusPacket) if err != nil { - atomic.AddInt32(&r.conversationCount, -1) - r.conversations.Delete(packet.ConversationId) + r.conversations.Remove(packet.ConversationId) return nil, fmt.Errorf("failed to marshal init status: %w", err) } return statusData, nil @@ -425,17 +459,30 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { // handleDataPacket processes DATA packet func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) + conv, ok := r.conversations.Get(packet.ConversationId) if !ok { - slog.Debug("DATA packet for unknown conversation (INIT may be lost/delayed)", - "conv_id", packet.ConversationId, "seq", packet.Sequence) return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() + // Once the conversation has been forwarded to upstream, return full ack range immediately. + if conv.Completed { + conv.LastActivity = time.Now() + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: []*dnspb.AckRange{{StartSeq: 1, EndSeq: conv.TotalChunks}}, + Nacks: []uint32{}, + } + statusData, err := proto.Marshal(statusPacket) + if err != nil { + return nil, fmt.Errorf("failed to marshal status packet: %w", err) + } + return statusData, nil + } + if packet.Sequence < 1 || packet.Sequence > conv.TotalChunks { return nil, fmt.Errorf("sequence out of bounds: %d (expected 1-%d)", packet.Sequence, conv.TotalChunks) } @@ -445,7 +492,7 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client slog.Debug("received chunk", "conv_id", conv.ID, "seq", packet.Sequence, "size", len(packet.Data), "total", len(conv.Chunks)) - if uint32(len(conv.Chunks)) == conv.TotalChunks && !conv.Completed { + if uint32(len(conv.Chunks)) == conv.TotalChunks { conv.Completed = true slog.Debug("C2 request complete, forwarding to upstream", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", conv.ExpectedDataSize) @@ -481,6 +528,9 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream // Reassemble data var fullData []byte + if conv.ExpectedDataSize > 0 { + fullData = make([]byte, 0, conv.ExpectedDataSize) + } for i := uint32(1); i <= conv.TotalChunks; i++ { chunk, ok := conv.Chunks[i] if !ok { @@ -489,25 +539,26 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream fullData = append(fullData, chunk...) } + for seq := range conv.Chunks { + conv.Chunks[seq] = nil + } + actualCRC := crc32.ChecksumIEEE(fullData) if actualCRC != conv.ExpectedCRC { - r.conversations.Delete(conv.ID) - atomic.AddInt32(&r.conversationCount, -1) + r.conversations.Remove(conv.ID) return fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) } slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) if conv.ExpectedDataSize > 0 && uint32(len(fullData)) != conv.ExpectedDataSize { - r.conversations.Delete(conv.ID) - atomic.AddInt32(&r.conversationCount, -1) + r.conversations.Remove(conv.ID) return fmt.Errorf("reassembled data size mismatch: expected %d bytes, got %d bytes", conv.ExpectedDataSize, len(fullData)) } responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) if err != nil { - r.conversations.Delete(conv.ID) - atomic.AddInt32(&r.conversationCount, -1) + r.conversations.Remove(conv.ID) return fmt.Errorf("failed to forward to upstream: %w", err) } @@ -607,17 +658,19 @@ func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, [] // handleFetchPacket processes FETCH packet func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) + conv, ok := r.conversations.Get(packet.ConversationId) if !ok { return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() if conv.ResponseData == nil { - return nil, fmt.Errorf("no response data available") + // Response not ready yet, the upstream gRPC call is still in progress. + slog.Debug("response not ready yet - upstream call in progress", "conv_id", conv.ID) + conv.LastActivity = time.Now() + return []byte{}, nil } conv.LastActivity = time.Now() @@ -654,32 +707,19 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) slog.Debug("returning response chunk", "conv_id", conv.ID, "chunk", fetchPayload.ChunkIndex, "size", len(conv.ResponseChunks[chunkIndex]), "total_chunks", len(conv.ResponseChunks)) + conv.ResponseServed = true return conv.ResponseChunks[chunkIndex], nil } slog.Debug("returning response", "conv_id", conv.ID, "size", len(conv.ResponseData)) + conv.ResponseServed = true return conv.ResponseData, nil } // handleCompletePacket processes COMPLETE packet and cleans up conversation func (r *Redirector) handleCompletePacket(packet *dnspb.DNSPacket) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) - if !ok { - return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) - } - - conv := val.(*Conversation) - conv.mu.Lock() - defer conv.mu.Unlock() - - slog.Debug("C2 conversation completed and confirmed by client", "conv_id", conv.ID, "method", conv.MethodPath) - - // Delete conversation and decrement counter - r.conversations.Delete(packet.ConversationId) - atomic.AddInt32(&r.conversationCount, -1) - - // Return empty success status + // Build success status statusPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_STATUS, ConversationId: packet.ConversationId, @@ -690,6 +730,25 @@ func (r *Redirector) handleCompletePacket(packet *dnspb.DNSPacket) ([]byte, erro if err != nil { return nil, fmt.Errorf("failed to marshal complete status: %w", err) } + + // Atomically check-then-delete to avoid race with concurrent COMPLETEs + r.conversationsMu.Lock() + conv, loaded := r.conversations.Get(packet.ConversationId) + if loaded { + r.conversations.Remove(packet.ConversationId) + } + r.conversationsMu.Unlock() + + if !loaded { + // Conversation already cleaned up by a prior COMPLETE - return success + slog.Debug("duplicate COMPLETE for already-completed conversation", "conv_id", packet.ConversationId) + return statusData, nil + } + + conv.mu.Lock() + slog.Debug("C2 conversation completed and confirmed by client", "conv_id", conv.ID, "method", conv.MethodPath) + conv.mu.Unlock() + return statusData, nil } diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go index 3bec7bc85..fbbc652cd 100644 --- a/tavern/internal/redirectors/dns/dns_test.go +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "encoding/base32" + "fmt" "hash/crc32" "net" "sort" @@ -10,12 +11,21 @@ import ( "testing" "time" + lru "github.com/hashicorp/golang-lru/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" "realm.pub/tavern/internal/c2/dnspb" ) +func newTestRedirector() *Redirector { + cache, _ := lru.New[string, *Conversation](MaxActiveConversations) + return &Redirector{ + conversations: cache, + conversationTimeout: NormalConversationTimeout, + } +} + // TestParseListenAddr tests the ParseListenAddr function func TestParseListenAddr(t *testing.T) { tests := []struct { @@ -80,9 +90,8 @@ func TestParseListenAddr(t *testing.T) { // TestExtractSubdomain tests subdomain extraction from full domain names func TestExtractSubdomain(t *testing.T) { - r := &Redirector{ - baseDomains: []string{"dnsc2.realm.pub", "foo.bar.com"}, - } + r := newTestRedirector() + r.baseDomains = []string{"dnsc2.realm.pub", "foo.bar.com"} tests := []struct { name string @@ -139,7 +148,7 @@ func TestExtractSubdomain(t *testing.T) { // TestDecodePacket tests Base32 decoding and protobuf unmarshaling func TestDecodePacket(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() t.Run("valid INIT packet", func(t *testing.T) { packet := &dnspb.DNSPacket{ @@ -231,7 +240,7 @@ func TestDecodePacket(t *testing.T) { // TestComputeAcksNacks tests the ACK range and NACK computation func TestComputeAcksNacks(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() tests := []struct { name string @@ -319,7 +328,7 @@ func TestComputeAcksNacks(t *testing.T) { // TestHandleInitPacket tests INIT packet processing func TestHandleInitPacket(t *testing.T) { t.Run("valid init packet", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -348,9 +357,8 @@ func TestHandleInitPacket(t *testing.T) { assert.Equal(t, "conv1234", statusPacket.ConversationId) // Verify conversation was created - val, ok := r.conversations.Load("conv1234") + conv, ok := r.conversations.Get("conv1234") require.True(t, ok) - conv := val.(*Conversation) assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) assert.Equal(t, uint32(5), conv.TotalChunks) assert.Equal(t, uint32(0x12345678), conv.ExpectedCRC) @@ -358,7 +366,7 @@ func TestHandleInitPacket(t *testing.T) { }) t.Run("invalid init payload", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_INIT, @@ -371,7 +379,7 @@ func TestHandleInitPacket(t *testing.T) { }) t.Run("data size exceeds maximum", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -392,9 +400,12 @@ func TestHandleInitPacket(t *testing.T) { assert.Contains(t, err.Error(), "exceeds maximum") }) - t.Run("max conversations reached", func(t *testing.T) { + t.Run("max conversations triggers LRU eviction", func(t *testing.T) { + // Create a small LRU to test eviction + cache, _ := lru.New[string, *Conversation](2) r := &Redirector{ - conversationCount: MaxActiveConversations, + conversations: cache, + conversationTimeout: NormalConversationTimeout, } initPayload := &dnspb.InitPayload{ @@ -404,22 +415,133 @@ func TestHandleInitPacket(t *testing.T) { payloadBytes, err := proto.Marshal(initPayload) require.NoError(t, err) + // Fill the LRU + for i := 0; i < 2; i++ { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: fmt.Sprintf("conv%d", i), + Data: payloadBytes, + } + _, err = r.handleInitPacket(packet) + require.NoError(t, err) + } + + assert.Equal(t, 2, r.conversations.Len()) + + // Third conversation should evict the oldest packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_INIT, - ConversationId: "conv1234", + ConversationId: "conv2", Data: payloadBytes, } - _, err = r.handleInitPacket(packet) - assert.Error(t, err) - assert.Contains(t, err.Error(), "max active conversations") + require.NoError(t, err) + + // LRU should still be at capacity (oldest evicted) + assert.Equal(t, 2, r.conversations.Len()) + // conv0 should have been evicted + _, ok := r.conversations.Get("conv0") + assert.False(t, ok, "oldest conversation should be evicted") + // conv2 (newest) should exist + _, ok = r.conversations.Get("conv2") + assert.True(t, ok) + }) + + t.Run("duplicate INIT returns status without leaking state", func(t *testing.T) { + r := newTestRedirector() + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 3, + DataCrc32: 0xDEADBEEF, + FileSize: 512, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "dupinit1234", + Data: payloadBytes, + } + + // First INIT creates conversation + resp1, err := r.handleInitPacket(packet) + require.NoError(t, err) + require.NotNil(t, resp1) + assert.Equal(t, 1, r.conversations.Len()) + + // Verify first response is STATUS + var status1 dnspb.DNSPacket + err = proto.Unmarshal(resp1, &status1) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, status1.Type) + + // Simulate duplicate INIT from DNS resolver + resp2, err := r.handleInitPacket(packet) + require.NoError(t, err) + require.NotNil(t, resp2) + + // Counter should NOT increment (no leak) + assert.Equal(t, 1, r.conversations.Len(), "duplicate INIT should not create new conversation") + + // Verify duplicate response is also STATUS + var status2 dnspb.DNSPacket + err = proto.Unmarshal(resp2, &status2) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, status2.Type) + assert.Equal(t, "dupinit1234", status2.ConversationId) + + // Conversation should still exist and be unchanged + conv, ok := r.conversations.Get("dupinit1234") + require.True(t, ok) + assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) + assert.Equal(t, uint32(3), conv.TotalChunks) + }) + + t.Run("concurrent duplicate INITs from resolvers", func(t *testing.T) { + r := newTestRedirector() + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 5, + DataCrc32: 0x12345678, + FileSize: 1024, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "concurrent-init", + Data: payloadBytes, + } + + // Simulate 10 concurrent INITs from different resolver nodes + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := r.handleInitPacket(packet) + assert.NoError(t, err) + }() + } + wg.Wait() + + // Should have exactly 1 conversation (no duplicates) + assert.Equal(t, 1, r.conversations.Len(), "concurrent INITs should not create duplicates") + + // Conversation should exist + _, ok := r.conversations.Get("concurrent-init") + assert.True(t, ok) }) } // TestHandleFetchPacket tests FETCH packet processing func TestHandleFetchPacket(t *testing.T) { t.Run("fetch single response", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() responseData := []byte("test response data") conv := &Conversation{ @@ -427,7 +549,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseData: responseData, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -440,7 +562,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch chunked response metadata", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() responseData := []byte("full response") responseCRC := crc32.ChecksumIEEE(responseData) @@ -451,7 +573,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseCRC: responseCRC, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -470,7 +592,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch specific chunk", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "conv1234", @@ -478,7 +600,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseChunks: [][]byte{[]byte("chunk0"), []byte("chunk1"), []byte("chunk2")}, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) fetchPayload := &dnspb.FetchPayload{ChunkIndex: 2} // 1-indexed payloadBytes, err := proto.Marshal(fetchPayload) @@ -496,7 +618,7 @@ func TestHandleFetchPacket(t *testing.T) { }) t.Run("fetch unknown conversation", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, @@ -508,28 +630,61 @@ func TestHandleFetchPacket(t *testing.T) { assert.Contains(t, err.Error(), "conversation not found") }) - t.Run("fetch with no response ready", func(t *testing.T) { - r := &Redirector{} + t.Run("fetch on failed conversation returns not found", func(t *testing.T) { + r := newTestRedirector() + + // Failed conversations are immediately removed from the cache, + // so a FETCH should get "conversation not found" + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "failconv", + } + + _, err := r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + }) + + t.Run("fetch after failure returns not found consistently", func(t *testing.T) { + r := newTestRedirector() + + // Multiple FETCH requests for a removed conversation should all + // consistently return "conversation not found" + for i := 0; i < 10; i++ { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "failconv2", + } + + _, err := r.handleFetchPacket(packet) + assert.Error(t, err, "FETCH attempt %d should error with not found", i) + assert.Contains(t, err.Error(), "conversation not found") + } + }) + + t.Run("fetch with no response ready returns empty (upstream in progress)", func(t *testing.T) { + r := newTestRedirector() conv := &Conversation{ ID: "conv1234", - ResponseData: nil, // No response yet + ResponseData: nil, // upstream call still in progress LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) packet := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_FETCH, ConversationId: "conv1234", } - _, err := r.handleFetchPacket(packet) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no response data") + // Should return empty response (not error) to avoid NXDOMAIN + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + assert.Equal(t, []byte{}, data) }) t.Run("fetch chunk out of bounds", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() conv := &Conversation{ ID: "conv1234", @@ -537,7 +692,7 @@ func TestHandleFetchPacket(t *testing.T) { ResponseChunks: [][]byte{[]byte("chunk0")}, LastActivity: time.Now(), } - r.conversations.Store("conv1234", conv) + r.conversations.Add("conv1234", conv) fetchPayload := &dnspb.FetchPayload{ChunkIndex: 10} // Out of bounds payloadBytes, err := proto.Marshal(fetchPayload) @@ -555,9 +710,133 @@ func TestHandleFetchPacket(t *testing.T) { }) } +// TestHandleCompletePacket tests COMPLETE packet processing +func TestHandleCompletePacket(t *testing.T) { + t.Run("successful complete cleans up conversation", func(t *testing.T) { + r := newTestRedirector() + + conv := &Conversation{ + ID: "complete1234", + MethodPath: "/c2.C2/ClaimTasks", + ResponseData: []byte("response"), + LastActivity: time.Now(), + } + r.conversations.Add("complete1234", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "complete1234", + } + + responseData, err := r.handleCompletePacket(packet) + require.NoError(t, err) + require.NotNil(t, responseData) + + // Verify response is STATUS + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(responseData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + assert.Equal(t, "complete1234", statusPacket.ConversationId) + + // Verify conversation was removed + _, ok := r.conversations.Get("complete1234") + assert.False(t, ok, "conversation should be removed after COMPLETE") + }) + + t.Run("duplicate COMPLETE returns success idempotently", func(t *testing.T) { + r := newTestRedirector() + + conv := &Conversation{ + ID: "dupcomp1234", + MethodPath: "/c2.C2/ClaimTasks", + ResponseData: []byte("response"), + LastActivity: time.Now(), + } + r.conversations.Add("dupcomp1234", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "dupcomp1234", + } + + // First COMPLETE removes conversation + resp1, err := r.handleCompletePacket(packet) + require.NoError(t, err) + require.NotNil(t, resp1) + + // Verify conversation removed + _, ok := r.conversations.Get("dupcomp1234") + assert.False(t, ok) + + // Second COMPLETE (duplicate from resolver) should succeed, not error + resp2, err := r.handleCompletePacket(packet) + require.NoError(t, err, "duplicate COMPLETE should not error") + require.NotNil(t, resp2) + + // Verify response is also STATUS + var status dnspb.DNSPacket + err = proto.Unmarshal(resp2, &status) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, status.Type) + }) + + t.Run("COMPLETE for never-existed conversation returns success", func(t *testing.T) { + r := newTestRedirector() + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "nonexistent", + } + + // Should succeed (not error) for idempotency + responseData, err := r.handleCompletePacket(packet) + require.NoError(t, err) + require.NotNil(t, responseData) + + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(responseData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + }) + + t.Run("concurrent COMPLETEs from resolvers", func(t *testing.T) { + r := newTestRedirector() + + conv := &Conversation{ + ID: "concurrent-complete", + MethodPath: "/c2.C2/ClaimTasks", + ResponseData: []byte("response"), + LastActivity: time.Now(), + } + r.conversations.Add("concurrent-complete", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_COMPLETE, + ConversationId: "concurrent-complete", + } + + // Simulate 10 concurrent COMPLETEs from different resolver nodes + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := r.handleCompletePacket(packet) + assert.NoError(t, err) + }() + } + wg.Wait() + + // Conversation should be removed + _, ok := r.conversations.Get("concurrent-complete") + assert.False(t, ok) + }) +} + // TestParseDomainNameAndType tests DNS query parsing func TestParseDomainNameAndType(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() tests := []struct { name string @@ -645,53 +924,107 @@ func TestParseDomainNameAndType(t *testing.T) { // TestConversationCleanup tests cleanup of stale conversations func TestConversationCleanup(t *testing.T) { - r := &Redirector{ - conversationTimeout: 15 * time.Minute, - } + r := newTestRedirector() // Create stale conversation staleConv := &Conversation{ ID: "stale", LastActivity: time.Now().Add(-20 * time.Minute), } - r.conversations.Store("stale", staleConv) - r.conversationCount = 1 + r.conversations.Add("stale", staleConv) // Create fresh conversation freshConv := &Conversation{ ID: "fresh", LastActivity: time.Now(), } - r.conversations.Store("fresh", freshConv) - r.conversationCount = 2 + r.conversations.Add("fresh", freshConv) - // Run cleanup + // Run cleanup (mirrors cleanupConversations logic) now := time.Now() - r.conversations.Range(func(key, value any) bool { - conv := value.(*Conversation) + for _, key := range r.conversations.Keys() { + conv, ok := r.conversations.Peek(key) + if !ok { + continue + } conv.mu.Lock() - if now.Sub(conv.LastActivity) > r.conversationTimeout { - r.conversations.Delete(key) - r.conversationCount-- + shouldDelete := false + if conv.ResponseServed { + shouldDelete = now.Sub(conv.LastActivity) > ServedConversationTimeout + } else { + shouldDelete = now.Sub(conv.LastActivity) > r.conversationTimeout + } + if shouldDelete { + r.conversations.Remove(key) } conv.mu.Unlock() - return true - }) + } // Verify stale was removed - _, ok := r.conversations.Load("stale") + _, ok := r.conversations.Get("stale") assert.False(t, ok, "stale conversation should be removed") // Verify fresh remains - _, ok = r.conversations.Load("fresh") + _, ok = r.conversations.Get("fresh") assert.True(t, ok, "fresh conversation should remain") - assert.Equal(t, int32(1), r.conversationCount) + assert.Equal(t, 1, r.conversations.Len()) +} + +// TestServedConversationCleanup tests that conversations with ResponseServed +// flag are cleaned up with the shorter ServedConversationTimeout. +func TestServedConversationCleanup(t *testing.T) { + r := newTestRedirector() + + // Served conversation: response was delivered 3 minutes ago (> 2 min served timeout) + servedConv := &Conversation{ + ID: "served", + ResponseServed: true, + LastActivity: time.Now().Add(-3 * time.Minute), + } + r.conversations.Add("served", servedConv) + + // In-progress conversation: 3 minutes old but not served (< 15 min normal timeout) + activeConv := &Conversation{ + ID: "active", + LastActivity: time.Now().Add(-3 * time.Minute), + } + r.conversations.Add("active", activeConv) + + // Run cleanup + now := time.Now() + for _, key := range r.conversations.Keys() { + conv, ok := r.conversations.Peek(key) + if !ok { + continue + } + conv.mu.Lock() + shouldDelete := false + if conv.ResponseServed { + shouldDelete = now.Sub(conv.LastActivity) > ServedConversationTimeout + } else { + shouldDelete = now.Sub(conv.LastActivity) > r.conversationTimeout + } + if shouldDelete { + r.conversations.Remove(key) + } + conv.mu.Unlock() + } + + // Served conversation should be cleaned (3 min > 2 min served timeout) + _, ok := r.conversations.Get("served") + assert.False(t, ok, "served conversation should be cleaned up with shorter timeout") + + // Active conversation should remain (3 min < 15 min normal timeout) + _, ok = r.conversations.Get("active") + assert.True(t, ok, "in-progress conversation should remain") + + assert.Equal(t, 1, r.conversations.Len()) } // TestConcurrentConversationAccess tests thread safety of conversation handling func TestConcurrentConversationAccess(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() initPayload := &dnspb.InitPayload{ MethodCode: "/c2.C2/ClaimTasks", @@ -718,11 +1051,10 @@ func TestConcurrentConversationAccess(t *testing.T) { go func(seq uint32) { defer wg.Done() - val, ok := r.conversations.Load("concurrent") + conv, ok := r.conversations.Get("concurrent") if !ok { return } - conv := val.(*Conversation) conv.mu.Lock() conv.Chunks[seq] = []byte{byte(seq)} conv.mu.Unlock() @@ -731,15 +1063,14 @@ func TestConcurrentConversationAccess(t *testing.T) { wg.Wait() // Verify all chunks stored - val, ok := r.conversations.Load("concurrent") + conv, ok := r.conversations.Get("concurrent") require.True(t, ok) - conv := val.(*Conversation) assert.Len(t, conv.Chunks, 100) } // TestBuildDNSResponse tests DNS response packet construction func TestBuildDNSResponse(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() // Create a mock UDP connection for testing serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") @@ -777,7 +1108,7 @@ func TestBuildDNSResponse(t *testing.T) { // TestHandleDataPacket tests DATA packet processing and chunk storage func TestHandleDataPacket(t *testing.T) { t.Run("store single chunk", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create conversation first with INIT - set TotalChunks > 1 to avoid completion @@ -817,15 +1148,14 @@ func TestHandleDataPacket(t *testing.T) { assert.Equal(t, "data1234", statusPacket.ConversationId) // Verify chunk was stored - val, ok := r.conversations.Load("data1234") + conv, ok := r.conversations.Get("data1234") require.True(t, ok) - conv := val.(*Conversation) assert.Len(t, conv.Chunks, 1) assert.Equal(t, []byte{0x01}, conv.Chunks[1]) }) t.Run("store multiple chunks with gaps", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create conversation @@ -869,9 +1199,8 @@ func TestHandleDataPacket(t *testing.T) { } // Verify chunks stored - val, ok := r.conversations.Load("gaps1234") + conv, ok := r.conversations.Get("gaps1234") require.True(t, ok) - conv := val.(*Conversation) assert.Len(t, conv.Chunks, 3) assert.Equal(t, []byte{1}, conv.Chunks[1]) assert.Equal(t, []byte{3}, conv.Chunks[3]) @@ -880,7 +1209,7 @@ func TestHandleDataPacket(t *testing.T) { }) t.Run("unknown conversation", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() dataPacket := &dnspb.DNSPacket{ @@ -896,7 +1225,7 @@ func TestHandleDataPacket(t *testing.T) { }) t.Run("sequence out of bounds", func(t *testing.T) { - r := &Redirector{} + r := newTestRedirector() ctx := context.Background() // Create conversation @@ -928,6 +1257,67 @@ func TestHandleDataPacket(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "sequence out of bounds") }) + + t.Run("short-circuit for completed conversation", func(t *testing.T) { + r := newTestRedirector() + ctx := context.Background() + + // Create a conversation that is already completed + conv := &Conversation{ + ID: "completed1", + TotalChunks: 3, + Completed: true, + Chunks: map[uint32][]byte{ + 1: {0x01}, + 2: {0x02}, + 3: {0x03}, + }, + LastActivity: time.Now(), + } + r.conversations.Add("completed1", conv) + + // Send a duplicate DATA to completed conversation + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "completed1", + Sequence: 1, + Data: []byte{0xFF}, // Different data + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + // Should get full ack range without recomputation + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + require.Len(t, statusPacket.Acks, 1) + assert.Equal(t, uint32(1), statusPacket.Acks[0].StartSeq) + assert.Equal(t, uint32(3), statusPacket.Acks[0].EndSeq) + assert.Empty(t, statusPacket.Nacks) + + // Original chunk data should NOT be overwritten + assert.Equal(t, []byte{0x01}, conv.Chunks[1]) + }) + + t.Run("data packet for removed conversation returns not found", func(t *testing.T) { + r := newTestRedirector() + ctx := context.Background() + + // Failed conversations are immediately removed from the cache, + // so a DATA packet should get "conversation not found" + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "failed1", + Sequence: 1, + Data: []byte{0x01}, + } + + _, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + }) } // TestProcessCompletedConversation tests data reassembly and CRC validation @@ -991,3 +1381,98 @@ func TestProcessCompletedConversation(t *testing.T) { assert.NotEqual(t, wrongCRC, actualCRC, "CRC should mismatch") }) } + +// TestConversationNotFoundError verifies the error message for missing conversations +func TestConversationNotFoundError(t *testing.T) { + r := newTestRedirector() + + t.Run("DATA returns conversation not found error", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "missing123", + Sequence: 1, + Data: []byte{0x01}, + } + + _, err := r.handleDataPacket(context.Background(), nil, packet, txtRecordType) + require.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + assert.Contains(t, err.Error(), "missing123") + }) + + t.Run("FETCH returns conversation not found error", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "missing456", + } + + _, err := r.handleFetchPacket(packet) + require.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + assert.Contains(t, err.Error(), "missing456") + }) +} + +// TestLRUEvictionBehavior verifies the LRU cache evicts oldest conversations when full +func TestLRUEvictionBehavior(t *testing.T) { + t.Run("evicts oldest when at capacity", func(t *testing.T) { + cache, _ := lru.New[string, *Conversation](3) + r := &Redirector{ + conversations: cache, + conversationTimeout: NormalConversationTimeout, + } + + // Add 3 conversations to fill capacity + for i := 0; i < 3; i++ { + r.conversations.Add(fmt.Sprintf("conv%d", i), &Conversation{ + ID: fmt.Sprintf("conv%d", i), + LastActivity: time.Now(), + }) + } + assert.Equal(t, 3, r.conversations.Len()) + + // Adding a 4th should evict the oldest (conv0) + r.conversations.Add("conv3", &Conversation{ + ID: "conv3", + LastActivity: time.Now(), + }) + assert.Equal(t, 3, r.conversations.Len()) + + _, ok := r.conversations.Get("conv0") + assert.False(t, ok, "oldest conversation should be evicted") + + _, ok = r.conversations.Get("conv3") + assert.True(t, ok, "newest conversation should exist") + }) + + t.Run("Get refreshes recency", func(t *testing.T) { + cache, _ := lru.New[string, *Conversation](3) + r := &Redirector{ + conversations: cache, + conversationTimeout: NormalConversationTimeout, + } + + // Add 3 conversations + for i := 0; i < 3; i++ { + r.conversations.Add(fmt.Sprintf("conv%d", i), &Conversation{ + ID: fmt.Sprintf("conv%d", i), + LastActivity: time.Now(), + }) + } + + // Access conv0 to refresh its recency + r.conversations.Get("conv0") + + // Adding conv3 should evict conv1 (now the oldest) instead of conv0 + r.conversations.Add("conv3", &Conversation{ + ID: "conv3", + LastActivity: time.Now(), + }) + + _, ok := r.conversations.Get("conv0") + assert.True(t, ok, "conv0 should survive due to recent Get") + + _, ok = r.conversations.Get("conv1") + assert.False(t, ok, "conv1 should be evicted as the oldest") + }) +}