diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java index f60ff3eea..3d340af54 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java @@ -93,6 +93,11 @@ public class MysqlSession implements Session { private final String databaseName; private final String tableName; + @FunctionalInterface + private interface SqlOperation { + void execute() throws Exception; + } + /** * Create a MysqlSession with default settings. * @@ -285,6 +290,38 @@ private String getFullTableName() { return "`" + databaseName + "`.`" + tableName + "`"; } + /** + * Execute a write operation in an explicit transaction. + * + *
MysqlSession obtains and owns a fresh JDBC connection for each write method call. This + * helper makes write semantics consistent even when the underlying DataSource defaults to + * {@code autoCommit=false}, and restores the connection's original auto-commit mode before + * returning it to the pool. + */ + private void executeInWriteTransaction(Connection conn, SqlOperation operation) + throws Exception { + boolean originalAutoCommit = conn.getAutoCommit(); + if (originalAutoCommit) { + conn.setAutoCommit(false); + } + + try { + operation.execute(); + conn.commit(); + } catch (Exception e) { + try { + conn.rollback(); + } catch (SQLException rollbackException) { + e.addSuppressed(rollbackException); + } + throw e; + } finally { + if (conn.getAutoCommit() != originalAutoCommit) { + conn.setAutoCommit(originalAutoCommit); + } + } + } + @Override public void save(SessionKey sessionKey, String key, State value) { String sessionId = sessionKey.toIdentifier(); @@ -298,18 +335,21 @@ public void save(SessionKey sessionKey, String key, State value) { + " VALUES (?, ?, ?, ?)" + " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)"; - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(upsertSql)) { - - String json = JsonUtils.getJsonCodec().toJson(value); - - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.setInt(3, SINGLE_STATE_INDEX); - stmt.setString(4, json); - - stmt.executeUpdate(); - + try (Connection conn = dataSource.getConnection()) { + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(upsertSql)) { + String json = JsonUtils.getJsonCodec().toJson(value); + + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.setInt(3, SINGLE_STATE_INDEX); + stmt.setString(4, json); + + stmt.executeUpdate(); + } + }); } catch (Exception e) { throw new RuntimeException("Failed to save state: " + key, e); } @@ -344,42 +384,35 @@ public void save(SessionKey sessionKey, String key, List extends State> values String hashKey = key + HASH_KEY_SUFFIX; try (Connection conn = dataSource.getConnection()) { - // Compute current hash - String currentHash = ListHashUtil.computeHash(values); - - // Get stored hash - String storedHash = getStoredHash(conn, sessionId, hashKey); - - // Get existing count - int existingCount = getListCount(conn, sessionId, key); - - // Determine if full rewrite is needed - boolean needsFullRewrite = - ListHashUtil.needsFullRewrite( - currentHash, storedHash, values.size(), existingCount); - - if (needsFullRewrite) { - // Transaction: delete all + insert all - conn.setAutoCommit(false); - try { - deleteListItems(conn, sessionId, key); - insertAllItems(conn, sessionId, key, values); - saveHash(conn, sessionId, hashKey, currentHash); - conn.commit(); - } catch (Exception e) { - conn.rollback(); - throw e; - } finally { - conn.setAutoCommit(true); - } - } else if (values.size() > existingCount) { - // Incremental append - List extends State> newItems = values.subList(existingCount, values.size()); - insertItems(conn, sessionId, key, newItems, existingCount); - saveHash(conn, sessionId, hashKey, currentHash); - } - // else: no change, skip - + executeInWriteTransaction( + conn, + () -> { + // Compute current hash + String currentHash = ListHashUtil.computeHash(values); + + // Get stored hash + String storedHash = getStoredHash(conn, sessionId, hashKey); + + // Get existing count + int existingCount = getListCount(conn, sessionId, key); + + // Determine if full rewrite is needed + boolean needsFullRewrite = + ListHashUtil.needsFullRewrite( + currentHash, storedHash, values.size(), existingCount); + + if (needsFullRewrite) { + deleteListItems(conn, sessionId, key); + insertAllItems(conn, sessionId, key, values); + saveHash(conn, sessionId, hashKey, currentHash); + } else if (values.size() > existingCount) { + List extends State> newItems = + values.subList(existingCount, values.size()); + insertItems(conn, sessionId, key, newItems, existingCount); + saveHash(conn, sessionId, hashKey, currentHash); + } + // else: no change, skip + }); } catch (Exception e) { throw new RuntimeException("Failed to save list: " + key, e); } @@ -626,13 +659,16 @@ public void delete(SessionKey sessionKey) { String deleteSql = "DELETE FROM " + getFullTableName() + " WHERE session_id = ?"; - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(deleteSql)) { - - stmt.setString(1, sessionId); - stmt.executeUpdate(); - - } catch (SQLException e) { + try (Connection conn = dataSource.getConnection()) { + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(deleteSql)) { + stmt.setString(1, sessionId); + stmt.executeUpdate(); + } + }); + } catch (Exception e) { throw new RuntimeException("Failed to delete session: " + sessionId, e); } } @@ -705,25 +741,34 @@ public DataSource getDataSource() { public int clearAllSessions() { String clearSql = "DELETE FROM " + getFullTableName(); - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(clearSql)) { - - return stmt.executeUpdate(); - - } catch (SQLException e) { + try (Connection conn = dataSource.getConnection()) { + int[] deletedRows = new int[1]; + executeInWriteTransaction( + conn, + () -> { + try (PreparedStatement stmt = conn.prepareStatement(clearSql)) { + deletedRows[0] = stmt.executeUpdate(); + } + }); + return deletedRows[0]; + } catch (Exception e) { throw new RuntimeException("Failed to clear sessions", e); } } /** * Truncate session table from the database (for testing or cleanup). - *
- * This method clears all session records by executing a TRUNCATE TABLE statement on the + * + *
This method clears all session records by executing a TRUNCATE TABLE statement on the * sessions table. TRUNCATE is faster than DELETE as it resets the table without logging * individual row deletions and reclaims storage space immediately. * - *
- * Note: The TRUNCATE operation requires DROP privileges in MySQL. + *
Note: In MySQL, {@code TRUNCATE TABLE} is DDL, triggers an implicit + * commit, and is not rollbackable. For that reason, this method executes the statement + * directly instead of routing it through {@link #executeInWriteTransaction(Connection, + * SqlOperation)}. + * + *
Note: The TRUNCATE operation requires DROP privileges in MySQL.
*
* @return typically 0 if successful
*/
@@ -732,9 +777,7 @@ public int truncateAllSessions() {
try (Connection conn = dataSource.getConnection();
PreparedStatement stmt = conn.prepareStatement(clearSql)) {
-
return stmt.executeUpdate();
-
} catch (SQLException e) {
throw new RuntimeException("Failed to truncate sessions", e);
}
diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java
index b720b7376..fd1811df6 100644
--- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java
+++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java
@@ -21,6 +21,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -232,6 +233,22 @@ void testSaveAndGetSingleState() throws SQLException {
assertEquals(42, loaded.get().count());
}
+ @Test
+ @DisplayName("Should commit single state save when connection auto-commit is disabled")
+ void testSaveSingleStateCommitsWhenAutoCommitDisabled() throws SQLException {
+ when(mockConnection.getAutoCommit()).thenReturn(false);
+ when(mockStatement.execute()).thenReturn(true);
+ when(mockStatement.executeUpdate()).thenReturn(1);
+
+ MysqlSession session = new MysqlSession(mockDataSource, true);
+ SessionKey sessionKey = SimpleSessionKey.of("session_auto_commit_off");
+
+ session.save(sessionKey, "testModule", new TestState("test_value", 42));
+
+ verify(mockConnection).commit();
+ verify(mockConnection, never()).setAutoCommit(true);
+ }
+
@Test
@DisplayName("Should save and get list state correctly")
void testSaveAndGetListState() throws SQLException {
@@ -265,6 +282,49 @@ void testSaveAndGetListState() throws SQLException {
assertEquals("value2", loaded.get(1).value());
}
+ @Test
+ @DisplayName("Should commit incremental list save when connection auto-commit is disabled")
+ void testSaveListIncrementalAppendCommitsWhenAutoCommitDisabled() throws SQLException {
+ when(mockConnection.getAutoCommit()).thenReturn(false);
+ when(mockStatement.execute()).thenReturn(true);
+ when(mockStatement.executeQuery()).thenReturn(mockResultSet);
+ when(mockResultSet.next()).thenReturn(false, true);
+ when(mockResultSet.getInt("max_index")).thenReturn(0);
+ when(mockResultSet.wasNull()).thenReturn(true);
+
+ MysqlSession session = new MysqlSession(mockDataSource, true);
+ SessionKey sessionKey = SimpleSessionKey.of("session_list_auto_commit_off");
+ List