diff --git a/.gitignore b/.gitignore index 7700465..bbb3b4e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +build-asan/ +*.idx + # Build directory. build/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..5fc54f6 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,67 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What this is + +`mgclient` is a C library implementing the Bolt protocol client for [Memgraph](https://www.memgraph.com) (also compatible with Neo4j Bolt). The core library is C11. A header-only C++17 wrapper (`mgclient_cpp`) sits on top, and it can also be compiled to WebAssembly via Emscripten. + +## Build & test + +Standard build (produces `libmgclient.a` + `libmgclient.so`/`.dylib`): + +``` +mkdir build && cd build +cmake .. +make +``` + +With tests enabled (this also forces `BUILD_CPP_BINDINGS=ON`): + +``` +cmake -DBUILD_TESTING=ON -DBUILD_TESTING_INTEGRATION=ON .. +make +ctest +``` + +- **Run a single test:** `ctest -R encoder` (test names: `value`, `encoder`, `decoder`, `client`, `transport`, `allocator`, `unit_mgclient_value`, plus `integration_basic_c`, `integration_basic_cpp`, `example_*`). +- **Unit tests only (no running Memgraph):** `ctest -E "example|integration"`. The `integration_*` and `example_*` tests require a live Memgraph on `127.0.0.1:7687`. +- **OpenSSL not found:** pass `-DOPENSSL_ROOT_DIR=...` (see README for macOS/Windows specifics). +- **WASM build (Linux only):** `cmake .. -DWASM=ON && make` → emits `mgclient.js` + `mgclient.wasm`. WASM uses WebSocket transport and has no OpenSSL dependency. + +## Formatting + +`./tool/format.sh` runs `clang-format` (Google style, 80-col, right-aligned pointers) over all `*.c/*.h/*.cpp/*.hpp` files **in place** and fails if anything changed. CI runs this on every push/PR as the `clang_check` job — formatting failures break CI, so run it before committing. + +Coverage report: `./tool/coverage.sh` (requires a build with `-DENABLE_COVERAGE=ON`; uses `llvm-profdata`/`llvm-cov`). + +## Architecture + +The library is layered. Public symbols are exported via the `MGCLIENT_EXPORT` macro (generated `mgclient-export.h`); everything in `src/*.h` is internal. + +- **`include/mgclient.h`** — the entire public C API and its Doxygen documentation. The big comment block at the top is the authoritative spec for the **ownership model** (read it before touching value/container code): non-const pointer returns transfer ownership to the caller; const pointer returns are read-only views valid only while the owner lives; insert functions steal ownership of inserted values. Getting this wrong causes double-frees. + +- **Session layer** (`mgsession.c`, `mgsession.h`) — `mg_session` is the connection object and is *single-command-at-a-time*: you `mg_session_run` a query, then `mg_session_pull` rows until it returns 0 before running anything else. `mg_connect` performs the Bolt handshake and HELLO. The session struct holds the in/out buffers, the negotiated Bolt `version`, transaction state (`explicit_transaction`), and two allocators (one general, one scoped to decoding). + +- **Encoder / decoder** (`mgsession-encoder.c`, `mgsession-decoder.c`) — serialize/deserialize Bolt messages and values over the chunked Bolt framing. All Bolt markers, struct signatures, and message signatures live in `src/mgconstants.h` — this is the reference when adding a new value type or Bolt message. Note that Bolt has multiple protocol versions and some value types (temporal types, ZonedDateTime) are version-gated; check how `session->version` is consulted. + +- **Transport layer** (`mgtransport.c`, `mgtransport.h`) — polymorphic `mg_transport` struct of function pointers (`send`/`recv`/`destroy`/suspend hooks). Three implementations: `mg_raw_transport` (plain socket), `mg_secure_transport` (OpenSSL/SSL, supports peer pubkey fingerprint verification via trust callback), and the WASM WebSocket path. The session talks only to the `mg_transport` interface and is agnostic to which one is in use. + +- **Socket layer** — OS-specific, selected at CMake configure time: `src/{linux,apple,windows}/mgsocket.c` (matching `mgcommon.h` per platform). The build picks exactly one based on `MGCLIENT_ON_{LINUX,APPLE,WINDOWS}`. + +- **Values** (`mgvalue.c`, `mgvalue.h`) — implementation of all Bolt data types (`mg_value`, `mg_string`, `mg_list`, `mg_map`, `mg_node`, `mg_relationship`, `mg_path`, temporal types, points). This is the largest file and where the ownership rules from `mgclient.h` are enforced. + +- **Allocator** (`mgallocator.c`, `mgallocator.h`) — pluggable `mg_allocator` interface; the library allocates through it rather than calling `malloc` directly. + +- **C++ wrapper** (`mgclient_cpp/include/`, header-only) — `mg::Client` (RAII connection with `Client::Connect(params)`), `mg::Value`, and an exception hierarchy (`MgException` → `ClientException`/`TransientException`/`DatabaseException`). Pure wrapper over the C API; no separate compiled library. + +## Tests + +- `tests/*.cpp` — unit tests (GTest, fetched via `FetchContent` at `release-1.8.1`). They link against `mgclient-static` and the C++ bindings. +- `tests/integration/` — require a running Memgraph instance; gated behind `BUILD_TESTING_INTEGRATION`. +- `client.cpp` mocks `mg_secure_transport_init` using the linker `--wrap` mechanism (`-Wl,--wrap=` on Linux, `-Wl,-alias,` on Apple) — see `tests/CMakeLists.txt`. If you rename that function, update the wrap flags too. +- `examples/` (`basic.c`, `basic.cpp`, `advanced.cpp`) are also compiled and registered as ctest tests; they double as API usage references. + +## Versioning gotcha + +`CMakeLists.txt` carries two independent version numbers: `project(... VERSION x.y.z)` and `mgclient_SOVERSION`. A minor version bump can mean ABI incompatibility — the SOVERSION must be bumped manually when the ABI changes (it is not derived from the project version). diff --git a/include/mgclient.h b/include/mgclient.h index 9815631..d104658 100644 --- a/include/mgclient.h +++ b/include/mgclient.h @@ -1374,35 +1374,84 @@ MGCLIENT_EXPORT int mg_session_run(mg_session *session, const char *query, const mg_map *extra_run_information, const mg_list **columns, int64_t *qid); +/// Sends a Bolt ROUTE message to a coordinator and returns the routing table. +/// +/// ROUTE is available only on Bolt protocol versions >= 4.3, which must have +/// been negotiated during \ref mg_connect. It is used for client-side routing +/// against a Memgraph high-availability cluster: the client asks a coordinator +/// for the current cluster topology and then connects directly to the +/// appropriate server based on the desired access mode. +/// +/// The session must be in the ready state (not executing/fetching a query and +/// not inside an explicit transaction). +/// +/// \param session A \ref mg_session connected to a coordinator. The Bolt +/// version negotiated for this session must be >= 4.3. +/// \param routing A \ref mg_map with routing context (e.g. the address +/// used to contact the coordinator). Must not be NULL; use +/// an empty map if there is no routing context. +/// \param bookmarks A \ref mg_list of bookmark strings, or NULL for none +/// (treated as an empty list). +/// \param extra A \ref mg_map with extra information. On Bolt 4.4 it is +/// sent verbatim as the third ROUTE field; on Bolt 4.3 +/// only its "db" string entry (if present) is used to +/// populate the separate database-name field. NULL is +/// allowed. +/// \param routing_table On success, a freshly allocated \ref mg_map holding the +/// routing table is stored here (ownership transferred to +/// the caller, who must call \ref mg_map_destroy on it). +/// NULL may be supplied to discard the result. The map has +/// the shape: +/// { +/// "ttl": , +/// "servers": [ +/// { +/// "addresses": ["host:port", ...], +/// "role": "READ" | "WRITE" | "ROUTE" +/// }, +/// ... +/// ] +/// } +/// \return Returns 0 if the routing table was obtained successfully. +/// Returns \ref MG_ERROR_BAD_PARAMETER if \p routing is NULL, +/// \ref MG_ERROR_BAD_CALL if the session is not ready, +/// \ref MG_ERROR_CLIENT_ERROR if the negotiated Bolt version is < 4.3, +/// or another non-zero error code otherwise. +MGCLIENT_EXPORT int mg_session_route(mg_session *session, const mg_map *routing, + const mg_list *bookmarks, + const mg_map *extra, + mg_map **routing_table); + /// Starts an Explicit transaction on the server. /// /// Every run will be part of that transaction until its explicitly ended. /// /// \param session A \ref mg_session on which the transaction /// should be started. \param extra_run_information A \ref mg_map containing -/// extra information that will be used for every statement that is ran as part -/// of the transaction. +/// extra information that will be used for every statement that is ran as +/// part of the transaction. /// It can contain the following information: -/// - bookmarks - list of strings containing some -/// kind of bookmark identification +/// - bookmarks - list of strings containing +/// some kind of bookmark identification /// - tx_timeout - integer that specifies a /// transaction timeout in ms. -/// - tx_metadata - dictionary taht can contain -/// some metadata information, mainly used for -/// logging. -/// - mode - specifies what kind of server is the -/// run targeting. For write access use "w" and -/// for read access use "r". Defaults to write -/// access. +/// - tx_metadata - dictionary taht can +/// contain some metadata information, mainly +/// used for logging. +/// - mode - specifies what kind of server is +/// the run targeting. For write access use +/// "w" and for read access use "r". Defaults +/// to write access. /// - db - specifies the database name for -/// multi-database to select where the transaction -/// takes place. If no `db` is sent or empty -/// string it implies that it is the default -/// database. +/// multi-database to select where the +/// transaction takes place. If no `db` is +/// sent or empty string it implies that it is +/// the default database. /// \return Returns 0 if the transaction was started successfuly. /// Otherwise, a non-zero error code is returned. -MGCLIENT_EXPORT int mg_session_begin_transaction( - mg_session *session, const mg_map *extra_run_information); +MGCLIENT_EXPORT +int mg_session_begin_transaction(mg_session *session, + const mg_map *extra_run_information); /// Commits current Explicit transaction. /// diff --git a/src/mgclient.c b/src/mgclient.c index 58d8284..1d8b57b 100644 --- a/src/mgclient.c +++ b/src/mgclient.c @@ -230,16 +230,21 @@ int validate_session_params(const mg_session_params *params, } static int mg_bolt_handshake(mg_session *session) { - const uint32_t VERSION_NONE = htobe32(0); - const uint32_t VERSION_1 = htobe32(1); + // Advertise supported Bolt versions, highest first. The version word is + // big-endian with the layout 0x0000, so e.g. 4.4 is 0x0404 and + // 1.0 is 0x0001. ROUTE (Bolt >= 4.3) requires negotiating 4.3 or 4.4; 4.1 and + // 1.0 are kept for backward compatibility. + const uint32_t VERSION_4_4 = htobe32(0x0404); + const uint32_t VERSION_4_3 = htobe32(0x0304); const uint32_t VERSION_4_1 = htobe32(0x0104); + const uint32_t VERSION_1 = htobe32(0x0001); mg_transport_suspend_until_ready_to_write(session->transport); if (mg_transport_send(session->transport, MG_HANDSHAKE_MAGIC, strlen(MG_HANDSHAKE_MAGIC)) != 0 || + mg_transport_send(session->transport, (char *)&VERSION_4_4, 4) != 0 || + mg_transport_send(session->transport, (char *)&VERSION_4_3, 4) != 0 || mg_transport_send(session->transport, (char *)&VERSION_4_1, 4) != 0 || - mg_transport_send(session->transport, (char *)&VERSION_1, 4) != 0 || - mg_transport_send(session->transport, (char *)&VERSION_NONE, 4) != 0 || - mg_transport_send(session->transport, (char *)&VERSION_NONE, 4) != 0) { + mg_transport_send(session->transport, (char *)&VERSION_1, 4) != 0) { mg_session_set_error(session, "failed to send handshake data"); return MG_ERROR_SEND_FAILED; } @@ -250,13 +255,18 @@ static int mg_bolt_handshake(mg_session *session) { mg_session_set_error(session, "failed to receive handshake response"); return MG_ERROR_RECV_FAILED; } - if (server_version == VERSION_1) { + uint32_t v = be32toh(server_version); + uint8_t major = (uint8_t)(v & 0xFF); + uint8_t minor = (uint8_t)((v >> 8) & 0xFF); + // Accept exactly the versions we advertised: 1.0, 4.1, 4.3, 4.4. + if (major == 1 && minor == 0) { session->version = 1; - } else if (server_version == VERSION_4_1) { + session->version_minor = 0; + } else if (major == 4 && (minor == 1 || minor == 3 || minor == 4)) { session->version = 4; + session->version_minor = minor; } else { - mg_session_set_error(session, "unsupported protocol version: %" PRIu32, - be32toh(server_version)); + mg_session_set_error(session, "unsupported protocol version: %" PRIu32, v); return MG_ERROR_PROTOCOL_VIOLATION; } return 0; @@ -803,6 +813,148 @@ int mg_session_run(mg_session *session, const char *query, const mg_map *params, return status; } +int mg_session_route(mg_session *session, const mg_map *routing, + const mg_list *bookmarks, const mg_map *extra, + mg_map **routing_table) { + if (!routing) { + mg_session_set_error(session, "routing map must not be NULL"); + return MG_ERROR_BAD_PARAMETER; + } + if (session->status == MG_SESSION_BAD) { + mg_session_set_error(session, "bad session"); + return MG_ERROR_BAD_CALL; + } + if (session->status == MG_SESSION_EXECUTING) { + mg_session_set_error(session, "already executing a query"); + return MG_ERROR_BAD_CALL; + } + if (session->status == MG_SESSION_FETCHING) { + mg_session_set_error(session, "fetching results of a query"); + return MG_ERROR_BAD_CALL; + } + if (session->explicit_transaction) { + mg_session_set_error(session, + "cannot route while in an explicit transaction"); + return MG_ERROR_BAD_CALL; + } + + assert(session->status == MG_SESSION_READY && !session->explicit_transaction); + + if (session->version < 4 || + (session->version == 4 && session->version_minor < 3)) { + mg_session_set_error(session, "ROUTE requires Bolt >= 4.3"); + return MG_ERROR_BAD_CALL; + } + + mg_message_destroy_ca(session->result.message, session->decoder_allocator); + session->result.columns = NULL; + session->result.message = NULL; + + // The encoders dereference the bookmarks list, so a non-NULL empty list must + // be supplied when the caller passes NULL. + mg_list *empty_bookmarks = NULL; + if (!bookmarks) { + empty_bookmarks = mg_list_make_empty(0); + if (!empty_bookmarks) { + mg_session_set_error(session, "out of memory"); + return MG_ERROR_OOM; + } + bookmarks = empty_bookmarks; + } + + int status = 0; + if (session->version_minor >= 4) { + if (!extra) { + extra = &mg_empty_map; + } + status = + mg_session_send_route_message_v4_4(session, routing, bookmarks, extra); + } else { + // Bolt 4.3 carries the database name as a separate string field. Extract it + // from extra["db"] if present, otherwise send an empty string (default db). + // The mg_string data is not null-terminated, so carry its size explicitly. + const char *db = ""; + uint32_t db_size = 0; + if (extra) { + const mg_value *db_tmp = mg_map_at(extra, "db"); + if (db_tmp && mg_value_get_type(db_tmp) == MG_VALUE_TYPE_STRING) { + const mg_string *db_str = mg_value_string(db_tmp); + db = db_str->data; + db_size = db_str->size; + } + } + status = mg_session_send_route_message_v4_3(session, routing, bookmarks, db, + db_size); + } + + if (empty_bookmarks) { + mg_list_destroy(empty_bookmarks); + } + + if (status != 0) { + goto fatal_failure; + } + + mg_transport_suspend_until_ready_to_read(session->transport); + status = mg_session_receive_message(session); + if (status != 0) { + goto fatal_failure; + } + + mg_message *response; + status = mg_session_read_bolt_message(session, &response); + if (status != 0) { + goto fatal_failure; + } + + if (response->type == MG_MESSAGE_TYPE_SUCCESS) { + const mg_value *rt = mg_map_at(response->success_v->metadata, "rt"); + if (!rt || mg_value_get_type(rt) != MG_VALUE_TYPE_MAP) { + status = MG_ERROR_PROTOCOL_VIOLATION; + mg_message_destroy_ca(response, session->decoder_allocator); + mg_session_set_error(session, "invalid response metadata: missing 'rt'"); + goto fatal_failure; + } + // Copy with the system allocator (not session->allocator): ownership is + // transferred to the caller, who releases it with the public + // mg_map_destroy, which itself uses the system allocator. + mg_map *copy = mg_map_copy_ca(mg_value_map(rt), &mg_system_allocator); + mg_message_destroy_ca(response, session->decoder_allocator); + if (!copy) { + mg_session_set_error(session, "out of memory"); + status = MG_ERROR_OOM; + goto fatal_failure; + } + if (routing_table) { + *routing_table = copy; + } else { + mg_map_destroy(copy); + } + return 0; + } + + if (response->type == MG_MESSAGE_TYPE_FAILURE) { + int failure_status = handle_failure_message(session, response->failure_v); + + status = handle_failure(session); + if (status != 0) { + goto fatal_failure; + } + + mg_message_destroy_ca(response, session->decoder_allocator); + return failure_status; + } + + status = MG_ERROR_PROTOCOL_VIOLATION; + mg_message_destroy_ca(response, session->decoder_allocator); + mg_session_set_error(session, "unexpected message type"); + +fatal_failure: + mg_session_invalidate(session); + assert(status != 0); + return status; +} + int mg_session_pull(mg_session *session, const mg_map *pull_information) { if (session->status == MG_SESSION_BAD) { mg_session_set_error(session, "called pull while bad session"); diff --git a/src/mgconstants.h b/src/mgconstants.h index 143ceb6..e0e77e3 100644 --- a/src/mgconstants.h +++ b/src/mgconstants.h @@ -107,6 +107,9 @@ static const uint8_t MG_MARKERS_MAP[] = {MG_MARKER_TINY_MAP, MG_MARKER_MAP_8, #define MG_SIGNATURE_MESSAGE_BEGIN 0x11 #define MG_SIGNATURE_MESSAGE_COMMIT 0x12 #define MG_SIGNATURE_MESSAGE_ROLLBACK 0x13 +// 0x66 also equals MG_SIGNATURE_DATE_TIME_ZONE_ID, but the two are decoded in +// different contexts (struct value vs. message), so the collision is safe. +#define MG_SIGNATURE_MESSAGE_ROUTE 0x66 #ifdef __cplusplus } diff --git a/src/mgsession-encoder.c b/src/mgsession-encoder.c index 8a6e7e1..1d85ba5 100644 --- a/src/mgsession-encoder.c +++ b/src/mgsession-encoder.c @@ -316,6 +316,40 @@ int mg_session_send_run_message(mg_session *session, const char *statement, return mg_session_flush_message(session); } +int mg_session_send_route_message_v4_3(mg_session *session, + const mg_map *routing, + const mg_list *bookmarks, const char *db, + uint32_t db_size) { + int const field_number = 3; + MG_RETURN_IF_FAILED(mg_session_write_uint8( + session, (uint8_t)(MG_MARKER_TINY_STRUCT + field_number))); + MG_RETURN_IF_FAILED( + mg_session_write_uint8(session, MG_SIGNATURE_MESSAGE_ROUTE)); + MG_RETURN_IF_FAILED(mg_session_write_map(session, routing)); + MG_RETURN_IF_FAILED(mg_session_write_list(session, bookmarks)); + // Use the explicit size: db may point at an mg_string's data, which is not + // null-terminated, so strlen-based mg_session_write_string is unsafe here. + MG_RETURN_IF_FAILED(mg_session_write_string2(session, db_size, db)); + + return mg_session_flush_message(session); +} + +int mg_session_send_route_message_v4_4(mg_session *session, + const mg_map *routing, + const mg_list *bookmarks, + const mg_map *extra) { + int const field_number = 3; + MG_RETURN_IF_FAILED(mg_session_write_uint8( + session, (uint8_t)(MG_MARKER_TINY_STRUCT + field_number))); + MG_RETURN_IF_FAILED( + mg_session_write_uint8(session, MG_SIGNATURE_MESSAGE_ROUTE)); + MG_RETURN_IF_FAILED(mg_session_write_map(session, routing)); + MG_RETURN_IF_FAILED(mg_session_write_list(session, bookmarks)); + MG_RETURN_IF_FAILED(mg_session_write_map(session, extra)); + + return mg_session_flush_message(session); +} + int mg_session_send_pull_message(mg_session *session, const mg_map *extra) { uint8_t marker = MG_MARKER_TINY_STRUCT + (session->version == 4); MG_RETURN_IF_FAILED(mg_session_write_uint8(session, marker)); diff --git a/src/mgsession.c b/src/mgsession.c index 7ce4a42..f95359b 100644 --- a/src/mgsession.c +++ b/src/mgsession.c @@ -87,6 +87,8 @@ mg_session *mg_session_init(mg_allocator *allocator) { session->result.message = NULL; session->result.columns = NULL; + session->version = 0; + session->version_minor = 0; session->explicit_transaction = 0; session->query_number = 0; diff --git a/src/mgsession.h b/src/mgsession.h index 2613ba0..00287dc 100644 --- a/src/mgsession.h +++ b/src/mgsession.h @@ -44,6 +44,7 @@ typedef struct mg_session { mg_transport *transport; int version; + int version_minor; char *out_buffer; size_t out_begin; @@ -181,6 +182,19 @@ int mg_session_send_commit_messsage(mg_session *session); int mg_session_send_rollback_messsage(mg_session *session); +// From Bolt 4.4 the database name is carried inside `extra`; on 4.3 it is a +// separate string field. `db`/`db_size` describe that field for the 4.3 form +// (db need not be null-terminated, hence the explicit size). +int mg_session_send_route_message_v4_3(mg_session *session, + const mg_map *routing, + const mg_list *bookmarks, const char *db, + uint32_t db_size); + +int mg_session_send_route_message_v4_4(mg_session *session, + const mg_map *routing, + const mg_list *bookmarks, + const mg_map *extra); + #ifdef __cplusplus } #endif diff --git a/tests/client.cpp b/tests/client.cpp index 59c3839..b81fc5c 100644 --- a/tests/client.cpp +++ b/tests/client.cpp @@ -256,10 +256,10 @@ TEST_F(ConnectTest, HandshakeFail) { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x04\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x03\x04"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x01"s); // Send unsupported version to client. uint32_t version = htobe32(2); @@ -286,10 +286,10 @@ TEST_F(ConnectTest, InitFail) { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x04\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x03\x04"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x01"s); uint32_t version = htobe32(1); ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); @@ -359,10 +359,10 @@ TEST_F(ConnectTest, InitFail_v4) { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x04\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x03\x04"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x01"s); uint32_t version = htobe32(0x0104); ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); @@ -437,10 +437,10 @@ TEST_F(ConnectTest, Success) { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x04\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x03\x04"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x01"s); uint32_t version = htobe32(1); ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); @@ -515,10 +515,10 @@ TEST_F(ConnectTest, Success_v4) { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x04\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x03\x04"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x01"s); uint32_t version = htobe32(0x0104); ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); @@ -598,10 +598,10 @@ TEST_F(ConnectTest, SuccessWithSSL) { char handshake[20]; ASSERT_EQ(RecvData(sockfd, handshake, 20), 0); ASSERT_EQ(std::string(handshake, 4), "\x60\x60\xB0\x17"s); - ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x01\x04"s); - ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x00\x01"s); - ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x00\x00"s); - ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x00"s); + ASSERT_EQ(std::string(handshake + 4, 4), "\x00\x00\x04\x04"s); + ASSERT_EQ(std::string(handshake + 8, 4), "\x00\x00\x03\x04"s); + ASSERT_EQ(std::string(handshake + 12, 4), "\x00\x00\x01\x04"s); + ASSERT_EQ(std::string(handshake + 16, 4), "\x00\x00\x00\x01"s); uint32_t version = htobe32(1); ASSERT_EQ(SendData(sockfd, (char *)&version, 4), 0); @@ -2144,3 +2144,213 @@ TEST_F(RunTest, TransactionWithMultipleRuns) { StopServer(); ASSERT_MEMORY_OK(); } + +class RouteTest : public ::testing::Test { + protected: + virtual void SetUp() override { + int tmp[2]; + ASSERT_EQ(mg_socket_pair(AF_UNIX, SOCK_STREAM, 0, tmp), 0); + sc = tmp[0]; + ss = tmp[1]; + + mg_init(); + session = mg_session_init((mg_allocator *)&allocator); + mg_raw_transport_init(sc, (mg_raw_transport **)&session->transport, + (mg_allocator *)&allocator); + session->status = MG_SESSION_READY; + } + + void RunServer(const std::function &server_func) { + server_thread = std::thread(server_func, ss); + } + void StopServer() { + if (server_thread.joinable()) { + server_thread.join(); + } + } + + int sc; + int ss; + mg_session *session; + std::thread server_thread; + + tracking_allocator allocator; +}; + +// Builds an "rt" routing-table value: +// {"ttl": , "servers": [{"addresses": [
], "role": "WRITE"}]} +static mg_value *MakeRoutingTable(int64_t ttl, const char *address) { + mg_list *addresses = mg_list_make_empty(1); + mg_list_append(addresses, mg_value_make_string(address)); + + mg_map *server = mg_map_make_empty(2); + mg_map_insert_unsafe(server, "addresses", mg_value_make_list(addresses)); + mg_map_insert_unsafe(server, "role", mg_value_make_string("WRITE")); + + mg_list *servers = mg_list_make_empty(1); + mg_list_append(servers, mg_value_make_map(server)); + + mg_map *rt = mg_map_make_empty(2); + mg_map_insert_unsafe(rt, "ttl", mg_value_make_integer(ttl)); + mg_map_insert_unsafe(rt, "servers", mg_value_make_list(servers)); + + return mg_value_make_map(rt); +} + +TEST_F(RouteTest, Success) { + session->version = 4; + session->version_minor = 4; + + RunServer([](int sockfd) { + mg_session *session = mg_session_init(&mg_system_allocator); + session->version = 4; + session->version_minor = 4; + mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, + &mg_system_allocator); + + // Consume the ROUTE message (the decoder does not decode ROUTE, so we just + // receive its chunks without trying to parse it). + ASSERT_EQ(mg_session_receive_message(session), 0); + + // Reply with a SUCCESS carrying the routing table. + { + mg_map *metadata = mg_map_make_empty(1); + mg_map_insert_unsafe(metadata, "rt", + MakeRoutingTable(/*ttl=*/300, "localhost:7687")); + ASSERT_EQ(mg_session_send_success_message(session, metadata), 0); + mg_map_destroy(metadata); + } + + mg_session_destroy(session); + }); + + mg_map *routing = mg_map_make_empty(1); + mg_map_insert_unsafe(routing, "address", + mg_value_make_string("localhost:7688")); + + mg_map *routing_table = nullptr; + ASSERT_EQ( + mg_session_route(session, routing, nullptr, nullptr, &routing_table), 0); + ASSERT_EQ(mg_session_status(session), MG_SESSION_READY); + ASSERT_TRUE(routing_table); + + const mg_value *ttl = mg_map_at(routing_table, "ttl"); + ASSERT_TRUE(ttl); + ASSERT_EQ(mg_value_get_type(ttl), MG_VALUE_TYPE_INTEGER); + ASSERT_EQ(mg_value_integer(ttl), 300); + + const mg_value *servers = mg_map_at(routing_table, "servers"); + ASSERT_TRUE(servers); + ASSERT_EQ(mg_value_get_type(servers), MG_VALUE_TYPE_LIST); + const mg_list *servers_list = mg_value_list(servers); + ASSERT_EQ(mg_list_size(servers_list), 1u); + + const mg_value *server = mg_list_at(servers_list, 0); + ASSERT_EQ(mg_value_get_type(server), MG_VALUE_TYPE_MAP); + const mg_map *server_map = mg_value_map(server); + + const mg_value *role = mg_map_at(server_map, "role"); + ASSERT_TRUE(role); + ASSERT_EQ( + std::string(mg_value_string(role)->data, mg_value_string(role)->size), + "WRITE"); + + const mg_value *addresses = mg_map_at(server_map, "addresses"); + ASSERT_TRUE(addresses); + ASSERT_EQ(mg_value_get_type(addresses), MG_VALUE_TYPE_LIST); + const mg_list *addr_list = mg_value_list(addresses); + ASSERT_EQ(mg_list_size(addr_list), 1u); + const mg_value *addr = mg_list_at(addr_list, 0); + ASSERT_EQ( + std::string(mg_value_string(addr)->data, mg_value_string(addr)->size), + "localhost:7687"); + + mg_map_destroy(routing_table); + mg_map_destroy(routing); + mg_session_destroy(session); + StopServer(); + ASSERT_MEMORY_OK(); +} + +TEST_F(RouteTest, SuccessV4_3WithDb) { + // On Bolt 4.3 the database name is a separate string field that + // mg_session_route extracts from extra["db"] and re-encodes. This exercises + // that extraction path (untested by RouteTest.Success, which uses 4.4 where + // the whole extra map is sent verbatim). + session->version = 4; + session->version_minor = 3; + + RunServer([](int sockfd) { + mg_session *session = mg_session_init(&mg_system_allocator); + session->version = 4; + session->version_minor = 3; + mg_raw_transport_init(sockfd, (mg_raw_transport **)&session->transport, + &mg_system_allocator); + + // Consume the ROUTE message (the decoder does not decode ROUTE, so we just + // receive its chunks without trying to parse it). + ASSERT_EQ(mg_session_receive_message(session), 0); + + // Reply with a SUCCESS carrying the routing table. + { + mg_map *metadata = mg_map_make_empty(1); + mg_map_insert_unsafe(metadata, "rt", + MakeRoutingTable(/*ttl=*/300, "localhost:7687")); + ASSERT_EQ(mg_session_send_success_message(session, metadata), 0); + mg_map_destroy(metadata); + } + + mg_session_destroy(session); + }); + + mg_map *routing = mg_map_make_empty(1); + mg_map_insert_unsafe(routing, "address", + mg_value_make_string("localhost:7688")); + + // extra = {"db": "memgraph"}; the "db" value is a decoded-style mg_string + // (data is NOT null-terminated), so re-encoding it must use its stored size. + mg_map *extra = mg_map_make_empty(1); + mg_map_insert_unsafe(extra, "db", mg_value_make_string("memgraph")); + + mg_map *routing_table = nullptr; + ASSERT_EQ(mg_session_route(session, routing, nullptr, extra, &routing_table), + 0); + ASSERT_EQ(mg_session_status(session), MG_SESSION_READY); + ASSERT_TRUE(routing_table); + + const mg_value *ttl = mg_map_at(routing_table, "ttl"); + ASSERT_TRUE(ttl); + ASSERT_EQ(mg_value_get_type(ttl), MG_VALUE_TYPE_INTEGER); + ASSERT_EQ(mg_value_integer(ttl), 300); + + mg_map_destroy(routing_table); + mg_map_destroy(extra); + mg_map_destroy(routing); + mg_session_destroy(session); + StopServer(); + ASSERT_MEMORY_OK(); +} + +TEST_F(RouteTest, UnsupportedVersion) { + // Negotiated Bolt 4.1: ROUTE is not available. + session->version = 4; + session->version_minor = 1; + + mg_map *routing = mg_map_make_empty(1); + mg_map_insert_unsafe(routing, "address", + mg_value_make_string("localhost:7688")); + + mg_map *routing_table = nullptr; + ASSERT_EQ( + mg_session_route(session, routing, nullptr, nullptr, &routing_table), + MG_ERROR_BAD_CALL); + // The session must remain usable (no transport interaction happened). + ASSERT_EQ(mg_session_status(session), MG_SESSION_READY); + ASSERT_EQ(routing_table, nullptr); + ASSERT_THAT(std::string(mg_session_error(session)), + HasSubstr("ROUTE requires Bolt >= 4.3")); + + mg_map_destroy(routing); + mg_session_destroy(session); + ASSERT_MEMORY_OK(); +} diff --git a/tests/decoder.cpp b/tests/decoder.cpp index c89de0b..0db5a8e 100644 --- a/tests/decoder.cpp +++ b/tests/decoder.cpp @@ -244,6 +244,97 @@ TEST_F(MessageChunkingTest, ManyMessages) { ASSERT_MEMORY_OK(); } +class RouteResponseTest : public DecoderTest {}; + +TEST_F(RouteResponseTest, RoutingTable) { + session = mg_session_init((mg_allocator *)&allocator); + mg_raw_transport_init(sc, (mg_raw_transport **)&session->transport, + (mg_allocator *)&allocator); + ASSERT_TRUE(session); + session->version = 4; + session->version_minor = 4; + + // SUCCESS message (0xB1 0x70) with metadata = {"rt": }. + // rt = {"ttl": 100, "servers": [{"addresses": ["localhost:7687"], + // "role": "WRITE"}]} + std::string msg = + "\xB1\x70" // TINY_STRUCT1, SUCCESS signature + "\xA1\x82" + "rt" // metadata map with single key "rt" + "\xA2" // rt map with 2 entries + "\x83" + "ttl" + "\x64" // "ttl" -> 100 + "\x87" + "servers" + "\x91" // "servers" -> list of 1 + "\xA2" // server map with 2 entries + "\x89" + "addresses" + "\x91" // "addresses" -> list of 1 + "\x8E" + "localhost:7687" // "localhost:7687" + "\x84" + "role" + "\x85" + "WRITE"s; // "role" -> "WRITE" + + client.WriteInChunks(ss, msg); + ASSERT_EQ(mg_session_receive_message(session), 0); + + mg_message *message; + ASSERT_EQ(mg_session_read_bolt_message(session, &message), 0); + ASSERT_EQ(message->type, MG_MESSAGE_TYPE_SUCCESS); + + const mg_map *metadata = message->success_v->metadata; + const mg_value *rt = mg_map_at(metadata, "rt"); + ASSERT_TRUE(rt); + ASSERT_EQ(mg_value_get_type(rt), MG_VALUE_TYPE_MAP); + const mg_map *rt_map = mg_value_map(rt); + + const mg_value *ttl = mg_map_at(rt_map, "ttl"); + ASSERT_TRUE(ttl); + ASSERT_EQ(mg_value_get_type(ttl), MG_VALUE_TYPE_INTEGER); + ASSERT_EQ(mg_value_integer(ttl), 100); + + const mg_value *servers = mg_map_at(rt_map, "servers"); + ASSERT_TRUE(servers); + ASSERT_EQ(mg_value_get_type(servers), MG_VALUE_TYPE_LIST); + const mg_list *servers_list = mg_value_list(servers); + ASSERT_EQ(mg_list_size(servers_list), 1u); + + const mg_value *server = mg_list_at(servers_list, 0); + ASSERT_EQ(mg_value_get_type(server), MG_VALUE_TYPE_MAP); + const mg_map *server_map = mg_value_map(server); + + const mg_value *role = mg_map_at(server_map, "role"); + ASSERT_TRUE(role); + ASSERT_EQ(mg_value_get_type(role), MG_VALUE_TYPE_STRING); + ASSERT_EQ( + std::string(mg_value_string(role)->data, mg_value_string(role)->size), + "WRITE"); + + const mg_value *addresses = mg_map_at(server_map, "addresses"); + ASSERT_TRUE(addresses); + ASSERT_EQ(mg_value_get_type(addresses), MG_VALUE_TYPE_LIST); + const mg_list *addr_list = mg_value_list(addresses); + ASSERT_EQ(mg_list_size(addr_list), 1u); + const mg_value *addr = mg_list_at(addr_list, 0); + ASSERT_EQ(mg_value_get_type(addr), MG_VALUE_TYPE_STRING); + ASSERT_EQ( + std::string(mg_value_string(addr)->data, mg_value_string(addr)->size), + "localhost:7687"); + + mg_message_destroy_ca(message, session->decoder_allocator); + + client.Stop(); + close(ss); + ASSERT_FALSE(client.error); + + mg_session_destroy(session); + ASSERT_MEMORY_OK(); +} + class ValueTest : public DecoderTest, public ::testing::WithParamInterface { protected: diff --git a/tests/encoder.cpp b/tests/encoder.cpp index de4aee0..a45fd2d 100644 --- a/tests/encoder.cpp +++ b/tests/encoder.cpp @@ -109,7 +109,10 @@ void AssertReadRaw(std::stringstream &sstr, const std::string &expected) { } void AssertEnd(std::stringstream &sstr) { - std::stringstream::char_type got = sstr.get(); + // Use int_type (not char_type) to hold the result: on platforms where char is + // unsigned (e.g. aarch64), assigning EOF (-1) to a char yields 0xFF which + // never compares equal to traits::eof(). + std::stringstream::int_type got = sstr.get(); if (got != std::stringstream::traits_type::eof()) { FAIL() << "Expected end of input stream, got character " << ::testing::PrintToString(got); @@ -247,6 +250,102 @@ TEST_F(MessageChunkingTest, ManyMessages) { ASSERT_MEMORY_OK(); } +class RouteMessageTest : public EncoderTest {}; + +TEST_F(RouteMessageTest, V4_3) { + session.version = 4; + session.version_minor = 3; + + mg_map *routing = mg_map_make_empty(0); + mg_list *bookmarks = mg_list_make_empty(0); + + ASSERT_EQ(mg_session_send_route_message_v4_3(&session, routing, bookmarks, + "memgraph", 8), + 0); + mg_raw_transport_destroy(session.transport); + + mg_map_destroy(routing); + mg_list_destroy(bookmarks); + + server.Stop(); + ASSERT_FALSE(server.error); + std::stringstream sstr(server.data); + + // TINY_STRUCT + 3 fields, ROUTE signature, empty map, empty list, + // tiny string "memgraph" (length 8). + ASSERT_READ_MESSAGE(sstr, "\xB3\x66\xA0\x90\x88memgraph"s); + ASSERT_END(sstr); + ASSERT_MEMORY_OK(); +} + +TEST_F(RouteMessageTest, V4_3EmptyDb) { + session.version = 4; + session.version_minor = 3; + + mg_map *routing = mg_map_make_empty(0); + mg_list *bookmarks = mg_list_make_empty(0); + + ASSERT_EQ( + mg_session_send_route_message_v4_3(&session, routing, bookmarks, "", 0), + 0); + mg_raw_transport_destroy(session.transport); + + mg_map_destroy(routing); + mg_list_destroy(bookmarks); + + server.Stop(); + ASSERT_FALSE(server.error); + std::stringstream sstr(server.data); + + // ... empty map, empty list, empty tiny string (0x80). + ASSERT_READ_MESSAGE(sstr, "\xB3\x66\xA0\x90\x80"s); + ASSERT_END(sstr); + ASSERT_MEMORY_OK(); +} + +TEST_F(RouteMessageTest, V4_4) { + session.version = 4; + session.version_minor = 4; + + // routing = {"address": "host:7687"} + mg_map *routing = mg_map_make_empty(1); + ASSERT_EQ( + mg_map_insert(routing, "address", mg_value_make_string("host:7687")), 0); + mg_list *bookmarks = mg_list_make_empty(0); + // extra = {"db": "memgraph"} + mg_map *extra = mg_map_make_empty(1); + ASSERT_EQ(mg_map_insert(extra, "db", mg_value_make_string("memgraph")), 0); + + ASSERT_EQ( + mg_session_send_route_message_v4_4(&session, routing, bookmarks, extra), + 0); + mg_raw_transport_destroy(session.transport); + + mg_map_destroy(routing); + mg_list_destroy(bookmarks); + mg_map_destroy(extra); + + server.Stop(); + ASSERT_FALSE(server.error); + std::stringstream sstr(server.data); + + // TINY_STRUCT + 3, ROUTE signature, + // map {"address"(0x87): "host:7687"(0x89)}, + // empty list (0x90), + // map {"db"(0x82): "memgraph"(0x88)}. + ASSERT_READ_MESSAGE(sstr, + "\xB3\x66\xA1\x87" + "address" + "\x89" + "host:7687" + "\x90\xA1\x82" + "db" + "\x88" + "memgraph"s); + ASSERT_END(sstr); + ASSERT_MEMORY_OK(); +} + class ValueTest : public EncoderTest, public ::testing::WithParamInterface { protected: