Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 125 additions & 15 deletions src/bedrock_agentcore/memory/integrations/strands/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,36 +628,108 @@ def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs
Optional[SessionMessage]: The message if found, None otherwise.

Note:
This should not be called as (as of now) only the `update_message` method calls this method and
updating messages is not supported in AgentCore Memory.
This reads a single event by ID from AgentCore Memory.
"""
result = self.memory_client.gmdp_client.get_event(
memoryId=self.config.memory_id, actorId=self.config.actor_id, sessionId=session_id, eventId=message_id
)
return SessionMessage.from_dict(result) if result else None

def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any) -> None:
"""Update message data.
"""Update message data in AgentCore Memory.

Note: AgentCore Memory doesn't support updating events,
so this is primarily for validation and logging.
Since AgentCore Memory events are immutable, this method performs an update by
creating a new event with the updated content and deleting the old event.
Comment thread
notgitika marked this conversation as resolved.
This enables features like guardrail redaction via Strands' redact_latest_message().

If the message has not yet been persisted (e.g., still in the message buffer when
batch_size > 1), the buffered message is replaced in-place instead.

Args:
session_id (str): The session ID containing the message.
agent_id (str): The agent ID associated with the message.
session_message (SessionMessage): The message to update.
session_message (SessionMessage): The message to update (with updated content
and the original message_id/eventId).
**kwargs (Any): Additional keyword arguments.

Raises:
SessionException: If session ID doesn't match configuration.
SessionException: If session ID doesn't match configuration or update fails.
"""
if session_id != self.config.session_id:
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")

logger.debug(
"Message update requested for message: %s (AgentCore Memory doesn't support updates)",
{session_message.message_id},
)
old_message_id = session_message.message_id

# If message hasn't been persisted yet (still in buffer), update it there
if old_message_id is None:
if self._update_buffered_message(session_message):
logger.debug("Updated buffered message (not yet persisted to AgentCore Memory)")
return
logger.debug("Message has no event ID and was not found in buffer - skipping update")
return

# Create a new event with the updated message content
try:
updated_message = SessionMessage(
Comment thread
notgitika marked this conversation as resolved.
message=session_message.message,
message_id=0,
created_at=session_message.created_at,
)
new_event = self.create_message(session_id, agent_id, updated_message)
except Exception as e:
Comment thread
notgitika marked this conversation as resolved.
logger.error("Failed to update message in AgentCore Memory: %s", e)
raise SessionException(f"Failed to update message: {e}") from e

new_event_id = new_event.get("eventId") if new_event else None
if not new_event_id:
logger.warning("create_message did not return an eventId — skipping delete of old event %s", old_message_id)
return

# Delete the old event; if this fails, roll back the newly created event
try:
self.memory_client.gmdp_client.delete_event(
Comment thread
notgitika marked this conversation as resolved.
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
eventId=old_message_id,
)
except Exception as delete_error:
logger.warning(
"Failed to delete old event %s after creating replacement: %s. Attempting rollback.",
old_message_id,
delete_error,
)
try:
self.memory_client.gmdp_client.delete_event(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
eventId=new_event_id,
)
logger.info("Rolled back new event %s after failed delete of old event", new_event_id)
except Exception as rollback_error:
logger.error(
"Rollback failed: could not delete new event %s: %s. Both old (%s) and new events may exist.",
new_event_id,
rollback_error,
old_message_id,
)
raise SessionException(
f"Failed to update message: could not delete old event: {delete_error}"
) from delete_error

# Update _latest_agent_message so it doesn't hold a stale eventId
latest_messages = getattr(self, "_latest_agent_message", None)
if latest_messages and agent_id in latest_messages:
old_latest = self._latest_agent_message[agent_id]
if old_latest.message_id == old_message_id:
self._latest_agent_message[agent_id] = SessionMessage(
message=session_message.message,
message_id=new_event_id,
created_at=session_message.created_at,
)

logger.info("Updated message in AgentCore Memory: replaced event %s", old_message_id)

def list_messages(
self,
Expand Down Expand Up @@ -857,6 +929,44 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:

# region Batching support

def _update_buffered_message(self, session_message: SessionMessage) -> bool:
"""Attempt to update a message that is still in the send buffer.

When batch_size > 1, messages may not yet be persisted to AgentCore Memory.
This method finds the most recent buffered message matching the session_message's
content role and replaces it with the updated content.

Args:
session_message (SessionMessage): The message with updated content.

Comment thread
notgitika marked this conversation as resolved.
Returns:
bool: True if a buffered message was found and updated, False otherwise.
"""
updated_messages = self.converter.message_to_payload(session_message)
if not updated_messages:
return False

is_blob = self.converter.exceeds_conversational_limit(updated_messages[0])

with self._message_lock:
# Search from the end (most recent) to find the message to update
for i in range(len(self._message_buffer) - 1, -1, -1):
buf = self._message_buffer[i]
if buf.session_id == self.config.session_id and buf.messages:
# Match by role - the most recent message with the same role
existing_role = buf.messages[0][1] if not buf.is_blob else None
new_role = updated_messages[0][1] if not is_blob else None
if existing_role == new_role:
self._message_buffer[i] = BufferedMessage(
session_id=buf.session_id,
messages=updated_messages,
is_blob=is_blob,
timestamp=buf.timestamp,
metadata=buf.metadata,
)
return True
return False

def _flush_messages_only(self) -> list[dict[str, Any]]:
"""Flush only buffered messages to AgentCore Memory.

Expand All @@ -878,6 +988,7 @@ def _flush_messages_only(self) -> list[dict[str, Any]]:

with self._message_lock:
messages_to_send = list(self._message_buffer)
self._message_buffer.clear()

if not messages_to_send:
return []
Expand Down Expand Up @@ -930,11 +1041,10 @@ def _flush_messages_only(self) -> list[dict[str, Any]]:
event.get("eventId"),
)

# Clear message buffer only after ALL messages succeed
with self._message_lock:
self._message_buffer.clear()

except Exception as e:
# Restore messages to buffer so they aren't lost
with self._message_lock:
self._message_buffer.extend(messages_to_send)
logger.error("Failed to flush messages to AgentCore Memory: %s", e)
raise SessionException(f"Failed to flush messages: {e}") from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,20 +496,120 @@ def test_read_message_not_found(self, session_manager, mock_memory_client):

assert result is None

def test_update_message(self, session_manager):
"""Test updating a message."""
message = SessionMessage(message={"role": "user", "content": [{"text": "Hello"}]}, message_id=1)
def test_update_message(self, session_manager, mock_memory_client):
"""Test updating a persisted message creates new event and deletes old one."""
mock_memory_client.create_event.return_value = {"eventId": "new_event_456"}

message = SessionMessage(
message={"role": "user", "content": [{"text": "redacted"}]},
message_id="old_event_123",
created_at="2024-01-01T12:00:00Z",
)

session_manager.update_message("test-session-456", "test-agent-123", message)

# Verify new event was created with correct content
mock_memory_client.create_event.assert_called_once()
create_kwargs = mock_memory_client.create_event.call_args.kwargs
assert "redacted" in str(create_kwargs["messages"])

Comment thread
notgitika marked this conversation as resolved.
# Verify old event was deleted
mock_memory_client.gmdp_client.delete_event.assert_called_once()
delete_kwargs = mock_memory_client.gmdp_client.delete_event.call_args.kwargs
assert delete_kwargs["eventId"] == "old_event_123"
assert delete_kwargs["memoryId"] == "test-memory-123"
assert delete_kwargs["actorId"] == "test-actor-789"
assert delete_kwargs["sessionId"] == "test-session-456"

def test_update_message_updates_latest_agent_message(self, session_manager, mock_memory_client):
"""Test that _latest_agent_message is updated with the new eventId after replacement."""
mock_memory_client.create_event.return_value = {"eventId": "new_event_456"}

# Initialize and pre-populate _latest_agent_message with the old event
session_manager._latest_agent_message = {}
session_manager._latest_agent_message["test-agent-123"] = SessionMessage(
message={"role": "assistant", "content": [{"text": "original"}]},
message_id="old_event_123",
created_at="2024-01-01T12:00:00Z",
)

message = SessionMessage(
message={"role": "assistant", "content": [{"text": "redacted"}]},
message_id="old_event_123",
created_at="2024-01-01T12:00:00Z",
)

# Should not raise any exceptions
session_manager.update_message("test-session-456", "test-agent-123", message)

assert session_manager._latest_agent_message["test-agent-123"].message_id == "new_event_456"

def test_update_message_wrong_session(self, session_manager):
"""Test updating a message with wrong session ID."""
message = SessionMessage(message={"role": "user", "content": [{"text": "Hello"}]}, message_id=1)
Comment thread
notgitika marked this conversation as resolved.

with pytest.raises(SessionException, match="Session ID mismatch"):
session_manager.update_message("wrong-session-id", "test-agent-123", message)

def test_update_message_no_message_id(self, session_manager):
"""Test updating a message with no message_id (not yet persisted) skips gracefully."""
message = SessionMessage(
message={"role": "user", "content": [{"text": "redacted"}]},
message_id=None,
created_at="2024-01-01T12:00:00Z",
)

# Should not raise - just skips since message isn't persisted and buffer is empty
session_manager.update_message("test-session-456", "test-agent-123", message)

def test_update_message_create_fails(self, session_manager, mock_memory_client):
"""Test update_message raises SessionException when create fails and does not delete."""
mock_memory_client.create_event.side_effect = Exception("API Error")

message = SessionMessage(
message={"role": "user", "content": [{"text": "redacted"}]},
message_id="old_event_123",
created_at="2024-01-01T12:00:00Z",
)
Comment thread
notgitika marked this conversation as resolved.

with pytest.raises(SessionException, match="Failed to update message"):
session_manager.update_message("test-session-456", "test-agent-123", message)

mock_memory_client.gmdp_client.delete_event.assert_not_called()

def test_update_message_delete_fails_rollback_succeeds(self, session_manager, mock_memory_client):
"""Test that when delete of old event fails, the new event is rolled back."""
mock_memory_client.create_event.return_value = {"eventId": "new_event_456"}
# First call (delete old) fails, second call (rollback new) succeeds
mock_memory_client.gmdp_client.delete_event.side_effect = [Exception("Delete failed"), None]

message = SessionMessage(
message={"role": "user", "content": [{"text": "redacted"}]},
message_id="old_event_123",
created_at="2024-01-01T12:00:00Z",
)

with pytest.raises(SessionException, match="Failed to update message"):
session_manager.update_message("test-session-456", "test-agent-123", message)

# Verify delete was called twice: once for old event, once for rollback of new event
assert mock_memory_client.gmdp_client.delete_event.call_count == 2
rollback_kwargs = mock_memory_client.gmdp_client.delete_event.call_args_list[1].kwargs
assert rollback_kwargs["eventId"] == "new_event_456"

def test_update_message_delete_fails_rollback_fails(self, session_manager, mock_memory_client):
"""Test that when both delete and rollback fail, exception is still raised."""
mock_memory_client.create_event.return_value = {"eventId": "new_event_456"}
mock_memory_client.gmdp_client.delete_event.side_effect = Exception("Delete failed")

Comment thread
notgitika marked this conversation as resolved.
message = SessionMessage(
message={"role": "user", "content": [{"text": "redacted"}]},
message_id="old_event_123",
created_at="2024-01-01T12:00:00Z",
)

with pytest.raises(SessionException, match="Failed to update message"):
session_manager.update_message("test-session-456", "test-agent-123", message)

def test_list_messages_with_limit(self, session_manager, mock_memory_client):
"""Test listing messages with limit."""
mock_memory_client.list_events.return_value = [
Expand Down Expand Up @@ -1366,6 +1466,35 @@ def test_pending_message_count_with_buffered_messages(self, batching_session_man
# Verify no events were sent (still buffered)
mock_memory_client.create_event.assert_not_called()

def test_update_buffered_message(self, batching_session_manager, mock_memory_client):
"""Test update_message replaces a buffered message in-place when message_id is None."""
# Add a user message to buffer
message = SessionMessage(
message={"role": "user", "content": [{"text": "offensive content"}]},
message_id=0,
created_at="2024-01-01T12:00:00Z",
)
batching_session_manager.create_message("test-session-456", "test-agent", message)
assert batching_session_manager.pending_message_count() == 1

# Update with redacted content (message_id=None simulates unbatched message)
redacted = SessionMessage(
message={"role": "user", "content": [{"text": "Message redacted by guardrail"}]},
message_id=None,
created_at="2024-01-01T12:00:00Z",
)
batching_session_manager.update_message("test-session-456", "test-agent", redacted)

# Buffer should still have 1 message but with updated content
assert batching_session_manager.pending_message_count() == 1
# Verify the buffered content was actually replaced
buffered = batching_session_manager._message_buffer[0]
assert "redacted" in str(buffered.messages) or "Message redacted by guardrail" in str(buffered.messages)
assert "offensive content" not in str(buffered.messages)
# No API calls should have been made (still buffered)
mock_memory_client.create_event.assert_not_called()
mock_memory_client.gmdp_client.delete_event.assert_not_called()

def test_buffer_auto_flushes_at_batch_size(self, batching_session_manager, mock_memory_client):
"""Test buffer automatically flushes when reaching batch_size."""
mock_memory_client.gmdp_client.create_event.return_value = {"eventId": "event_123"}
Expand Down
Loading