From 47fb23f882e1711b8909a2d6675b9b33757c21d1 Mon Sep 17 00:00:00 2001 From: Lewis Marshall Date: Fri, 13 Feb 2026 23:07:54 +0000 Subject: [PATCH] feat: Implement Message Updates, Deletes, and Appends Signed-off-by: Lewis Marshall --- ably/example_message_updates_test.go | 114 +++++++ ably/export_test.go | 2 +- ably/message_updates_integration_test.go | 362 +++++++++++++++++++++ ably/proto_http.go | 17 +- ably/proto_message.go | 164 ++++++++++ ably/proto_message_operations_test.go | 240 ++++++++++++++ ably/proto_protocol_message.go | 42 ++- ably/realtime_channel.go | 222 ++++++++++++- ably/realtime_conn.go | 27 +- ably/realtime_experimental_objects.go | 4 +- ably/realtime_experimental_objects_test.go | 14 +- ably/realtime_presence.go | 15 +- ably/rest_channel.go | 219 ++++++++++++- ably/rest_client.go | 47 ++- ably/state.go | 55 +++- ably/state_test.go | 162 +++++++++ ablytest/sandbox.go | 10 +- 17 files changed, 1634 insertions(+), 82 deletions(-) create mode 100644 ably/example_message_updates_test.go create mode 100644 ably/message_updates_integration_test.go create mode 100644 ably/proto_message_operations_test.go create mode 100644 ably/state_test.go diff --git a/ably/example_message_updates_test.go b/ably/example_message_updates_test.go new file mode 100644 index 000000000..d70320b1e --- /dev/null +++ b/ably/example_message_updates_test.go @@ -0,0 +1,114 @@ +package ably_test + +import ( + "context" + "fmt" + + "github.com/ably/ably-go/ably" +) + +// Example demonstrating how to publish a message and get its serial +func ExampleRESTChannel_PublishWithResult() { + client, err := ably.NewREST(ably.WithKey("xxx:xxx")) + if err != nil { + panic(err) + } + + channel := client.Channels.Get("example-channel") + + // Publish a message and get its serial + result, err := channel.PublishWithResult(context.Background(), "event-name", "message data") + if err != nil { + panic(err) + } + + if result.Serial == nil { + fmt.Println("Message published but serial not available (discarded by conflation)") + return + } + fmt.Printf("Message published with serial: %s\n", *result.Serial) +} + +// Example demonstrating how to update a message +func ExampleRESTChannel_UpdateMessage() { + client, err := ably.NewREST(ably.WithKey("xxx:xxx")) + if err != nil { + panic(err) + } + + channel := client.Channels.Get("example-channel") + + // First publish a message to get its serial + result, err := channel.PublishWithResult(context.Background(), "event", "initial data") + if err != nil { + panic(err) + } + + if result.Serial == nil { + fmt.Println("Message published but serial not available (discarded by conflation)") + return + } + + // Update the message + msg := &ably.Message{ + Serial: *result.Serial, + Data: "updated data", + } + + updateResult, err := channel.UpdateMessage( + context.Background(), + msg, + ably.UpdateWithDescription("Fixed typo"), + ably.UpdateWithMetadata(map[string]string{"editor": "alice"}), + ) + if err != nil { + panic(err) + } + + if updateResult.VersionSerial == nil { + fmt.Println("Message updated but version serial not available (superseded)") + return + } + fmt.Printf("Message updated with version serial: %s\n", *updateResult.VersionSerial) +} + +// Example demonstrating async message append for AI streaming +func ExampleRealtimeChannel_AppendMessageAsync() { + client, err := ably.NewRealtime(ably.WithKey("xxx:xxx")) + if err != nil { + panic(err) + } + + channel := client.Channels.Get("chat-channel") + + // Publish initial message + result, err := channel.PublishWithResult(context.Background(), "ai-response", "The answer is") + if err != nil { + panic(err) + } + + if result.Serial == nil { + fmt.Println("Message published but serial not available (discarded by conflation)") + return + } + + // Stream tokens asynchronously without blocking + tokens := []string{" 42", ".", " This", " is", " the", " answer."} + for _, token := range tokens { + msg := &ably.Message{ + Serial: *result.Serial, + Data: token, + } + // Non-blocking append - critical for AI streaming + err := channel.AppendMessageAsync(msg, func(r *ably.UpdateDeleteResult, err error) { + if err != nil { + fmt.Printf("Append failed: %v\n", err) + } + }) + if err != nil { + panic(err) + } + } + + fmt.Println("All tokens queued for append") +} diff --git a/ably/export_test.go b/ably/export_test.go index 9ac12bbb4..020fa4809 100644 --- a/ably/export_test.go +++ b/ably/export_test.go @@ -222,7 +222,7 @@ func (c *Connection) AckAll() { c.mtx.Unlock() c.log().Infof("Ack all %d messages waiting for ACK/NACK", len(cx)) for _, v := range cx { - v.onAck(nil) + v.ackCallback.call(nil, nil) } } diff --git a/ably/message_updates_integration_test.go b/ably/message_updates_integration_test.go new file mode 100644 index 000000000..82a38bd07 --- /dev/null +++ b/ably/message_updates_integration_test.go @@ -0,0 +1,362 @@ +//go:build !unit +// +build !unit + +package ably_test + +import ( + "context" + "testing" + "time" + + "github.com/ably/ably-go/ably" + "github.com/ably/ably-go/ablytest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRESTChannel_MessageUpdates(t *testing.T) { + app, err := ablytest.NewSandbox(nil) + require.NoError(t, err) + defer app.Close() + + client, err := ably.NewREST(app.Options()...) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("PublishWithResult", func(t *testing.T) { + // Use mutable: namespace to enable message operations feature + channel := client.Channels.Get("mutable:test_publish_with_result") + + t.Run("returns serial for published message", func(t *testing.T) { + result, err := channel.PublishWithResult(ctx, "event1", "test data") + require.NoError(t, err) + require.NotNil(t, result.Serial, "Expected non-nil serial") + assert.NotEmpty(t, *result.Serial, "Expected non-empty serial") + }) + + t.Run("PublishMultipleWithResult returns serials for all messages", func(t *testing.T) { + messages := []*ably.Message{ + {Name: "event1", Data: "data1"}, + {Name: "event2", Data: "data2"}, + {Name: "event3", Data: "data3"}, + } + + results, err := channel.PublishMultipleWithResult(ctx, messages) + require.NoError(t, err) + assert.Len(t, results, 3) + + for i, result := range results { + require.NotNil(t, result.Serial, "Expected non-nil serial for message %d", i) + assert.NotEmpty(t, *result.Serial, "Expected non-empty serial for message %d", i) + } + }) + }) + + t.Run("UpdateMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_update_message") + + t.Run("updates a message with new data", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "initial data") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // Update the message + msg := &ably.Message{ + Serial: *publishResult.Serial, + Data: "updated data", + } + updateResult, err := channel.UpdateMessage(ctx, msg, + ably.UpdateWithDescription("Fixed typo"), + ably.UpdateWithMetadata(map[string]string{"editor": "test"}), + ) + require.NoError(t, err) + require.NotNil(t, updateResult.VersionSerial, "Expected non-nil version serial") + assert.NotEmpty(t, *updateResult.VersionSerial, "Expected version serial") + assert.NotEqual(t, *publishResult.Serial, *updateResult.VersionSerial, "VersionSerial should differ from original Serial") + + // Verify the update by fetching the message (eventually consistent) + require.Eventually(t, func() bool { + retrieved, err := channel.GetMessage(ctx, *publishResult.Serial) + if err != nil { + return false + } + return retrieved.Data == "updated data" + }, 5*time.Second, 100*time.Millisecond, "Updated message should be retrievable") + }) + + t.Run("returns error when message has no serial", func(t *testing.T) { + msg := &ably.Message{Data: "test"} + _, err := channel.UpdateMessage(ctx, msg) + require.Error(t, err) + + errorInfo, ok := err.(*ably.ErrorInfo) + require.True(t, ok, "Expected ErrorInfo") + assert.Equal(t, ably.ErrorCode(40003), errorInfo.Code) + assert.Contains(t, errorInfo.Message(), "lacks a serial") + }) + }) + + t.Run("DeleteMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_delete_message") + + t.Run("deletes a message", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "data to delete") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // Delete the message + msg := &ably.Message{ + Serial: *publishResult.Serial, + } + deleteResult, err := channel.DeleteMessage(ctx, msg, + ably.UpdateWithDescription("Deleted by test"), + ) + require.NoError(t, err) + require.NotNil(t, deleteResult.VersionSerial, "Expected non-nil version serial") + assert.NotEmpty(t, *deleteResult.VersionSerial, "Expected version serial") + }) + }) + + t.Run("AppendMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_append_message") + + t.Run("appends to a message", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "Hello") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // Append to the message + msg := &ably.Message{ + Serial: *publishResult.Serial, + Data: " World", + } + appendResult, err := channel.AppendMessage(ctx, msg) + require.NoError(t, err) + require.NotNil(t, appendResult.VersionSerial, "Expected non-nil version serial") + assert.NotEmpty(t, *appendResult.VersionSerial, "Expected version serial") + + // Verify by fetching the message - data should be appended (eventually consistent) + require.Eventually(t, func() bool { + retrieved, err := channel.GetMessage(ctx, *publishResult.Serial) + if err != nil { + return false + } + // Verify the data was appended: "Hello" + " World" = "Hello World" + return retrieved.Data == "Hello World" + }, 5*time.Second, 100*time.Millisecond, "Message data should be appended") + }) + }) + + t.Run("GetMessage", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_get_message") + + t.Run("retrieves a message by serial", func(t *testing.T) { + // Publish a message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "test data") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // GetMessage is eventually consistent - retry until message is available + require.Eventually(t, func() bool { + msg, err := channel.GetMessage(ctx, *publishResult.Serial) + if err != nil { + return false + } + return msg.Data == "test data" && msg.Serial == *publishResult.Serial + }, 5*time.Second, 100*time.Millisecond, "Message should be retrievable") + }) + }) + + t.Run("GetMessageVersions", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_get_message_versions") + + t.Run("retrieves all versions after updates", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "test-event", "version 1") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // Update the message twice + msg := &ably.Message{ + Serial: *publishResult.Serial, + Data: "version 2", + } + _, err = channel.UpdateMessage(ctx, msg, ably.UpdateWithDescription("First update")) + require.NoError(t, err) + + msg.Data = "version 3" + _, err = channel.UpdateMessage(ctx, msg, ably.UpdateWithDescription("Second update")) + require.NoError(t, err) + + // GetMessageVersions is eventually consistent - retry until all versions are available + var versions []*ably.Message + require.Eventually(t, func() bool { + page, err := channel.GetMessageVersions(*publishResult.Serial, nil).Pages(ctx) + if err != nil { + return false + } + + // Must call Next() to decode the response body into items + if !page.Next(ctx) { + return false + } + + versions = page.Items() + + // Should have exactly 3 versions: original publish + 2 updates + return len(versions) == 3 + }, 10*time.Second, 200*time.Millisecond, "All three message versions should be retrievable") + + // Verify we have exactly 3 versions in the correct order + require.Equal(t, 3, len(versions)) + assert.Equal(t, ably.MessageActionCreate, versions[0].Action, "First version should be MESSAGE_CREATE") + assert.Equal(t, ably.MessageActionUpdate, versions[1].Action, "Second version should be MESSAGE_UPDATE") + assert.Equal(t, ably.MessageActionUpdate, versions[2].Action, "Third version should be MESSAGE_UPDATE") + }) + }) +} + +func TestRealtimeChannel_MessageUpdates(t *testing.T) { + app, err := ablytest.NewSandbox(nil) + require.NoError(t, err) + defer app.Close() + + client, err := ably.NewRealtime(app.Options()...) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + + // Wait for connection + err = ablytest.Wait(ablytest.ConnWaiter(client, nil, ably.ConnectionEventConnected), nil) + require.NoError(t, err) + + t.Run("PublishWithResult", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_realtime_publish_with_result") + + // Attach channel + err := channel.Attach(ctx) + require.NoError(t, err) + + t.Run("returns serial for published message", func(t *testing.T) { + result, err := channel.PublishWithResult(ctx, "event1", "realtime data") + require.NoError(t, err) + require.NotNil(t, result.Serial, "Expected non-nil serial") + assert.NotEmpty(t, *result.Serial, "Expected non-empty serial") + }) + + t.Run("PublishMultipleWithResult returns serials", func(t *testing.T) { + messages := []*ably.Message{ + {Name: "evt1", Data: "data1"}, + {Name: "evt2", Data: "data2"}, + } + + results, err := channel.PublishMultipleWithResult(ctx, messages) + require.NoError(t, err) + assert.Len(t, results, 2) + + for i, result := range results { + require.NotNil(t, result.Serial, "Expected non-nil serial for message %d", i) + assert.NotEmpty(t, *result.Serial, "Expected non-empty serial for message %d", i) + } + }) + }) + + t.Run("UpdateMessageAsync", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_realtime_update_async") + + // Attach channel + err := channel.Attach(ctx) + require.NoError(t, err) + + t.Run("updates message asynchronously", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "event", "initial") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // Update asynchronously + done := make(chan *ably.UpdateDeleteResult, 1) + errChan := make(chan error, 1) + + msg := &ably.Message{ + Serial: *publishResult.Serial, + Data: "updated async", + } + err = channel.UpdateMessageAsync(msg, func(result *ably.UpdateDeleteResult, err error) { + if err != nil { + errChan <- err + } else { + done <- result + } + }) + require.NoError(t, err) + + // Wait for callback + select { + case result := <-done: + require.NotNil(t, result.VersionSerial) + assert.NotEmpty(t, *result.VersionSerial) + case err := <-errChan: + t.Fatalf("Update failed: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for update callback") + } + }) + }) + + t.Run("AppendMessageAsync", func(t *testing.T) { + channel := client.Channels.Get("mutable:test_ai_streaming") + + // Attach channel + err := channel.Attach(ctx) + require.NoError(t, err) + + t.Run("rapid async appends for AI token streaming", func(t *testing.T) { + // Publish initial message + publishResult, err := channel.PublishWithResult(ctx, "ai-response", "The answer is") + require.NoError(t, err) + require.NotNil(t, publishResult.Serial) + + // Simulate rapid token streaming + tokens := []string{" 42", ".", " This", " is", " correct", "."} + type ack struct { + result *ably.UpdateDeleteResult + err error + } + acks := make(chan ack, len(tokens)) + + for _, token := range tokens { + msg := &ably.Message{ + Serial: *publishResult.Serial, + Data: token, + } + + err := channel.AppendMessageAsync(msg, func(result *ably.UpdateDeleteResult, err error) { + acks <- ack{result, err} + }) + require.NoError(t, err, "Failed to queue append %q", token) + } + + // Wait for all appends to complete (with timeout) + timeout := time.After(10 * time.Second) + ackCount := 0 + for ackCount < len(tokens) { + select { + case ack := <-acks: + require.NoError(t, ack.err) + ackCount++ + case <-timeout: + t.Fatalf("Timeout: Only %d/%d appends completed before timeout", ackCount, len(tokens)) + } + } + + assert.Equal(t, len(tokens), ackCount, "All appends should complete") + }) + }) +} diff --git a/ably/proto_http.go b/ably/proto_http.go index 28dc553bc..787c6f91e 100644 --- a/ably/proto_http.go +++ b/ably/proto_http.go @@ -13,11 +13,18 @@ const ( ablyErrorMessageHeader = "X-Ably-Errormessage" clientLibraryVersion = "1.3.0" clientRuntimeName = "go" - ablyProtocolVersion = "2" // CSV2 - ablyClientIDHeader = "X-Ably-ClientId" - hostHeader = "Host" - ablyAgentHeader = "Ably-Agent" // RSC7d - ablySDKIdentifier = "ably-go/" + clientLibraryVersion // RSC7d1 + // ablyProtocolVersion is the default Ably protocol version used for all requests. + // Protocol v5 is required for message operations (publish/update/delete/append) to return + // message serials and version information. + // + // Note: Stats requests explicitly override this to use protocol v2 to maintain compatibility + // with the existing nested Stats structure. Migrating stats to the flattened v3+ format + // requires breaking API changes and is planned for ably-go v2.0. + ablyProtocolVersion = "5" // CSV2 + ablyClientIDHeader = "X-Ably-ClientId" + hostHeader = "Host" + ablyAgentHeader = "Ably-Agent" // RSC7d + ablySDKIdentifier = "ably-go/" + clientLibraryVersion // RSC7d1 ) var goRuntimeIdentifier = func() string { diff --git a/ably/proto_message.go b/ably/proto_message.go index 3dd2e4a45..e1e515f57 100644 --- a/ably/proto_message.go +++ b/ably/proto_message.go @@ -7,6 +7,8 @@ import ( "fmt" "strings" "unicode/utf8" + + "github.com/ugorji/go/codec" ) // encodings @@ -18,6 +20,150 @@ const ( encVCDiff = "vcdiff" ) +// MessageAction represents the type of message operation (TM5). +type MessageAction string + +const ( + MessageActionUnknown MessageAction = "UNKNOWN" + MessageActionCreate MessageAction = "MESSAGE_CREATE" + MessageActionUpdate MessageAction = "MESSAGE_UPDATE" + MessageActionDelete MessageAction = "MESSAGE_DELETE" + MessageActionMeta MessageAction = "META" + MessageActionMessageSummary MessageAction = "MESSAGE_SUMMARY" + MessageActionAppend MessageAction = "MESSAGE_APPEND" +) + +// messageActions is a slice of MessageAction constants (TM5) where the index +// of a given constant represents the numeric value to use when encoding that +// constant over the wire (see encodeMessageAction). +var messageActions = []MessageAction{ + MessageActionCreate, // 0 = MESSAGE_CREATE + MessageActionUpdate, // 1 = MESSAGE_UPDATE + MessageActionDelete, // 2 = MESSAGE_DELETE + MessageActionMeta, // 3 = META + MessageActionMessageSummary, // 4 = MESSAGE_SUMMARY + MessageActionAppend, // 5 = MESSAGE_APPEND +} + +func encodeMessageAction(action MessageAction) int { + for i, a := range messageActions { + if a == action { + return i + } + } + return 0 // default to create +} + +func decodeMessageAction(num int) MessageAction { + if num >= 0 && num < len(messageActions) { + return messageActions[num] + } + return MessageActionUnknown +} + +// MarshalJSON implements json.Marshaler to encode MessageAction as numeric for wire compatibility. +func (a MessageAction) MarshalJSON() ([]byte, error) { + return json.Marshal(encodeMessageAction(a)) +} + +// UnmarshalJSON implements json.Unmarshaler to decode numeric wire format to MessageAction. +func (a *MessageAction) UnmarshalJSON(data []byte) error { + var num int + if err := json.Unmarshal(data, &num); err != nil { + return err + } + *a = decodeMessageAction(num) + return nil +} + +// CodecEncodeSelf implements codec.Selfer for MessagePack encoding. +func (a MessageAction) CodecEncodeSelf(encoder *codec.Encoder) { + encoder.MustEncode(encodeMessageAction(a)) +} + +// CodecDecodeSelf implements codec.Selfer for MessagePack decoding. +func (a *MessageAction) CodecDecodeSelf(decoder *codec.Decoder) { + var num int + decoder.MustDecode(&num) + *a = decodeMessageAction(num) +} + +// MessageVersion contains version information for a message (TM2s). +// When received from the server, Serial and Timestamp are server-populated. +// When sending an update/delete/append, ClientID, Description, and Metadata +// are user-provided via UpdateOption functions (mapped from MOP2a/MOP2b/MOP2c). +type MessageVersion struct { + // Serial is an opaque version identifier, assigned by the server (TM2s1). Read-only on received messages. + Serial string `json:"serial,omitempty" codec:"serial,omitempty"` + // Timestamp is set by the server when the version is created, ms since epoch (TM2s2). Read-only on received messages. + Timestamp int64 `json:"timestamp,omitempty" codec:"timestamp,omitempty"` + // ClientID identifies the client that performed the operation (TM2s3, MOP2a). + ClientID string `json:"clientId,omitempty" codec:"clientId,omitempty"` + // Description is a human-readable description of the operation (TM2s4, MOP2b). + Description string `json:"description,omitempty" codec:"description,omitempty"` + // Metadata contains arbitrary key-value pairs about the operation (TM2s5, MOP2c). + Metadata map[string]string `json:"metadata,omitempty" codec:"metadata,omitempty"` +} + +// PublishResult contains the result of a publish operation with serial tracking. +// The spec (PBR2a) defines PublishResult with a serials array, but per RSL1n1/RTL6j1, +// SDKs may implement alternatives where adding a response value would be a breaking +// API change. This SDK returns one PublishResult per message for ergonomics. +type PublishResult struct { + Serial *string // nil if message was discarded by conflation (PBR2a) +} + +// UpdateDeleteResult contains the result of an update, delete, or append operation (UDR1). +type UpdateDeleteResult struct { + VersionSerial *string // nil if superseded (UDR2a) +} + +// UpdateOption is a functional option for message update operations. +type UpdateOption func(*updateOptions) + +type updateOptions struct { + version *MessageVersion // unexported, built lazily from options + params map[string]string // URL query parameters (RSL15f) / ProtocolMessage params (RTL32e) +} + +// UpdateWithDescription sets a description for the update operation. +func UpdateWithDescription(description string) UpdateOption { + return func(o *updateOptions) { + if o.version == nil { + o.version = &MessageVersion{} + } + o.version.Description = description + } +} + +// UpdateWithClientID sets the client ID for the update operation. +func UpdateWithClientID(clientID string) UpdateOption { + return func(o *updateOptions) { + if o.version == nil { + o.version = &MessageVersion{} + } + o.version.ClientID = clientID + } +} + +// UpdateWithMetadata sets metadata for the update operation. +func UpdateWithMetadata(metadata map[string]string) UpdateOption { + return func(o *updateOptions) { + if o.version == nil { + o.version = &MessageVersion{} + } + o.version.Metadata = metadata + } +} + +// UpdateWithParams sets operation params. When using REST, these are included as URL query +// parameters (RSL15a, RSL15f). When using Realtime, these are set as protocolMessage.Params (RTL32e). +func UpdateWithParams(params map[string]string) UpdateOption { + return func(o *updateOptions) { + o.params = params + } +} + // Message contains an individual message that is sent to, or received from, Ably. type Message struct { // ID is a unique identifier assigned by Ably to this message (TM2a). @@ -42,6 +188,12 @@ type Message struct { // Extras is a JSON object of arbitrary key-value pairs that may contain metadata, and/or ancillary payloads. // Valid payloads include push, deltaExtras, ReferenceExtras and headers (TM2i). Extras map[string]interface{} `json:"extras,omitempty" codec:"extras,omitempty"` + // Serial is a permanent identifier for this message assigned by the server (TM2r). + Serial string `json:"serial,omitempty" codec:"serial,omitempty"` + // Action indicates the type of message operation (TM5). + Action MessageAction `json:"action,omitempty" codec:"action,omitempty"` + // Version contains version information for the message (TM2s). + Version *MessageVersion `json:"version,omitempty" codec:"version,omitempty"` } // DeltaExtras describes a message whose payload is a "vcdiff"-encoded delta generated with respect to a base message (DE1, DE2). @@ -101,6 +253,18 @@ func (p *protocolMessage) updateInnerMessageEmptyFields(m *Message, index int) { if m.Timestamp == 0 { m.Timestamp = p.Timestamp } + // TM2s: Initialize version object if not present on received messages. + if m.Version == nil { + m.Version = &MessageVersion{} + } + // TM2s1: Default version.serial from message.serial. + if empty(m.Version.Serial) && !empty(m.Serial) { + m.Version.Serial = m.Serial + } + // TM2s2: Default version.timestamp from message.timestamp. + if m.Version.Timestamp == 0 && m.Timestamp != 0 { + m.Version.Timestamp = m.Timestamp + } } // updateInnerMessagesEmptyFields updates [Message.ID], [Message.ConnectionID] and [Message.Timestamp] with diff --git a/ably/proto_message_operations_test.go b/ably/proto_message_operations_test.go new file mode 100644 index 000000000..3c7261279 --- /dev/null +++ b/ably/proto_message_operations_test.go @@ -0,0 +1,240 @@ +package ably + +import ( + "encoding/json" + "testing" + + "github.com/ugorji/go/codec" +) + +func TestMessageAction_JSON_Encoding(t *testing.T) { + tests := []struct { + action MessageAction + expected string + }{ + {MessageActionCreate, "0"}, + {MessageActionUpdate, "1"}, + {MessageActionDelete, "2"}, + {MessageActionMeta, "3"}, + {MessageActionMessageSummary, "4"}, + {MessageActionAppend, "5"}, + } + + for _, tt := range tests { + t.Run(string(tt.action), func(t *testing.T) { + data, err := json.Marshal(tt.action) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + if string(data) != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, string(data)) + } + }) + } +} + +func TestMessageAction_JSON_Decoding(t *testing.T) { + tests := []struct { + input string + expected MessageAction + }{ + {"0", MessageActionCreate}, + {"1", MessageActionUpdate}, + {"2", MessageActionDelete}, + {"3", MessageActionMeta}, + {"4", MessageActionMessageSummary}, + {"5", MessageActionAppend}, + {"999", MessageActionUnknown}, // Unknown values decode to MessageActionUnknown + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + var action MessageAction + err := json.Unmarshal([]byte(tt.input), &action) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if action != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, action) + } + }) + } +} + +func TestMessageAction_Codec_Encoding(t *testing.T) { + tests := []struct { + action MessageAction + num int + }{ + {MessageActionCreate, 0}, + {MessageActionUpdate, 1}, + {MessageActionDelete, 2}, + {MessageActionMeta, 3}, + {MessageActionMessageSummary, 4}, + {MessageActionAppend, 5}, + } + + for _, tt := range tests { + t.Run(string(tt.action), func(t *testing.T) { + var buf []byte + enc := codec.NewEncoderBytes(&buf, &codec.MsgpackHandle{}) + tt.action.CodecEncodeSelf(enc) + + // Decode to verify + var result int + dec := codec.NewDecoderBytes(buf, &codec.MsgpackHandle{}) + dec.MustDecode(&result) + + if result != tt.num { + t.Errorf("Expected %d, got %d", tt.num, result) + } + }) + } +} + +func TestMessageAction_Codec_Decoding(t *testing.T) { + tests := []struct { + num int + expected MessageAction + }{ + {0, MessageActionCreate}, + {1, MessageActionUpdate}, + {2, MessageActionDelete}, + {3, MessageActionMeta}, + {4, MessageActionMessageSummary}, + {5, MessageActionAppend}, + {999, MessageActionUnknown}, // Unknown values decode to MessageActionUnknown + } + + for _, tt := range tests { + t.Run(string(tt.expected), func(t *testing.T) { + var buf []byte + enc := codec.NewEncoderBytes(&buf, &codec.MsgpackHandle{}) + enc.MustEncode(tt.num) + + var action MessageAction + dec := codec.NewDecoderBytes(buf, &codec.MsgpackHandle{}) + action.CodecDecodeSelf(dec) + + if action != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, action) + } + }) + } +} + +func TestMessageVersion_Serialization(t *testing.T) { + version := &MessageVersion{ + Serial: "abc123", + Timestamp: 1234567890, + ClientID: "client1", + Description: "Test update", + Metadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + } + + // Test JSON serialization + data, err := json.Marshal(version) + if err != nil { + t.Fatalf("Failed to marshal MessageVersion: %v", err) + } + + var decoded MessageVersion + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("Failed to unmarshal MessageVersion: %v", err) + } + + if decoded.Serial != version.Serial { + t.Errorf("Serial: expected %s, got %s", version.Serial, decoded.Serial) + } + if decoded.Timestamp != version.Timestamp { + t.Errorf("Timestamp: expected %d, got %d", version.Timestamp, decoded.Timestamp) + } + if decoded.ClientID != version.ClientID { + t.Errorf("ClientID: expected %s, got %s", version.ClientID, decoded.ClientID) + } + if decoded.Description != version.Description { + t.Errorf("Description: expected %s, got %s", version.Description, decoded.Description) + } + if len(decoded.Metadata) != len(version.Metadata) { + t.Errorf("Metadata length: expected %d, got %d", len(version.Metadata), len(decoded.Metadata)) + } +} + +func TestMessage_NewFields_Serialization(t *testing.T) { + msg := &Message{ + ID: "msg123", + Data: "test data", + Serial: "serial123", + Action: MessageActionUpdate, + Version: &MessageVersion{ + Serial: "version_serial", + ClientID: "client1", + Description: "Updated message", + }, + } + + // Test JSON serialization + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Failed to marshal Message: %v", err) + } + + var decoded Message + err = json.Unmarshal(data, &decoded) + if err != nil { + t.Fatalf("Failed to unmarshal Message: %v", err) + } + + if decoded.Serial != msg.Serial { + t.Errorf("Serial: expected %s, got %s", msg.Serial, decoded.Serial) + } + if decoded.Action != msg.Action { + t.Errorf("Action: expected %s, got %s", msg.Action, decoded.Action) + } + if decoded.Version == nil { + t.Fatal("Version should not be nil") + } + if decoded.Version.Serial != msg.Version.Serial { + t.Errorf("Version.Serial: expected %s, got %s", msg.Version.Serial, decoded.Version.Serial) + } +} + +func TestValidateMessageSerial(t *testing.T) { + t.Run("nil message", func(t *testing.T) { + err := validateMessageSerial(nil) + if err == nil { + t.Fatal("Expected error for nil message") + } + if code(err) != 40003 { + t.Errorf("Expected error code 40003, got %d", code(err)) + } + }) + + t.Run("empty serial", func(t *testing.T) { + msg := &Message{Data: "test"} + err := validateMessageSerial(msg) + if err == nil { + t.Fatal("Expected error for message without serial") + } + if code(err) != 40003 { + t.Errorf("Expected error code 40003, got %d", code(err)) + } + // Verify exact error message matches TypeScript + expectedMsg := "this message lacks a serial and cannot be updated. Make sure you have enabled \"Message annotations, updates, and deletes\" in channel settings on your dashboard" + if err.(*ErrorInfo).Message() != expectedMsg { + t.Errorf("Error message mismatch.\nExpected: %s\nGot: %s", expectedMsg, err.(*ErrorInfo).Message()) + } + }) + + t.Run("valid serial", func(t *testing.T) { + msg := &Message{Data: "test", Serial: "abc123"} + err := validateMessageSerial(msg) + if err != nil { + t.Errorf("Expected no error for valid message, got %v", err) + } + }) +} diff --git a/ably/proto_protocol_message.go b/ably/proto_protocol_message.go index 809d21195..cff9370cd 100644 --- a/ably/proto_protocol_message.go +++ b/ably/proto_protocol_message.go @@ -117,25 +117,31 @@ func coerceInt64(v interface{}) int64 { } } +// protocolPublishResult matches the wire format for publish results in ACK messages (TR4s, PBR2a). +type protocolPublishResult struct { + Serials []string `json:"serials,omitempty" codec:"serials,omitempty"` +} + type protocolMessage struct { - Messages []*Message `json:"messages,omitempty" codec:"messages,omitempty"` - Presence []*PresenceMessage `json:"presence,omitempty" codec:"presence,omitempty"` - State []*objects.Message `json:"state,omitempty" codec:"state,omitempty"` - ID string `json:"id,omitempty" codec:"id,omitempty"` - ApplicationID string `json:"applicationId,omitempty" codec:"applicationId,omitempty"` - ConnectionID string `json:"connectionId,omitempty" codec:"connectionId,omitempty"` - ConnectionKey string `json:"connectionKey,omitempty" codec:"connectionKey,omitempty"` - Channel string `json:"channel,omitempty" codec:"channel,omitempty"` - ChannelSerial string `json:"channelSerial,omitempty" codec:"channelSerial,omitempty"` - ConnectionDetails *connectionDetails `json:"connectionDetails,omitempty" codec:"connectionDetails,omitempty"` - Error *errorInfo `json:"error,omitempty" codec:"error,omitempty"` - MsgSerial int64 `json:"msgSerial" codec:"msgSerial"` - Timestamp int64 `json:"timestamp,omitempty" codec:"timestamp,omitempty"` - Count int `json:"count,omitempty" codec:"count,omitempty"` - Action protoAction `json:"action,omitempty" codec:"action,omitempty"` - Flags protoFlag `json:"flags,omitempty" codec:"flags,omitempty"` - Params channelParams `json:"params,omitempty" codec:"params,omitempty"` - Auth *authDetails `json:"auth,omitempty" codec:"auth,omitempty"` + Messages []*Message `json:"messages,omitempty" codec:"messages,omitempty"` + Presence []*PresenceMessage `json:"presence,omitempty" codec:"presence,omitempty"` + State []*objects.Message `json:"state,omitempty" codec:"state,omitempty"` + ID string `json:"id,omitempty" codec:"id,omitempty"` + ApplicationID string `json:"applicationId,omitempty" codec:"applicationId,omitempty"` + ConnectionID string `json:"connectionId,omitempty" codec:"connectionId,omitempty"` + ConnectionKey string `json:"connectionKey,omitempty" codec:"connectionKey,omitempty"` + Channel string `json:"channel,omitempty" codec:"channel,omitempty"` + ChannelSerial string `json:"channelSerial,omitempty" codec:"channelSerial,omitempty"` + ConnectionDetails *connectionDetails `json:"connectionDetails,omitempty" codec:"connectionDetails,omitempty"` + Error *errorInfo `json:"error,omitempty" codec:"error,omitempty"` + MsgSerial int64 `json:"msgSerial" codec:"msgSerial"` + Timestamp int64 `json:"timestamp,omitempty" codec:"timestamp,omitempty"` + Count int `json:"count,omitempty" codec:"count,omitempty"` + Action protoAction `json:"action,omitempty" codec:"action,omitempty"` + Flags protoFlag `json:"flags,omitempty" codec:"flags,omitempty"` + Params channelParams `json:"params,omitempty" codec:"params,omitempty"` + Auth *authDetails `json:"auth,omitempty" codec:"auth,omitempty"` + Res []*protocolPublishResult `json:"res,omitempty" codec:"res,omitempty"` // TR4s: publish results in ACK } // authDetails contains the token string used to authenticate a client with Ably. diff --git a/ably/realtime_channel.go b/ably/realtime_channel.go index da96dd8ee..d5375494d 100644 --- a/ably/realtime_channel.go +++ b/ably/realtime_channel.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "errors" "fmt" + "net/url" "sort" "sync" @@ -728,7 +729,91 @@ func (c *RealtimeChannel) PublishMultipleAsync(messages []*Message, onAck func(e Channel: c.Name, Messages: messages, } - return c.send(msg, onAck) + return c.send(msg, &msgAckCallback{onAck: onAck}) +} + +// PublishWithResult publishes a single message to the channel and returns the serial assigned +// by the server (RTL6j, RTL6j1 alternative). +// This will block until either the publish is acknowledged or fails to deliver. +func (c *RealtimeChannel) PublishWithResult(ctx context.Context, name string, data interface{}) (*PublishResult, error) { + results, err := c.PublishMultipleWithResult(ctx, []*Message{{Name: name, Data: data}}) + if err != nil { + return nil, err + } + if len(results) == 0 { + return &PublishResult{}, nil + } + return &results[0], nil +} + +// PublishWithResultAsync is the same as PublishWithResult except instead of blocking it calls onAck +// with the result or error. Note onAck must not block as it would block the internal client. +func (c *RealtimeChannel) PublishWithResultAsync(name string, data interface{}, onAck func(*PublishResult, error)) error { + return c.PublishMultipleWithResultAsync([]*Message{{Name: name, Data: data}}, func(results []PublishResult, err error) { + if err != nil { + onAck(nil, err) + return + } + if len(results) == 0 { + onAck(&PublishResult{}, nil) + return + } + onAck(&results[0], nil) + }) +} + +// PublishMultipleWithResult publishes all given messages on the channel and returns the serials +// assigned by the server (RTL6j, RTL6j1 alternative). +func (c *RealtimeChannel) PublishMultipleWithResult(ctx context.Context, messages []*Message) ([]PublishResult, error) { + type resultOrError struct { + results []PublishResult + err error + } + listen := make(chan resultOrError, 1) + onAck := func(results []PublishResult, err error) { + listen <- resultOrError{results, err} + } + if err := c.PublishMultipleWithResultAsync(messages, onAck); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-listen: + return result.results, result.err + } +} + +// PublishMultipleWithResultAsync is the same as PublishMultipleWithResult except it calls onAck +// instead of blocking (RTL6j, RTL6j1 alternative). +func (c *RealtimeChannel) PublishMultipleWithResultAsync(messages []*Message, onAck func([]PublishResult, error)) error { + id := c.client.Auth.clientIDForCheck() + for _, v := range messages { + if v.ClientID != "" && id != wildcardClientID && v.ClientID != id { + // Spec RTL6g3,RTL6g4 + return fmt.Errorf("Unable to publish message containing a clientId (%s) that is incompatible with the library clientId (%s)", v.ClientID, id) + } + } + msg := &protocolMessage{ + Action: actionMessage, + Channel: c.Name, + Messages: messages, + } + return c.sendWithSerialCallback(msg, func(serials []string, err error) { + if err != nil { + onAck(nil, err) + return + } + results := make([]PublishResult, len(messages)) + for i := range results { + if i < len(serials) { + serial := serials[i] + results[i].Serial = &serial + } + } + onAck(results, nil) + }) } // History retrieves a [ably.HistoryRequest] object, containing an array of historical @@ -768,8 +853,128 @@ func (c *RealtimeChannel) HistoryUntilAttach(o ...HistoryOption) (*HistoryReques return &historyRequest, nil } -func (c *RealtimeChannel) send(msg *protocolMessage, onAck func(err error)) error { - if enqueued := c.maybeEnqueue(msg, onAck); enqueued { +// performMessageOperationAsync is a shared helper for UpdateMessageAsync, DeleteMessageAsync, and AppendMessageAsync. +// Implements RTL32a-e: validates serial, applies options, sets action/version, encodes data, and sends the protocol message. +func (c *RealtimeChannel) performMessageOperationAsync(msg *Message, action MessageAction, onAck func(*UpdateDeleteResult, error), options ...UpdateOption) error { + if err := validateMessageSerial(msg); err != nil { + return err + } + + id := c.client.Auth.clientIDForCheck() + if msg.ClientID != "" && id != wildcardClientID && msg.ClientID != id { + // Spec RTL6g3,RTL6g4 + return fmt.Errorf("Unable to publish message containing a clientId (%s) that is incompatible with the library clientId (%s)", msg.ClientID, id) + } + + // Apply options + var opts updateOptions + for _, o := range options { + o(&opts) + } + + // RTL32c: Copy message to avoid mutating user-supplied object. + opMsg := *msg + // RTL32b1: Set action for the operation. + opMsg.Action = action + // RTL32b2: Set version from user-supplied operation metadata. + opMsg.Version = opts.version + + // RTL32d (encoding): Encode message data per RSL4/RSC8. + cipher, _ := (*protoChannelOptions)(c.options).GetCipher() + opMsg, err := opMsg.withEncodedData(cipher) + if err != nil { + return fmt.Errorf("encoding data for message: %w", err) + } + + // RTL32b: Send MESSAGE ProtocolMessage with single Message. + protoMsg := &protocolMessage{ + Action: actionMessage, + Channel: c.Name, + Messages: []*Message{&opMsg}, + Params: opts.params, // RTL32e: Pass params in ProtocolMessage.params. + } + + return c.sendWithSerialCallback(protoMsg, func(serials []string, err error) { + if err != nil { + onAck(nil, err) + return + } + // RTL32d: Extract versionSerial from first element of ACK serials. + result := &UpdateDeleteResult{} + if len(serials) > 0 { + serial := serials[0] + result.VersionSerial = &serial + } + onAck(result, nil) + }) +} + +// performMessageOperation is a shared blocking helper for UpdateMessage, DeleteMessage, and AppendMessage. +// It wraps performMessageOperationAsync with a channel-based blocking pattern. +func (c *RealtimeChannel) performMessageOperation(ctx context.Context, msg *Message, action MessageAction, options ...UpdateOption) (*UpdateDeleteResult, error) { + type resultOrError struct { + result *UpdateDeleteResult + err error + } + listen := make(chan resultOrError, 1) + onAck := func(result *UpdateDeleteResult, err error) { + listen <- resultOrError{result, err} + } + if err := c.performMessageOperationAsync(msg, action, onAck, options...); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-listen: + return result.result, result.err + } +} + +// UpdateMessage updates a previously published message. +func (c *RealtimeChannel) UpdateMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateDeleteResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionUpdate, options...) +} + +// UpdateMessageAsync is the same as UpdateMessage except instead of blocking it calls onAck. +func (c *RealtimeChannel) UpdateMessageAsync(msg *Message, onAck func(*UpdateDeleteResult, error), options ...UpdateOption) error { + return c.performMessageOperationAsync(msg, MessageActionUpdate, onAck, options...) +} + +// DeleteMessage deletes a previously published message. +func (c *RealtimeChannel) DeleteMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateDeleteResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionDelete, options...) +} + +// DeleteMessageAsync is the same as DeleteMessage except instead of blocking it calls onAck. +func (c *RealtimeChannel) DeleteMessageAsync(msg *Message, onAck func(*UpdateDeleteResult, error), options ...UpdateOption) error { + return c.performMessageOperationAsync(msg, MessageActionDelete, onAck, options...) +} + +// AppendMessage appends to a previously published message. +func (c *RealtimeChannel) AppendMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateDeleteResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionAppend, options...) +} + +// AppendMessageAsync is the same as AppendMessage except instead of blocking it calls onAck. +// This is critical for AI token streaming use cases where rapid appends should not block. +func (c *RealtimeChannel) AppendMessageAsync(msg *Message, onAck func(*UpdateDeleteResult, error), options ...UpdateOption) error { + return c.performMessageOperationAsync(msg, MessageActionAppend, onAck, options...) +} + +// GetMessage retrieves a message by serial, delegates to REST (RTL28, RSL11). +func (c *RealtimeChannel) GetMessage(ctx context.Context, serial string) (*Message, error) { + return c.client.rest.Channels.Get(c.Name).GetMessage(ctx, serial) +} + +// GetMessageVersions retrieves the version history of a message by serial, delegates to REST (RTL31, RSL14). +func (c *RealtimeChannel) GetMessageVersions(serial string, params url.Values) HistoryRequest { + return c.client.rest.Channels.Get(c.Name).GetMessageVersions(serial, params) +} + +func (c *RealtimeChannel) send(msg *protocolMessage, ackCallback *msgAckCallback) error { + if enqueued := c.maybeEnqueue(msg, ackCallback); enqueued { return nil } @@ -777,11 +982,16 @@ func (c *RealtimeChannel) send(msg *protocolMessage, onAck func(err error)) erro return newError(ErrChannelOperationFailedInvalidChannelState, nil) } - c.client.Connection.send(msg, onAck) + c.client.Connection.send(msg, ackCallback) return nil } -func (c *RealtimeChannel) maybeEnqueue(msg *protocolMessage, onAck func(err error)) bool { +// sendWithSerialCallback sends a message and calls onAck with serials extracted from ACK. +func (c *RealtimeChannel) sendWithSerialCallback(msg *protocolMessage, onAck func(serials []string, err error)) error { + return c.send(msg, &msgAckCallback{onAckWithSerials: onAck}) +} + +func (c *RealtimeChannel) maybeEnqueue(msg *protocolMessage, ackCallback *msgAckCallback) bool { // RTL6c2 if c.opts().NoQueueing { return false @@ -803,7 +1013,7 @@ func (c *RealtimeChannel) maybeEnqueue(msg *protocolMessage, onAck func(err erro ChannelStateDetaching: } - c.queue.Enqueue(msg, onAck) + c.queue.Enqueue(msg, ackCallback) return true } diff --git a/ably/realtime_conn.go b/ably/realtime_conn.go index b913ab2e4..2dc7b385f 100644 --- a/ably/realtime_conn.go +++ b/ably/realtime_conn.go @@ -614,35 +614,36 @@ func (c *Connection) advanceSerial() { c.msgSerial = (c.msgSerial + 1) % maxint64 } -func (c *Connection) send(msg *protocolMessage, onAck func(err error)) { +// send sends a message with an ackCallback. +func (c *Connection) send(msg *protocolMessage, ackCallback *msgAckCallback) { hasMsgSerial := msg.Action == actionMessage || msg.Action == actionPresence || msg.Action == actionObject c.mtx.Lock() // RTP16a - in case of presence msg send, check for connection status and send accordingly switch state := c.state; state { default: c.mtx.Unlock() - if onAck != nil { + if ackCallback != nil { if c.state == ConnectionStateClosed { - onAck(errClosed) + ackCallback.call(nil, errClosed) } else { - onAck(connStateError(state, nil)) + ackCallback.call(nil, connStateError(state, nil)) } } case ConnectionStateInitialized, ConnectionStateConnecting, ConnectionStateDisconnected: c.mtx.Unlock() if c.opts.NoQueueing { - if onAck != nil { - onAck(connStateError(state, errQueueing)) + if ackCallback != nil { + ackCallback.call(nil, connStateError(state, errQueueing)) } } else { - c.queue.Enqueue(msg, onAck) // RTL4i + c.queue.Enqueue(msg, ackCallback) // RTL4i } case ConnectionStateConnected: if err := c.verifyAndUpdateMessages(msg); err != nil { c.mtx.Unlock() - if onAck != nil { - onAck(err) + if ackCallback != nil { + ackCallback.call(nil, err) } return } @@ -660,13 +661,13 @@ func (c *Connection) send(msg *protocolMessage, onAck func(err error)) { c.log().Warnf("transport level failure while sending message, %v", err) c.conn.Close() c.mtx.Unlock() - c.queue.Enqueue(msg, onAck) + c.queue.Enqueue(msg, ackCallback) } else { if hasMsgSerial { c.advanceSerial() } - if onAck != nil { - c.pending.Enqueue(msg, onAck) + if ackCallback != nil { + c.pending.Enqueue(msg, ackCallback) } c.mtx.Unlock() } @@ -760,7 +761,7 @@ func (c *Connection) resendPending() { c.mtx.Unlock() c.log().Debugf("resending %d messages waiting for ACK/NACK", len(cx)) for _, v := range cx { - c.send(v.msg, v.onAck) + c.send(v.msg, v.ackCallback) } } diff --git a/ably/realtime_experimental_objects.go b/ably/realtime_experimental_objects.go index b40ff5a1e..98ab0bdd3 100644 --- a/ably/realtime_experimental_objects.go +++ b/ably/realtime_experimental_objects.go @@ -50,7 +50,7 @@ func (o *RealtimeExperimentalObjects) PublishObjects(ctx context.Context, msgs . Channel: o.channel.getName(), State: msgs, } - if err := o.channel.send(msg, onAck); err != nil { + if err := o.channel.send(msg, &msgAckCallback{onAck: onAck}); err != nil { return err } @@ -63,7 +63,7 @@ func (o *RealtimeExperimentalObjects) PublishObjects(ctx context.Context, msgs . } type channel interface { - send(msg *protocolMessage, onAck func(error)) error + send(msg *protocolMessage, ackCallback *msgAckCallback) error getClientOptions() *clientOptions getName() string } diff --git a/ably/realtime_experimental_objects_test.go b/ably/realtime_experimental_objects_test.go index 21d1eaabc..8f3eaa4d5 100644 --- a/ably/realtime_experimental_objects_test.go +++ b/ably/realtime_experimental_objects_test.go @@ -167,14 +167,14 @@ func TestRealtimeExperimentalObjects_PublishObjects(t *testing.T) { // Create channel mock channelMock := &channelMock{ - SendFunc: func(msg *protocolMessage, onAck func(error)) error { + SendFunc: func(msg *protocolMessage, ackCallback *msgAckCallback) error { if tt.sendError != nil { return tt.sendError } capturedProtocolMsg = msg // Simulate async ack go func() { - onAck(tt.ackError) + ackCallback.call(nil, tt.ackError) }() return nil }, @@ -234,8 +234,8 @@ func TestRealtimeExperimentalObjects_PublishObjects(t *testing.T) { func TestRealtimeExperimentalObjects_PublishObjectsContextCancellation(t *testing.T) { // Test context cancellation during publish channelMock := &channelMock{ - SendFunc: func(msg *protocolMessage, onAck func(error)) error { - // Don't call onAck to simulate hanging + SendFunc: func(msg *protocolMessage, ackCallback *msgAckCallback) error { + // Don't call ackCallback to simulate hanging return nil }, GetClientOptionsFunc: func() *clientOptions { @@ -275,13 +275,13 @@ func TestRealtimeExperimentalObjects_PublishObjectsContextCancellation(t *testin // channelMock implements the channel interface type channelMock struct { - SendFunc func(msg *protocolMessage, onAck func(error)) error + SendFunc func(msg *protocolMessage, ackCallback *msgAckCallback) error GetClientOptionsFunc func() *clientOptions GetNameFunc func() string } -func (c channelMock) send(msg *protocolMessage, onAck func(error)) error { - return c.SendFunc(msg, onAck) +func (c channelMock) send(msg *protocolMessage, ackCallback *msgAckCallback) error { + return c.SendFunc(msg, ackCallback) } func (c channelMock) getClientOptions() *clientOptions { diff --git a/ably/realtime_presence.go b/ably/realtime_presence.go index e93192bd4..01b08a526 100644 --- a/ably/realtime_presence.go +++ b/ably/realtime_presence.go @@ -73,14 +73,14 @@ func (pres *RealtimePresence) onChannelSuspended(err error) { pres.queue.Fail(err) } -func (pres *RealtimePresence) maybeEnqueue(msg *protocolMessage, onAck func(err error)) bool { +func (pres *RealtimePresence) maybeEnqueue(msg *protocolMessage, ackCallback *msgAckCallback) bool { if pres.channel.opts().NoQueueing { - if onAck != nil { - onAck(errors.New("unable enqueue message because Options.QueueMessages is set to false")) + if ackCallback != nil { + ackCallback.call(nil, errors.New("unable enqueue message because Options.QueueMessages is set to false")) } return false } - pres.queue.Enqueue(msg, onAck) + pres.queue.Enqueue(msg, ackCallback) return true } @@ -98,15 +98,16 @@ func (pres *RealtimePresence) send(msg *PresenceMessage) (result, error) { onAck := func(err error) { listen <- err } + ackCallback := &msgAckCallback{onAck: onAck} switch pres.channel.State() { case ChannelStateInitialized: // RTP16b - if pres.maybeEnqueue(protomsg, onAck) { + if pres.maybeEnqueue(protomsg, ackCallback) { pres.channel.attach() } case ChannelStateAttaching: // RTP16b - pres.maybeEnqueue(protomsg, onAck) + pres.maybeEnqueue(protomsg, ackCallback) case ChannelStateAttached: // RTP16a - pres.channel.client.Connection.send(protomsg, onAck) // RTP16a, RTL6c + pres.channel.client.Connection.send(protomsg, ackCallback) // RTP16a, RTL6c } return resultFunc(func(ctx context.Context) error { diff --git a/ably/rest_channel.go b/ably/rest_channel.go index ffd55804d..b58ef2a3c 100644 --- a/ably/rest_channel.go +++ b/ably/rest_channel.go @@ -3,6 +3,7 @@ package ably import ( "context" "encoding/json" + "errors" "fmt" "net/url" "strconv" @@ -64,6 +65,22 @@ type publishMultipleOptions struct { params map[string]string } +// publishResponse represents the response from the server after publishing messages. +type publishResponse struct { + Serials []string `json:"serials,omitempty" codec:"serials,omitempty"` +} + +// validateMessageSerial validates that a message has a serial for update operations. +func validateMessageSerial(msg *Message) error { + if msg == nil { + return newError(40003, fmt.Errorf("message cannot be nil")) + } + if msg.Serial == "" { + return newError(40003, fmt.Errorf("this message lacks a serial and cannot be updated. Make sure you have enabled \"Message annotations, updates, and deletes\" in channel settings on your dashboard")) + } + return nil +} + // PublishWithConnectionKey allows a message to be published for a specified connectionKey. func PublishWithConnectionKey(connectionKey string) PublishMultipleOption { return func(options *publishMultipleOptions) { @@ -86,8 +103,9 @@ func PublishMultipleWithParams(params map[string]string) PublishMultipleOption { return PublishWithParams(params) } -// PublishMultiple publishes multiple messages in a batch. Returns error if there is a problem publishing message (RSL1). -func (c *RESTChannel) PublishMultiple(ctx context.Context, messages []*Message, options ...PublishMultipleOption) error { +// publishMultiple is the internal implementation for publishing multiple messages. +// If out is non-nil, the response body will be decoded into it. +func (c *RESTChannel) publishMultiple(ctx context.Context, messages []*Message, out interface{}, options ...PublishMultipleOption) error { var publishOpts publishMultipleOptions for _, o := range options { o(&publishOpts) @@ -146,13 +164,18 @@ func (c *RESTChannel) PublishMultiple(ctx context.Context, messages []*Message, } } - res, err := c.client.post(ctx, c.baseURL+"/messages"+query, messages, nil) + res, err := c.client.post(ctx, c.baseURL+"/messages"+query, messages, out) if err != nil { return err } return res.Body.Close() } +// PublishMultiple publishes multiple messages in a batch. Returns error if there is a problem publishing message (RSL1). +func (c *RESTChannel) PublishMultiple(ctx context.Context, messages []*Message, options ...PublishMultipleOption) error { + return c.publishMultiple(ctx, messages, nil, options...) +} + // PublishMultipleWithOptions is the same as PublishMultiple. // // Deprecated: Use PublishMultiple instead. @@ -162,6 +185,112 @@ func (c *RESTChannel) PublishMultipleWithOptions(ctx context.Context, messages [ return c.PublishMultiple(ctx, messages, options...) } +// PublishWithResult publishes a single message to the channel with the given event name and payload, +// and returns the serial assigned by the server (RSL1n, RSL1n1 alternative). +// Returns error if there is a problem performing message publish. +func (c *RESTChannel) PublishWithResult(ctx context.Context, name string, data interface{}, options ...PublishMultipleOption) (*PublishResult, error) { + results, err := c.PublishMultipleWithResult(ctx, []*Message{{Name: name, Data: data}}, options...) + if err != nil { + return nil, err + } + if len(results) == 0 { + return &PublishResult{}, nil + } + return &results[0], nil +} + +// PublishMultipleWithResult publishes multiple messages in a batch and returns the serials +// assigned by the server (RSL1n, RSL1n1 alternative). +// Returns error if there is a problem publishing messages. +func (c *RESTChannel) PublishMultipleWithResult(ctx context.Context, messages []*Message, options ...PublishMultipleOption) ([]PublishResult, error) { + var response publishResponse + if err := c.publishMultiple(ctx, messages, &response, options...); err != nil { + return nil, err + } + + // Build results from serials + results := make([]PublishResult, len(messages)) + for i := range results { + if i < len(response.Serials) { + serial := response.Serials[i] + results[i].Serial = &serial + } + } + return results, nil +} + +// updateDeleteResult is the wire format for the response from update/delete/append operations (RSL15e, UDR2a). +type updateDeleteResult struct { + VersionSerial *string `json:"versionSerial,omitempty" codec:"versionSerial,omitempty"` +} + +// performMessageOperation is a shared helper for UpdateMessage, DeleteMessage, and AppendMessage. +// It validates the message serial, applies update options, sets the action, encodes data, and sends the request. +// Uses PATCH /channels/{name}/messages/{serial} per RSL15b with a single Message body (not an array). +func (c *RESTChannel) performMessageOperation(ctx context.Context, msg *Message, action MessageAction, options ...UpdateOption) (*UpdateDeleteResult, error) { + if err := validateMessageSerial(msg); err != nil { + return nil, err + } + + // Apply options + var opts updateOptions + for _, o := range options { + o(&opts) + } + + // RSL15c: Copy message to avoid mutating user-supplied object. + opMsg := *msg + // RSL15b1: Set action for the operation. + opMsg.Action = action + // RSL15b7: Set version from user-supplied operation metadata. + opMsg.Version = opts.version + + // RSL15d: Encode message data per RSL4/RSC8. + cipher, _ := c.options.GetCipher() + var err error + opMsg, err = opMsg.withEncodedData(cipher) + if err != nil { + return nil, fmt.Errorf("encoding data for message: %w", err) + } + + // Build URL: PATCH /channels/{name}/messages/{serial} per RSL15b + path := c.baseURL + "/messages/" + url.PathEscape(msg.Serial) + + // Append query params if provided (RSL15a, RSL15f) + if len(opts.params) > 0 { + queryParams := url.Values{} + for k, v := range opts.params { + queryParams.Set(k, v) + } + path += "?" + queryParams.Encode() + } + + // PATCH with single Message body (not an array), parse UpdateDeleteResult (RSL15e) + var response updateDeleteResult + res, err := c.client.patch(ctx, path, &opMsg, &response) + if err != nil { + return nil, err + } + defer res.Body.Close() + + return &UpdateDeleteResult{VersionSerial: response.VersionSerial}, nil +} + +// UpdateMessage updates a previously published message. +func (c *RESTChannel) UpdateMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateDeleteResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionUpdate, options...) +} + +// DeleteMessage deletes a previously published message. +func (c *RESTChannel) DeleteMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateDeleteResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionDelete, options...) +} + +// AppendMessage appends to a previously published message. +func (c *RESTChannel) AppendMessage(ctx context.Context, msg *Message, options ...UpdateOption) (*UpdateDeleteResult, error) { + return c.performMessageOperation(ctx, msg, MessageActionAppend, options...) +} + // ChannelDetails contains the details of a [ably.RESTChannel] or [ably.RealtimeChannel] object // such as its ID and [ably.ChannelStatus]. type ChannelDetails struct { @@ -300,12 +429,18 @@ func (o *historyOptions) apply(opts ...HistoryOption) url.Values { type HistoryRequest struct { r paginatedRequest channel *RESTChannel + // err is a validation error set at construction time. If non-nil, Pages and + // Items return it immediately without making any requests. + err error } // Pages returns an iterator for whole pages of History. // // See package-level documentation => [ably] Pagination for details about history pagination. func (r HistoryRequest) Pages(ctx context.Context) (*MessagesPaginatedResult, error) { + if r.err != nil { + return nil, r.err + } var res MessagesPaginatedResult return &res, res.load(ctx, r.r) } @@ -352,6 +487,9 @@ func (p *MessagesPaginatedResult) Items() []*Message { // // See package-level documentation => [ably] Pagination for details about history pagination. func (r HistoryRequest) Items(ctx context.Context) (*MessagesPaginatedItems, error) { + if r.err != nil { + return nil, r.err + } var res MessagesPaginatedItems var err error res.next, err = res.loadItems(ctx, r.r, func() (interface{}, func() int) { @@ -370,11 +508,20 @@ func (c *RESTChannel) fullMessagesDecoder(dst *[]*Message) interface{} { return &fullMessagesDecoder{dst: dst, c: c} } +func (c *RESTChannel) fullMessageDecoder(dst *Message) interface{} { + return &fullMessageDecoder{dst: dst, c: c} +} + type fullMessagesDecoder struct { dst *[]*Message c *RESTChannel } +type fullMessageDecoder struct { + dst *Message + c *RESTChannel +} + func (t *fullMessagesDecoder) UnmarshalJSON(b []byte) error { err := json.Unmarshal(b, &t.dst) if err != nil { @@ -405,11 +552,43 @@ func (t *fullMessagesDecoder) decodeMessagesData() { *m, err = m.withDecodedData(cipher) if err != nil { // RSL6b - t.c.log().Errorf("Couldn't fully decode message data from channel %q: %w", t.c.Name, err) + t.c.log().Errorf("Couldn't fully decode message data from channel %q: %v", t.c.Name, err) } } } +func (t *fullMessageDecoder) UnmarshalJSON(b []byte) error { + err := json.Unmarshal(b, t.dst) + if err != nil { + return err + } + t.decodeMessageData() + return nil +} + +func (t *fullMessageDecoder) CodecEncodeSelf(*codec.Encoder) { + panic("fullMessageDecoder cannot be used as encoder") +} + +func (t *fullMessageDecoder) CodecDecodeSelf(decoder *codec.Decoder) { + decoder.MustDecode(t.dst) + t.decodeMessageData() +} + +var _ interface { + json.Unmarshaler + codec.Selfer +} = (*fullMessageDecoder)(nil) + +func (t *fullMessageDecoder) decodeMessageData() { + cipher, _ := t.c.options.GetCipher() + var err error + *t.dst, err = t.dst.withDecodedData(cipher) + if err != nil { + t.c.log().Errorf("Couldn't fully decode message data from channel %q: %v", t.c.Name, err) + } +} + type MessagesPaginatedItems struct { PaginatedResult items []*Message @@ -436,6 +615,38 @@ func (p *MessagesPaginatedItems) Item() *Message { return p.item } +// GetMessage retrieves a message by its serial. +func (c *RESTChannel) GetMessage(ctx context.Context, serial string) (*Message, error) { + if serial == "" { + return nil, newError(40003, errors.New("serial is required to retrieve a message")) + } + var message Message + req := &request{ + Method: "GET", + Path: c.baseURL + "/messages/" + url.PathEscape(serial), + Out: c.fullMessageDecoder(&message), + } + _, err := c.client.do(ctx, req) + if err != nil { + return nil, err + } + return &message, nil +} + +// GetMessageVersions retrieves the version history of a message by its serial. +// Returns a HistoryRequest that can be used to paginate through message versions. +func (c *RESTChannel) GetMessageVersions(serial string, params url.Values) HistoryRequest { + if serial == "" { + return HistoryRequest{err: newError(40003, errors.New("serial is required to retrieve message versions"))} + } + path := c.baseURL + "/messages/" + url.PathEscape(serial) + "/versions" + rawPath := "/channels/" + c.pathName() + "/messages/" + url.PathEscape(serial) + "/versions" + return HistoryRequest{ + r: c.client.newPaginatedRequest(path, rawPath, params), + channel: c, + } +} + func (c *RESTChannel) log() logger { return c.client.log } diff --git a/ably/rest_client.go b/ably/rest_client.go index 233e79a4f..29ce07a5c 100644 --- a/ably/rest_client.go +++ b/ably/rest_client.go @@ -188,9 +188,26 @@ func (c *REST) Time(ctx context.Context) (time.Time, error) { // [ably.PaginatedResult] object, containing an array of [Stats]{@link Stats} objects (RSC6a). // // See package-level documentation => [ably] Pagination for handling stats pagination. +// +// Note: Stats requests use protocol version 2 to maintain compatibility with the existing +// nested Stats structure. Migrating to the flattened protocol v3+ stats format is planned +// for ably-go v2 as it requires breaking API changes. func (c *REST) Stats(o ...StatsOption) StatsRequest { params := (&statsOptions{}).apply(o...) - return StatsRequest{r: c.newPaginatedRequest("/stats", "", params)} + + // Use protocol v2 for stats to maintain compatibility with existing Stats structure. + // Protocol v3+ uses a flattened format that would require breaking API changes. + statsHeader := make(http.Header) + statsHeader.Set(ablyProtocolVersionHeader, "2") + + req := c.newPaginatedRequest("/stats", "", params) + // Override the query function to use getWithHeader with protocol v2 + req.query = func(ctx context.Context, path string) (*http.Response, error) { + // Pass nil for out because pagination decodes the response separately + return c.getWithHeader(ctx, path, nil, statsHeader) + } + + return StatsRequest{r: req} } func (c *REST) setActiveRealtimeHost(realtimeHost string) { @@ -622,6 +639,17 @@ func (c *REST) get(ctx context.Context, path string, out interface{}) (*http.Res return c.do(ctx, r) } +// getWithHeader is like get but allows specifying custom HTTP headers. +func (c *REST) getWithHeader(ctx context.Context, path string, out interface{}, header http.Header) (*http.Response, error) { + r := &request{ + Method: "GET", + Path: path, + Out: out, + header: header, + } + return c.do(ctx, r) +} + func (c *REST) post(ctx context.Context, path string, in, out interface{}) (*http.Response, error) { r := &request{ Method: "POST", @@ -632,6 +660,16 @@ func (c *REST) post(ctx context.Context, path string, in, out interface{}) (*htt return c.do(ctx, r) } +func (c *REST) patch(ctx context.Context, path string, in, out interface{}) (*http.Response, error) { + r := &request{ + Method: "PATCH", + Path: path, + In: in, + Out: out, + } + return c.do(ctx, r) +} + func (c *REST) do(ctx context.Context, r *request) (*http.Response, error) { return c.doWithHandle(ctx, r, c.handleResponse) } @@ -784,8 +822,11 @@ func (c *REST) newHTTPRequest(ctx context.Context, r *request) (*http.Request, e if r.header != nil { copyHeader(req.Header, r.header) } - req.Header.Set("Accept", protocol) // RSC19c - req.Header.Set(ablyProtocolVersionHeader, ablyProtocolVersion) // RSC7a + req.Header.Set("Accept", protocol) // RSC19c + // RSC7a - Only set protocol version if not already set by custom header + if req.Header.Get(ablyProtocolVersionHeader) == "" { + req.Header.Set(ablyProtocolVersionHeader, ablyProtocolVersion) + } req.Header.Set(ablyAgentHeader, ablyAgentIdentifier(c.opts.Agents)) // RSC7d if c.opts.ClientID != "" && c.Auth.method == authBasic { // References RSA7e2 diff --git a/ably/state.go b/ably/state.go index a0e922274..c3c5bf916 100644 --- a/ably/state.go +++ b/ably/state.go @@ -125,14 +125,14 @@ func (q *pendingEmitter) Dismiss() []msgWithAckCallback { return cx } -func (q *pendingEmitter) Enqueue(msg *protocolMessage, onAck func(err error)) { +func (q *pendingEmitter) Enqueue(msg *protocolMessage, ackCallback *msgAckCallback) { if len(q.queue) > 0 { expected := q.queue[len(q.queue)-1].msg.MsgSerial + 1 if got := msg.MsgSerial; expected != got { panic(fmt.Sprintf("protocol violation: expected next enqueued message to have msgSerial %d; got %d", expected, got)) } } - q.queue = append(q.queue, msgWithAckCallback{msg, onAck}) + q.queue = append(q.queue, msgWithAckCallback{msg, ackCallback}) } func (q *pendingEmitter) Ack(msg *protocolMessage, errInfo *ErrorInfo) { @@ -180,15 +180,48 @@ func (q *pendingEmitter) Ack(msg *protocolMessage, errInfo *ErrorInfo) { err = errImplictNACK } q.log.Verbosef("received %v for message serial %d", msg.Action, sch.msg.MsgSerial) - if sch.onAck != nil { - sch.onAck(err) + + // Extract the corresponding result for this message from the res array. + // The res array contains results only for messages that were actually ACKed, + // not for implicitly NACKed messages (those with i < serialShift). + var result []*protocolPublishResult + resIndex := i - serialShift + if msg.Res != nil && resIndex >= 0 && resIndex < len(msg.Res) { + result = []*protocolPublishResult{msg.Res[resIndex]} } + sch.ackCallback.call(result, err) + } +} + +type msgAckCallback struct { + onAck func(err error) + onAckWithSerials func(serials []string, err error) +} + +// call invokes the appropriate ackCallback based on which is set. +// If onAckWithSerials is set, extracts serials from res before calling. +// If onAck is set, calls it with just the error. +func (cb *msgAckCallback) call(res []*protocolPublishResult, err error) { + if cb == nil { + return + } + if cb.onAckWithSerials != nil { + // Extract serials from results + var serials []string + for _, result := range res { + if result != nil && len(result.Serials) > 0 { + serials = append(serials, result.Serials...) + } + } + cb.onAckWithSerials(serials, err) + } else if cb.onAck != nil { + cb.onAck(err) } } type msgWithAckCallback struct { - msg *protocolMessage - onAck func(err error) + msg *protocolMessage + ackCallback *msgAckCallback } type msgQueue struct { @@ -203,17 +236,17 @@ func newMsgQueue(conn *Connection) *msgQueue { } } -func (q *msgQueue) Enqueue(msg *protocolMessage, onAck func(err error)) { +func (q *msgQueue) Enqueue(msg *protocolMessage, ackCallback *msgAckCallback) { q.mtx.Lock() // TODO(rjeczalik): reorder the queue so Presence / Messages can be merged - q.queue = append(q.queue, msgWithAckCallback{msg, onAck}) + q.queue = append(q.queue, msgWithAckCallback{msg, ackCallback}) q.mtx.Unlock() } func (q *msgQueue) Flush() { q.mtx.Lock() for _, queueMsg := range q.queue { - q.conn.send(queueMsg.msg, queueMsg.onAck) + q.conn.send(queueMsg.msg, queueMsg.ackCallback) } q.queue = nil q.mtx.Unlock() @@ -223,9 +256,7 @@ func (q *msgQueue) Fail(err error) { q.mtx.Lock() for _, queueMsg := range q.queue { q.log().Errorf("failure sending message (serial=%d): %v", queueMsg.msg.MsgSerial, err) - if queueMsg.onAck != nil { - queueMsg.onAck(newError(90000, err)) - } + queueMsg.ackCallback.call(nil, newError(90000, err)) } q.queue = nil q.mtx.Unlock() diff --git a/ably/state_test.go b/ably/state_test.go new file mode 100644 index 000000000..5a4f42aee --- /dev/null +++ b/ably/state_test.go @@ -0,0 +1,162 @@ +//go:build !integration +// +build !integration + +package ably + +import ( + "io" + "log" + "testing" + + "github.com/stretchr/testify/assert" +) + +// This test verifies that when an ACK message contains a "res" array with multiple results, +// each result is correctly associated with its corresponding message based on msgSerial. +// +// Scenario: SDK sends multiple messages with consecutive msgSerials +// Server responds with ACK: msgSerial=X, count=N, res=[result1, result2, ..., resultN] +// Expected behavior: Each message should receive serials from only its corresponding res element +func TestPendingEmitter_AckResult(t *testing.T) { + t.Run("two messages with single serial each", func(t *testing.T) { + testLogger := logger{l: &stdLogger{log.New(io.Discard, "", 0)}} + emitter := newPendingEmitter(testLogger) + + // Track what serials each message receives + var msg1Serials, msg2Serials []string + + // Create two protocol messages with consecutive msgSerials + protoMsg1 := &protocolMessage{ + MsgSerial: 5, + Action: actionMessage, + Channel: "test-channel", + } + callback1 := &msgAckCallback{ + onAckWithSerials: func(serials []string, err error) { + msg1Serials = serials + }, + } + + protoMsg2 := &protocolMessage{ + MsgSerial: 6, + Action: actionMessage, + Channel: "test-channel", + } + callback2 := &msgAckCallback{ + onAckWithSerials: func(serials []string, err error) { + msg2Serials = serials + }, + } + + // Enqueue both messages + emitter.Enqueue(protoMsg1, callback1) + emitter.Enqueue(protoMsg2, callback2) + + // Simulate receiving an ACK with msgSerial=5, count=2, and res array with two distinct results + ackMsg := &protocolMessage{ + Action: actionAck, + MsgSerial: 5, + Count: 2, + Res: []*protocolPublishResult{ + {Serials: []string{"serial-for-msg-5"}}, // Should only go to message 1 + {Serials: []string{"serial-for-msg-6"}}, // Should only go to message 2 + }, + } + + // Process the ACK + emitter.Ack(ackMsg, nil) + + // Verify each message received only its corresponding serials + assert.Equal(t, []string{"serial-for-msg-5"}, msg1Serials, + "Message 1 (msgSerial=5) should only receive serials from res[0]") + assert.Equal(t, []string{"serial-for-msg-6"}, msg2Serials, + "Message 2 (msgSerial=6) should only receive serials from res[1]") + }) + + t.Run("two messages with multiple serials each", func(t *testing.T) { + testLogger := logger{l: &stdLogger{log.New(io.Discard, "", 0)}} + emitter := newPendingEmitter(testLogger) + + var msg1Serials, msg2Serials []string + + protoMsg1 := &protocolMessage{ + MsgSerial: 10, + Action: actionMessage, + Channel: "test-channel", + } + callback1 := &msgAckCallback{ + onAckWithSerials: func(serials []string, err error) { + msg1Serials = serials + }, + } + + protoMsg2 := &protocolMessage{ + MsgSerial: 11, + Action: actionMessage, + Channel: "test-channel", + } + callback2 := &msgAckCallback{ + onAckWithSerials: func(serials []string, err error) { + msg2Serials = serials + }, + } + + emitter.Enqueue(protoMsg1, callback1) + emitter.Enqueue(protoMsg2, callback2) + + // ACK with multiple serials per result + ackMsg := &protocolMessage{ + Action: actionAck, + MsgSerial: 10, + Count: 2, + Res: []*protocolPublishResult{ + {Serials: []string{"serial-10-a", "serial-10-b"}}, // Both should go to message 1 + {Serials: []string{"serial-11-a", "serial-11-b"}}, // Both should go to message 2 + }, + } + + emitter.Ack(ackMsg, nil) + + assert.Equal(t, []string{"serial-10-a", "serial-10-b"}, msg1Serials, + "Message 1 (msgSerial=10) should receive both serials from res[0]") + assert.Equal(t, []string{"serial-11-a", "serial-11-b"}, msg2Serials, + "Message 2 (msgSerial=11) should receive both serials from res[1]") + }) + + t.Run("three messages", func(t *testing.T) { + testLogger := logger{l: &stdLogger{log.New(io.Discard, "", 0)}} + emitter := newPendingEmitter(testLogger) + + var msg1Serials, msg2Serials, msg3Serials []string + + protoMsg1 := &protocolMessage{MsgSerial: 1, Action: actionMessage, Channel: "test"} + callback1 := &msgAckCallback{onAckWithSerials: func(serials []string, err error) { msg1Serials = serials }} + + protoMsg2 := &protocolMessage{MsgSerial: 2, Action: actionMessage, Channel: "test"} + callback2 := &msgAckCallback{onAckWithSerials: func(serials []string, err error) { msg2Serials = serials }} + + protoMsg3 := &protocolMessage{MsgSerial: 3, Action: actionMessage, Channel: "test"} + callback3 := &msgAckCallback{onAckWithSerials: func(serials []string, err error) { msg3Serials = serials }} + + emitter.Enqueue(protoMsg1, callback1) + emitter.Enqueue(protoMsg2, callback2) + emitter.Enqueue(protoMsg3, callback3) + + ackMsg := &protocolMessage{ + Action: actionAck, + MsgSerial: 1, + Count: 3, + Res: []*protocolPublishResult{ + {Serials: []string{"serial-1"}}, + {Serials: []string{"serial-2"}}, + {Serials: []string{"serial-3"}}, + }, + } + + emitter.Ack(ackMsg, nil) + + assert.Equal(t, []string{"serial-1"}, msg1Serials, "Message 1 should receive serial-1") + assert.Equal(t, []string{"serial-2"}, msg2Serials, "Message 2 should receive serial-2") + assert.Equal(t, []string{"serial-3"}, msg3Serials, "Message 3 should receive serial-3") + }) +} diff --git a/ablytest/sandbox.go b/ablytest/sandbox.go index a2294bd3b..e8898bb26 100644 --- a/ablytest/sandbox.go +++ b/ablytest/sandbox.go @@ -34,10 +34,11 @@ type Key struct { } type Namespace struct { - ID string `json:"id"` - Created int `json:"created,omitempty"` - Modified int `json:"modified,omitempty"` - Persisted bool `json:"persisted,omitempty"` + ID string `json:"id"` + Created int `json:"created,omitempty"` + Modified int `json:"modified,omitempty"` + Persisted bool `json:"persisted,omitempty"` + MutableMessages bool `json:"mutableMessages,omitempty"` } type Presence struct { @@ -80,6 +81,7 @@ func DefaultConfig() *Config { }, Namespaces: []Namespace{ {ID: "persisted", Persisted: true}, + {ID: "mutable", MutableMessages: true}, }, Channels: []Channel{ {