diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 693e4754f6eb..4109bdc55aaf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -106,6 +106,7 @@ add_library(bitcoin_common STATIC EXCLUDE_FROM_ALL common/run_command.cpp common/settings.cpp common/signmessage.cpp + common/sockman.cpp common/system.cpp common/url.cpp compressor.cpp diff --git a/src/common/sockman.cpp b/src/common/sockman.cpp new file mode 100644 index 000000000000..c6ed415a54df --- /dev/null +++ b/src/common/sockman.cpp @@ -0,0 +1,372 @@ +// Copyright (c) 2024-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or https://opensource.org/license/mit/. + +#include // IWYU pragma: keep + +#include +#include +#include +#include +#include + +// The set of sockets cannot be modified while waiting +// The sleep time needs to be small to avoid new sockets stalling +static constexpr auto SELECT_TIMEOUT{50ms}; + +bool SockMan::BindAndStartListening(const CService& to, bilingual_str& err_msg) +{ + // Create socket for listening for incoming connections + sockaddr_storage storage; + socklen_t len{sizeof(storage)}; + if (!to.GetSockAddr(reinterpret_cast(&storage), &len)) { + err_msg = Untranslated(strprintf("Bind address family for %s not supported", to.ToStringAddrPort())); + return false; + } + + std::unique_ptr sock{CreateSock(to.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP)}; + if (!sock) { + err_msg = Untranslated(strprintf("Cannot create %s listen socket: %s", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError()))); + return false; + } + + int one{1}; + + // Allow binding if the port is still in TIME_WAIT state after + // the program was closed and restarted. + if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&one), sizeof(one)) == SOCKET_ERROR) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Info, + "Cannot set SO_REUSEADDR on %s listen socket: %s, continuing anyway\n", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError())); + } + + // some systems don't have IPV6_V6ONLY but are always v6only; others do have the option + // and enable it by default or not. Try to enable it, if possible. + if (to.IsIPv6()) { +#ifdef IPV6_V6ONLY + if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&one), sizeof(one)) == SOCKET_ERROR) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Info, + "Cannot set IPV6_V6ONLY on %s listen socket: %s, continuing anyway\n", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError())); + } +#endif +#ifdef WIN32 + int prot_level{PROTECTION_LEVEL_UNRESTRICTED}; + if (sock->SetSockOpt(IPPROTO_IPV6, + IPV6_PROTECTION_LEVEL, + reinterpret_cast(&prot_level), + sizeof(prot_level)) == SOCKET_ERROR) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Info, + "Cannot set IPV6_PROTECTION_LEVEL on %s listen socket: %s, continuing anyway\n", + to.ToStringAddrPort(), + NetworkErrorString(WSAGetLastError())); + } +#endif + } + + if (sock->Bind(reinterpret_cast(&storage), len) == SOCKET_ERROR) { + const int err{WSAGetLastError()}; + if (err == WSAEADDRINUSE) { + err_msg = strprintf(_("Unable to bind to %s on this computer. %s is probably already running."), + to.ToStringAddrPort(), + CLIENT_NAME); + } else { + err_msg = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), + to.ToStringAddrPort(), + NetworkErrorString(err)); + } + return false; + } + + // Listen for incoming connections + if (sock->Listen(SOMAXCONN) == SOCKET_ERROR) { + err_msg = strprintf(_("Cannot listen on %s: %s"), to.ToStringAddrPort(), NetworkErrorString(WSAGetLastError())); + return false; + } + + m_listen.emplace_back(std::move(sock)); + + return true; +} + +void SockMan::StartSocketsThreads(const Options& options) +{ + m_thread_socket_handler = std::thread( + &util::TraceThread, options.socket_handler_thread_name, [this] { ThreadSocketHandler(); }); +} + +void SockMan::JoinSocketsThreads() +{ + if (m_thread_socket_handler.joinable()) { + m_thread_socket_handler.join(); + } +} + +std::unique_ptr SockMan::AcceptConnection(const Sock& listen_sock, CService& addr) +{ + sockaddr_storage storage; + socklen_t len{sizeof(storage)}; + + auto sock{listen_sock.Accept(reinterpret_cast(&storage), &len)}; + + if (!sock) { + const int err{WSAGetLastError()}; + if (err != WSAEWOULDBLOCK) { + LogPrintLevel(BCLog::NET, + BCLog::Level::Error, + "Cannot accept new connection: %s\n", + NetworkErrorString(err)); + } + return {}; + } + + if (!addr.SetSockAddr(reinterpret_cast(&storage), len)) { + LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "Unknown socket family\n"); + } + + return sock; +} + +void SockMan::NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) +{ + AssertLockNotHeld(m_connected_mutex); + + if (!sock->IsSelectable()) { + LogPrintf("connection from %s dropped: non-selectable socket\n", them.ToStringAddrPort()); + return; + } + + // According to the internet TCP_NODELAY is not carried into accepted sockets + // on all platforms. Set it again here just to be sure. + const int on{1}; + if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) { + LogDebug(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n", + them.ToStringAddrPort()); + } + + const Id id{GetNewId()}; + + { + LOCK(m_connected_mutex); + m_connected.emplace(id, std::make_shared(std::move(sock))); + } + + if (!EventNewConnectionAccepted(id, me, them)) { + CloseConnection(id); + } +} + +SockMan::Id SockMan::GetNewId() +{ + return m_next_id.fetch_add(1, std::memory_order_relaxed); +} + +bool SockMan::CloseConnection(Id id) +{ + LOCK(m_connected_mutex); + return m_connected.erase(id) > 0; +} + +ssize_t SockMan::SendBytes(Id id, + std::span data, + bool will_send_more, + std::string& errmsg) const +{ + AssertLockNotHeld(m_connected_mutex); + + if (data.empty()) { + return 0; + } + + auto sockets{GetConnectionSockets(id)}; + if (!sockets) { + // Bail out immediately and just leave things in the caller's send queue. + return 0; + } + + int flags{MSG_NOSIGNAL | MSG_DONTWAIT}; +#ifdef MSG_MORE + if (will_send_more) { + flags |= MSG_MORE; + } +#endif + + const ssize_t sent{WITH_LOCK( + sockets->mutex, + return sockets->sock->Send(reinterpret_cast(data.data()), data.size(), flags);)}; + + if (sent >= 0) { + return sent; + } + + const int err{WSAGetLastError()}; + if (err == WSAEWOULDBLOCK || err == WSAEMSGSIZE || err == WSAEINTR || err == WSAEINPROGRESS) { + return 0; + } + errmsg = NetworkErrorString(err); + return -1; +} + +void SockMan::StopListening() +{ + m_listen.clear(); +} + +bool SockMan::ShouldTryToSend(Id id) const { return true; } + +bool SockMan::ShouldTryToRecv(Id id) const { return true; } + +void SockMan::EventIOLoopCompletedForOne(Id id) {} + +void SockMan::EventIOLoopCompletedForAll() {} + +void SockMan::ThreadSocketHandler() +{ + AssertLockNotHeld(m_connected_mutex); + + while (!interruptNet) { + EventIOLoopCompletedForAll(); + + // Check for the readiness of the already connected sockets and the + // listening sockets in one call ("readiness" as in poll(2) or + // select(2)). If none are ready, wait for a short while and return + // empty sets. + auto io_readiness{GenerateWaitSockets()}; + if (io_readiness.events_per_sock.empty() || + // WaitMany() may as well be a static method, the context of the first Sock in the vector is not relevant. + !io_readiness.events_per_sock.begin()->first->WaitMany(SELECT_TIMEOUT, + io_readiness.events_per_sock)) { + interruptNet.sleep_for(SELECT_TIMEOUT); + } + + // Service (send/receive) each of the already connected sockets. + SocketHandlerConnected(io_readiness); + + // Accept new connections from listening sockets. + SocketHandlerListening(io_readiness.events_per_sock); + } +} + +SockMan::IOReadiness SockMan::GenerateWaitSockets() +{ + AssertLockNotHeld(m_connected_mutex); + + IOReadiness io_readiness; + + for (const auto& sock : m_listen) { + io_readiness.events_per_sock.emplace(sock, Sock::Events{Sock::RECV}); + } + + auto connected_snapshot{WITH_LOCK(m_connected_mutex, return m_connected;)}; + + for (const auto& [id, sockets] : connected_snapshot) { + const bool select_recv{ShouldTryToRecv(id)}; + const bool select_send{ShouldTryToSend(id)}; + if (!select_recv && !select_send) continue; + + Sock::Event event = (select_send ? Sock::SEND : 0) | (select_recv ? Sock::RECV : 0); + io_readiness.events_per_sock.emplace(sockets->sock, Sock::Events{event}); + io_readiness.ids_per_sock.emplace(sockets->sock, id); + } + + return io_readiness; +} + +void SockMan::SocketHandlerConnected(const IOReadiness& io_readiness) +{ + AssertLockNotHeld(m_connected_mutex); + + for (const auto& [sock, events] : io_readiness.events_per_sock) { + if (interruptNet) { + return; + } + + auto it{io_readiness.ids_per_sock.find(sock)}; + if (it == io_readiness.ids_per_sock.end()) { + continue; + } + const Id id{it->second}; + + bool send_ready = events.occurred & Sock::SEND; // Sock::SEND could only be set if ShouldTryToSend() has returned true in GenerateWaitSockets(). + bool recv_ready = events.occurred & Sock::RECV; // Sock::RECV could only be set if ShouldTryToRecv() has returned true in GenerateWaitSockets(). + bool err_ready = events.occurred & Sock::ERR; + + if (send_ready) { + bool cancel_recv; + + EventReadyToSend(id, cancel_recv); + + if (cancel_recv) { + recv_ready = false; + } + } + + if (recv_ready || err_ready) { + uint8_t buf[0x10000]; // typical socket buffer is 8K-64K + + auto sockets{GetConnectionSockets(id)}; + if (!sockets) { + continue; + } + + const ssize_t nrecv{WITH_LOCK( + sockets->mutex, + return sockets->sock->Recv(buf, sizeof(buf), MSG_DONTWAIT);)}; + + if (nrecv < 0) { // In all cases (including -1 and 0) EventIOLoopCompletedForOne() should be executed after this, don't change the code to skip it. + const int err = WSAGetLastError(); + if (err != WSAEWOULDBLOCK && err != WSAEMSGSIZE && err != WSAEINTR && err != WSAEINPROGRESS) { + EventGotPermanentReadError(id, NetworkErrorString(err)); + } + } else if (nrecv == 0) { + EventGotEOF(id); + } else { + EventGotData(id, {buf, static_cast(nrecv)}); + } + } + + EventIOLoopCompletedForOne(id); + } +} + +void SockMan::SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) +{ + AssertLockNotHeld(m_connected_mutex); + + for (const auto& sock : m_listen) { + if (interruptNet) { + return; + } + const auto it = events_per_sock.find(sock); + if (it != events_per_sock.end() && it->second.occurred & Sock::RECV) { + CService addr_accepted; + + auto sock_accepted{AcceptConnection(*sock, addr_accepted)}; + + if (sock_accepted) { + NewSockAccepted(std::move(sock_accepted), GetBindAddress(*sock), addr_accepted); + } + } + } +} + +std::shared_ptr SockMan::GetConnectionSockets(Id id) const +{ + LOCK(m_connected_mutex); + + auto it{m_connected.find(id)}; + if (it == m_connected.end()) { + // There is no socket in case we've already disconnected, or in test cases without + // real connections. + return {}; + } + + return it->second; +} diff --git a/src/common/sockman.h b/src/common/sockman.h new file mode 100644 index 000000000000..5187ed3f0519 --- /dev/null +++ b/src/common/sockman.h @@ -0,0 +1,322 @@ +// Copyright (c) 2024-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or https://opensource.org/license/mit/. + +#ifndef BITCOIN_COMMON_SOCKMAN_H +#define BITCOIN_COMMON_SOCKMAN_H + +#include +#include +#include + +#include +#include +#include + +/** + * A socket manager class which handles socket operations. + * To use this class, inherit from it and implement the pure virtual methods. + * Handled operations: + * - binding and listening on sockets + * - starting of necessary threads to process socket operations + * - accepting incoming connections + * - closing connections + * - waiting for IO readiness on sockets and doing send/recv accordingly + */ +class SockMan +{ +public: + /** + * Each connection is assigned an unique id of this type. + */ + using Id = int64_t; + + virtual ~SockMan() = default; + + // + // Non-virtual functions, to be reused by children classes. + // + + /** + * Bind to a new address:port, start listening and add the listen socket to `m_listen`. + * Should be called before `StartSocketsThreads()`. + * @param[in] to Where to bind. + * @param[out] err_msg Error string if an error occurs. + * @retval true Success. + * @retval false Failure, `err_msg` will be set. + */ + bool BindAndStartListening(const CService& to, bilingual_str& err_msg); + + /** + * Options to influence `StartSocketsThreads()`. + */ + struct Options { + std::string_view socket_handler_thread_name; + }; + + /** + * Start the necessary threads for sockets IO. + */ + void StartSocketsThreads(const Options& options); + + /** + * Join (wait for) the threads started by `StartSocketsThreads()` to exit. + */ + void JoinSocketsThreads(); + + /** + * Accept a connection. + * @param[in] listen_sock Socket on which to accept the connection. + * @param[out] addr Address of the peer that was accepted. + * @return Newly created socket for the accepted connection. + */ + std::unique_ptr AcceptConnection(const Sock& listen_sock, CService& addr); + + /** + * After a new socket with a peer has been created, configure its flags, + * make a new connection id and call `EventNewConnectionAccepted()`. + * @param[in] sock The newly created socket. + * @param[in] me Address at our end of the connection. + * @param[in] them Address of the new peer. + */ + void NewSockAccepted(std::unique_ptr&& sock, const CService& me, const CService& them) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Generate an id for a newly created connection. + */ + Id GetNewId(); + + /** + * Destroy a given connection by closing its socket and release resources occupied by it. + * @param[in] id Connection to destroy. + * @return Whether the connection existed and its socket was closed by this call. + */ + bool CloseConnection(Id id) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Try to send some data over the given connection. + * @param[in] id Identifier of the connection. + * @param[in] data The data to send, it might happen that only a prefix of this is sent. + * @param[in] will_send_more Used as an optimization if the caller knows that they will + * be sending more data soon after this call. + * @param[out] errmsg If <0 is returned then this will contain a human readable message + * explaining the error. + * @retval >=0 The number of bytes actually sent. + * @retval <0 A permanent error has occurred. + */ + ssize_t SendBytes(Id id, + std::span data, + bool will_send_more, + std::string& errmsg) const + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Stop listening by closing all listening sockets. + */ + void StopListening(); + + /** + * This is signaled when network activity should cease. + */ + CThreadInterrupt interruptNet; + + /** + * List of listening sockets. + */ + std::vector> m_listen; + +private: + + // + // Pure virtual functions must be implemented by children classes. + // + + /** + * Be notified when a new connection has been accepted. + * @param[in] id Id of the newly accepted connection. + * @param[in] me The address and port at our side of the connection. + * @param[in] them The address and port at the peer's side of the connection. + * @retval true The new connection was accepted at the higher level. + * @retval false The connection was refused at the higher level, so the + * associated socket and id should be discarded by `SockMan`. + */ + virtual bool EventNewConnectionAccepted(Id id, + const CService& me, + const CService& them) = 0; + + /** + * Called when the socket is ready to send data and `ShouldTryToSend()` has + * returned true. This is where the higher level code serializes its messages + * and calls `SockMan::SendBytes()`. + * @param[in] id Id of the connection whose socket is ready to send. + * @param[out] cancel_recv Should always be set upon return and if it is true, + * then the next attempt to receive data from that connection will be omitted. + */ + virtual void EventReadyToSend(Id id, bool& cancel_recv) = 0; + + /** + * Called when new data has been received. + * @param[in] id Connection for which the data arrived. + * @param[in] data Received data. + */ + virtual void EventGotData(Id id, std::span data) = 0; + + /** + * Called when the remote peer has sent an EOF on the socket. This is a graceful + * close of their writing side, we can still send and they will receive, if it + * makes sense at the application level. + * @param[in] id Connection whose socket got EOF. + */ + virtual void EventGotEOF(Id id) = 0; + + /** + * Called when we get an irrecoverable error trying to read from a socket. + * @param[in] id Connection whose socket got an error. + * @param[in] errmsg Message describing the error. + */ + virtual void EventGotPermanentReadError(Id id, const std::string& errmsg) = 0; + + // + // Non-pure virtual functions can be overridden by children classes or left + // alone to use the default implementation from SockMan. + // + + /** + * Can be used to temporarily pause sends on a connection. + * SockMan would only call Send() if this returns true. + * The implementation in SockMan always returns true. + * @param[in] id Connection for which to confirm or omit the next call to EventReadyToSend(). + */ + virtual bool ShouldTryToSend(Id id) const; + + /** + * SockMan would only call Recv() on a connection's socket if this returns true. + * Can be used to temporarily pause receives on a connection. + * The implementation in SockMan always returns true. + * @param[in] id Connection for which to confirm or omit the next receive. + */ + virtual bool ShouldTryToRecv(Id id) const; + + /** + * SockMan has completed the current send+recv iteration for a given connection. + * It will do another send+recv for this connection after processing all other connections. + * Can be used to execute periodic tasks for a given connection. + * The implementation in SockMan does nothing. + * @param[in] id Connection for which send+recv has been done. + */ + virtual void EventIOLoopCompletedForOne(Id id); + + /** + * SockMan has completed send+recv for all connections. + * Can be used to execute periodic tasks for all connections, like closing + * connections due to higher level logic. + * The implementation in SockMan does nothing. + */ + virtual void EventIOLoopCompletedForAll(); + + /** + * The sockets used by a connection. + */ + struct ConnectionSockets { + explicit ConnectionSockets(std::unique_ptr&& s) + : sock{std::move(s)} + { + } + + /** + * Mutex that serializes the Send() and Recv() calls on `sock`. + */ + Mutex mutex; + + /** + * Underlying socket. + * `shared_ptr` (instead of `unique_ptr`) is used to avoid premature close of the + * underlying file descriptor by one thread while another thread is poll(2)-ing + * it for activity. + * @see https://github.com/bitcoin/bitcoin/issues/21744 for details. + */ + std::shared_ptr sock; + }; + + /** + * Info about which socket has which event ready and its connection id. + */ + struct IOReadiness { + /** + * Map of socket -> socket events. For example: + * socket1 -> { requested = SEND|RECV, occurred = RECV } + * socket2 -> { requested = SEND, occurred = SEND } + */ + Sock::EventsPerSock events_per_sock; + + /** + * Map of socket -> connection id (in `m_connected`). For example + * socket1 -> id=23 + * socket2 -> id=56 + */ + std::unordered_map + ids_per_sock; + }; + + /** + * Check connected and listening sockets for IO readiness and process them accordingly. + */ + void ThreadSocketHandler() + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Generate a collection of sockets to check for IO readiness. + * @return Sockets to check for readiness plus an aux map to find the + * corresponding connection id given a socket. + */ + IOReadiness GenerateWaitSockets() + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Do the read/write for connected sockets that are ready for IO. + * @param[in] io_readiness Which sockets are ready and their connection ids. + */ + void SocketHandlerConnected(const IOReadiness& io_readiness) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Accept incoming connections, one from each read-ready listening socket. + * @param[in] events_per_sock Sockets that are ready for IO. + */ + void SocketHandlerListening(const Sock::EventsPerSock& events_per_sock) + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * Retrieve an entry from m_connected. + * @param[in] id Connection id to search for. + * @return ConnectionSockets for the given connection id or empty shared_ptr if not found. + */ + std::shared_ptr GetConnectionSockets(Id id) const + EXCLUSIVE_LOCKS_REQUIRED(!m_connected_mutex); + + /** + * The id to assign to the next created connection. Used to generate ids of connections. + */ + std::atomic m_next_id{0}; + + /** + * Thread that sends to and receives from sockets and accepts connections. + */ + std::thread m_thread_socket_handler; + + mutable Mutex m_connected_mutex; + + /** + * Sockets for existent connections. + * The `shared_ptr` makes it possible to create a snapshot of this by simply copying + * it (under `m_connected_mutex`). + */ + std::unordered_map> m_connected GUARDED_BY(m_connected_mutex); +}; + +#endif // BITCOIN_COMMON_SOCKMAN_H diff --git a/src/net.cpp b/src/net.cpp index aab8782f3d32..4b02c7745db9 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -379,20 +379,6 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce) return true; } -/** Get the bind address for a socket as CService. */ -static CService GetBindAddress(const Sock& sock) -{ - CService addr_bind; - struct sockaddr_storage sockaddr_bind; - socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); - if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { - addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind, sockaddr_bind_len); - } else { - LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n"); - } - return addr_bind; -} - CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, ConnectionType conn_type, bool use_v2transport) { AssertLockNotHeld(m_unused_i2p_sessions_mutex); diff --git a/src/netbase.cpp b/src/netbase.cpp index 06593312d12f..03eb58ac1f8a 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -951,3 +951,16 @@ CService MaybeFlipIPv6toCJDNS(const CService& service) } return ret; } + +CService GetBindAddress(const Sock& sock) +{ + CService addr_bind; + struct sockaddr_storage sockaddr_bind; + socklen_t sockaddr_bind_len = sizeof(sockaddr_bind); + if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) { + addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind, sockaddr_bind_len); + } else { + LogPrintLevel(BCLog::NET, BCLog::Level::Warning, "getsockname failed\n"); + } + return addr_bind; +} diff --git a/src/netbase.h b/src/netbase.h index 41b3ca8fdb0a..72fc37fef8c4 100644 --- a/src/netbase.h +++ b/src/netbase.h @@ -362,4 +362,7 @@ bool IsBadPort(uint16_t port); */ CService MaybeFlipIPv6toCJDNS(const CService& service); +/** Get the bind address for a socket as CService. */ +CService GetBindAddress(const Sock& sock); + #endif // BITCOIN_NETBASE_H diff --git a/src/sv2/CMakeLists.txt b/src/sv2/CMakeLists.txt index e61f2f356083..a628204612fc 100644 --- a/src/sv2/CMakeLists.txt +++ b/src/sv2/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(bitcoin_sv2 STATIC EXCLUDE_FROM_ALL noise.cpp transport.cpp + connman.cpp ) target_link_libraries(bitcoin_sv2 @@ -12,5 +13,6 @@ target_link_libraries(bitcoin_sv2 core_interface bitcoin_clientversion bitcoin_crypto + bitcoin_common # for SockMan $<$:ws2_32> ) diff --git a/src/sv2/connman.cpp b/src/sv2/connman.cpp new file mode 100644 index 000000000000..dc4c10dcd9ef --- /dev/null +++ b/src/sv2/connman.cpp @@ -0,0 +1,371 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include +#include +#include +#include +#include + +using node::Sv2MsgType; + +Sv2Connman::~Sv2Connman() +{ + AssertLockNotHeld(m_clients_mutex); + + { + LOCK(m_clients_mutex); + for (const auto& client : m_sv2_clients) { + LogTrace(BCLog::SV2, "Disconnecting client id=%zu\n", + client.first); + client.second->m_disconnect_flag = true; + } + DisconnectFlagged(); + } + + Interrupt(); + StopThreads(); +} + +bool Sv2Connman::Start(Sv2EventsInterface* msgproc, std::string host, uint16_t port) +{ + m_msgproc = msgproc; + + if (!Bind(host, port)) return false; + + SockMan::Options sockman_options; + StartSocketsThreads(sockman_options); + + return true; +} + +bool Sv2Connman::Bind(std::string host, uint16_t port) +{ + const CService addr_bind = LookupNumeric(host, port); + + bilingual_str error; + if (!BindAndStartListening(addr_bind, error)) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Template Provider failed to bind to port %d: %s\n", port, error.original); + return false; + } + + LogPrintLevel(BCLog::SV2, BCLog::Level::Info, "%s listening on %s:%d\n", SV2_PROTOCOL_NAMES.at(m_subprotocol), host, port); + + return true; +} + + +void Sv2Connman::DisconnectFlagged() +{ + AssertLockHeld(m_clients_mutex); + + // Remove clients that are flagged for disconnection. + auto it = m_sv2_clients.begin(); + while(it != m_sv2_clients.end()) { + std::shared_ptr client{it->second}; + LOCK(client->cs_send); + LOCK(client->cs_status); + if (client->m_send_messages.empty() && client->m_disconnect_flag) { + CloseConnection(it->second->m_id); + it = m_sv2_clients.erase(it); + } else { + it++; + } + } +} + +void Sv2Connman::EventIOLoopCompletedForAll() +{ + LOCK(m_clients_mutex); + DisconnectFlagged(); +} + +void Sv2Connman::Interrupt() +{ + interruptNet(); +} + +void Sv2Connman::StopThreads() +{ + JoinSocketsThreads(); +} + +std::shared_ptr Sv2Connman::GetClientById(NodeId node_id) const +{ + auto it{m_sv2_clients.find(node_id)}; + if (it != m_sv2_clients.end()) { + return it->second; + } + return nullptr; +} + +bool Sv2Connman::EventNewConnectionAccepted(NodeId node_id, + const CService& addr_bind_, + const CService& addr_) +{ + Assume(m_certificate); + LOCK(m_clients_mutex); + std::unique_ptr transport = std::make_unique(m_static_key, m_certificate.value()); + auto client = std::make_shared(node_id, std::move(transport)); + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "New client id=%zu connected\n", node_id); + m_sv2_clients.emplace(node_id, std::move(client)); + return true; +} + +void Sv2Connman::EventReadyToSend(NodeId node_id, bool& cancel_recv) +{ + AssertLockNotHeld(m_clients_mutex); + + auto client{WITH_LOCK(m_clients_mutex, return GetClientById(node_id);)}; + if (client == nullptr) { + cancel_recv = true; + return; + } + + LOCK(client->cs_send); + auto it = client->m_send_messages.begin(); + std::optional expected_more; + + size_t total_sent = 0; + + while(true) { + if (it != client->m_send_messages.end()) { + // If possible, move one message from the send queue to the transport. + // This fails when there is an existing message still being sent, + // or when the handshake has not yet completed. + // + // Wrap Sv2NetMsg inside CSerializedNetMsg for transport + CSerializedNetMsg net_msg{*it}; + if (client->m_transport->SetMessageToSend(net_msg)) { + ++it; + } + } + + const auto& [data, more, _m_message_type] = client->m_transport->GetBytesToSend(/*have_next_message=*/it != client->m_send_messages.end()); + + + // We rely on the 'more' value returned by GetBytesToSend to correctly predict whether more + // bytes are still to be sent, to correctly set the MSG_MORE flag. As a sanity check, + // verify that the previously returned 'more' was correct. + if (expected_more.has_value()) Assume(!data.empty() == *expected_more); + expected_more = more; + + ssize_t sent = 0; + std::string errmsg; + + if (!data.empty()) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Send %d bytes to client id=%zu\n", + data.size() - total_sent, node_id); + + sent = SendBytes(node_id, data, more, errmsg); + } + + if (sent > 0) { + client->m_transport->MarkBytesSent(sent); + if (static_cast(sent) != data.size()) { + // could not send full message; stop sending more + break; + } + } else { + if (sent < 0) { + LogDebug(BCLog::NET, "socket send error for peer=%d: %s\n", node_id, errmsg); + CloseConnection(node_id); + } + break; + } + } + + // Clear messages that have been handed to transport from the queue + client->m_send_messages.erase(client->m_send_messages.begin(), it); + + // If both receiving and (non-optimistic) sending were possible, we first attempt + // sending. If that succeeds, but does not fully drain the send queue, do not + // attempt to receive. This avoids needlessly queueing data if the remote peer + // is slow at receiving data, by means of TCP flow control. We only do this when + // sending actually succeeded to make sure progress is always made; otherwise a + // deadlock would be possible when both sides have data to send, but neither is + // receiving. + // + // TODO: decide if this is useful for Sv2 + cancel_recv = total_sent > 0; // && more; +} + +void Sv2Connman::EventGotData(Id id, std::span data) +{ + AssertLockNotHeld(m_clients_mutex); + + auto client{WITH_LOCK(m_clients_mutex, return GetClientById(id);)}; + if (client == nullptr) { + return; + } + + try { + while (data.size() > 0) { + // absorb network data + if (!client->m_transport->ReceivedBytes(data)) { + // Serious transport problem + LogPrintLevel(BCLog::SV2, BCLog::Level::Trace, "Transport problem, disconnecting client id=%zu\n", + client->m_id); + // TODO: should we even bother with this? + LOCK(client->cs_status); + client->m_disconnect_flag = true; + break; + } + + if (client->m_transport->ReceivedMessageComplete()) { + bool dummy_reject_message = false; + Sv2NetMsg msg = client->m_transport->GetReceivedMessage(std::chrono::milliseconds(0), dummy_reject_message); + ProcessSv2Message(msg, *client.get()); + } + } + } catch (const std::exception& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received error when processing client id=%zu message: %s\n", client->m_id, e.what()); + LOCK(client->cs_status); + client->m_disconnect_flag = true; + } + +} + +void Sv2Connman::EventGotEOF(NodeId node_id) +{ + auto client{WITH_LOCK(m_clients_mutex, return GetClientById(node_id);)}; + if (client == nullptr) return; + LOCK(client->cs_status); + client->m_disconnect_flag = true; +} + +void Sv2Connman::EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) +{ + auto client{WITH_LOCK(m_clients_mutex, return GetClientById(node_id);)}; + if (client == nullptr) return; + LOCK(client->cs_status); + client->m_disconnect_flag = true; +} + +void Sv2Connman::ProcessSv2Message(const Sv2NetMsg& sv2_net_msg, Sv2Client& client) +{ + uint8_t msg_type[1] = {uint8_t(sv2_net_msg.m_msg_type)}; + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Received 0x%s %s from client id=%zu\n", + // After clang-17: + // std::format("{:x}", uint8_t(sv2_net_msg.m_msg_type)), + HexStr(msg_type), + node::SV2_MSG_NAMES.at(sv2_net_msg.m_msg_type), client.m_id); + + DataStream ss (sv2_net_msg.m_msg); + + if (WITH_LOCK(client.cs_status, return client.m_disconnect_flag)) { + // Don't bother processing new messages if we are about to disconnect when the + // send queue empties. This also prevents us from appending to the send queue + // when m_disconnect_flag is set. + return; + } + + switch (sv2_net_msg.m_msg_type) + { + case Sv2MsgType::SETUP_CONNECTION: + { + if (WITH_LOCK(client.cs_status, return client.m_setup_connection_confirmed)) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Client client id=%zu connection has already been confirmed\n", + client.m_id); + return; + } + + node::Sv2SetupConnectionMsg setup_conn; + try { + ss >> setup_conn; + } catch (const std::exception& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received invalid SetupConnection message from client id=%zu: %s\n", + client.m_id, e.what()); + LOCK(client.cs_status); + client.m_disconnect_flag = true; + return; + } + + LOCK(client.cs_send); + + // Disconnect a client that connects on the wrong subprotocol. + if (setup_conn.m_protocol != m_subprotocol) { + node::Sv2SetupConnectionErrorMsg setup_conn_err{setup_conn.m_flags, std::string{"unsupported-protocol"}}; + + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x02 SetupConnectionError to client id=%zu\n", + client.m_id); + client.m_send_messages.emplace_back(setup_conn_err); + + LOCK(client.cs_status); + client.m_disconnect_flag = true; + return; + } + + // Disconnect a client if they are not running a compatible protocol version. + if ((m_protocol_version < setup_conn.m_min_version) || (m_protocol_version > setup_conn.m_max_version)) { + node::Sv2SetupConnectionErrorMsg setup_conn_err{setup_conn.m_flags, std::string{"protocol-version-mismatch"}}; + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x02 SetupConnection.Error to client id=%zu\n", + client.m_id); + client.m_send_messages.emplace_back(setup_conn_err); + + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received a connection from client id=%zu with incompatible protocol_versions: min_version: %d, max_version: %d\n", + client.m_id, setup_conn.m_min_version, setup_conn.m_max_version); + + LOCK(client.cs_status); + client.m_disconnect_flag = true; + return; + } + + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "Send 0x01 SetupConnection.Success to client id=%zu\n", + client.m_id); + node::Sv2SetupConnectionSuccessMsg setup_success{m_protocol_version, m_optional_features}; + client.m_send_messages.emplace_back(setup_success); + + LOCK(client.cs_status); + client.m_setup_connection_confirmed = true; + + break; + } + case Sv2MsgType::COINBASE_OUTPUT_CONSTRAINTS: + { + { + LOCK(client.cs_status); + if (!client.m_setup_connection_confirmed) { + client.m_disconnect_flag = true; + return; + } + } + + node::Sv2CoinbaseOutputConstraintsMsg coinbase_output_constraints; + try { + ss >> coinbase_output_constraints; + client.m_coinbase_output_constraints_recv = true; + } catch (const std::exception& e) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received invalid CoinbaseOutputConstraints message from client id=%zu: %s\n", + client.m_id, e.what()); + + LOCK(client.cs_status); + client.m_disconnect_flag = true; + return; + } + + uint32_t max_additional_size = coinbase_output_constraints.m_coinbase_output_max_additional_size; + LogPrintLevel(BCLog::SV2, BCLog::Level::Debug, "coinbase_output_max_additional_size=%d bytes\n", max_additional_size); + + if (max_additional_size > MAX_BLOCK_WEIGHT) { + LogPrintLevel(BCLog::SV2, BCLog::Level::Error, "Received impossible CoinbaseOutputConstraints from client id=%zu: %d\n", + client.m_id, max_additional_size); + + LOCK(client.cs_status); + client.m_disconnect_flag = true; + return; + } + + client.m_coinbase_tx_outputs_size = coinbase_output_constraints.m_coinbase_output_max_additional_size; + + break; + } + default: { + uint8_t msg_type[1]{uint8_t(sv2_net_msg.m_msg_type)}; + LogPrintLevel(BCLog::SV2, BCLog::Level::Warning, "Received unknown message type 0x%s from client id=%zu\n", + HexStr(msg_type), client.m_id); + break; + } + } +} diff --git a/src/sv2/connman.h b/src/sv2/connman.h new file mode 100644 index 000000000000..eee7c8ab98d2 --- /dev/null +++ b/src/sv2/connman.h @@ -0,0 +1,229 @@ +// Copyright (c) 2023-present The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_SV2_CONNMAN_H +#define BITCOIN_SV2_CONNMAN_H + +#include +#include +#include +#include + +namespace { + /* + * Supported Stratum v2 subprotocols + */ + static constexpr uint8_t TP_SUBPROTOCOL{0x02}; + + static const std::map SV2_PROTOCOL_NAMES{ + {0x02, "Template Provider"}, + }; +} + +struct Sv2Client +{ + /* Ephemeral identifier */ + size_t m_id; + + /** + * Transport + */ + std::unique_ptr m_transport; + + Mutex cs_status; + + /** + * Whether the client has confirmed the connection with a successful SetupConnection. + */ + bool m_setup_connection_confirmed GUARDED_BY(cs_status) = false; + + /** + * Whether the client is a candidate for disconnection. The client's socket will be + * closed after all queued messages have been sent. + */ + bool m_disconnect_flag GUARDED_BY(cs_status) = false; + + Mutex cs_send; + + /** Queue of messages to be sent */ + std::deque m_send_messages GUARDED_BY(cs_send); + + /** + * Whether the client has received CoinbaseOutputConstraints message. + */ + bool m_coinbase_output_constraints_recv = false; + + /** + * Specific additional coinbase tx output size required for the client. + */ + unsigned int m_coinbase_tx_outputs_size; + + explicit Sv2Client(size_t id, std::unique_ptr transport) : + m_id{id}, m_transport{std::move(transport)} {}; + + bool IsFullyConnected() EXCLUSIVE_LOCKS_REQUIRED(cs_status) + { + return !m_disconnect_flag && m_setup_connection_confirmed; + } + + Sv2Client(Sv2Client&) = delete; + Sv2Client& operator=(const Sv2Client&) = delete; +}; + +/** + * Interface for sv2 message handling + */ +class Sv2EventsInterface +{ +public: + virtual ~Sv2EventsInterface() = default; +}; + +/* + * Handle Stratum v2 connections. + * Currently only supports inbound connections. + */ +class Sv2Connman : SockMan +{ +private: + /** Interface to pass events up */ + Sv2EventsInterface* m_msgproc; + + /** + * The current protocol version of stratum v2 supported by the server. Not to be confused + * with byte value of identitying the stratum v2 subprotocol. + */ + const uint16_t m_protocol_version = 2; + + /** + * The currently supported optional features. + */ + const uint16_t m_optional_features = 0; + + /** + * The subprotocol used in setup connection messages. + * An Sv2Connman only recognizes its own subprotocol. + */ + const uint8_t m_subprotocol; + + CKey m_static_key; + + XOnlyPubKey m_authority_pubkey; + + std::optional m_certificate; + + /** + * A map of all connected stratum v2 clients. + */ + using Clients = std::unordered_map>; + Clients m_sv2_clients GUARDED_BY(m_clients_mutex); + + /** + * Creates a socket and binds the port for new stratum v2 connections. + */ + [[nodiscard]] bool Bind(std::string host, uint16_t port); + + void DisconnectFlagged() EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex); + + /** + * Create a `Sv2Client` object and add it to the `m_sv2_clients` member. + * @param[in] node_id Id of the newly accepted connection. + * @param[in] me The address and port at our side of the connection. + * @param[in] them The address and port at the peer's side of the connection. + * @retval true on success + * @retval false on failure, meaning that the associated socket and node_id should be discarded + */ + virtual bool EventNewConnectionAccepted(NodeId node_id, + const CService& me, + const CService& them) + EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex) override; + + void EventReadyToSend(NodeId node_id, bool& cancel_recv) override + EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex); + + virtual void EventGotData(Id id, std::span data) override + EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex); + + virtual void EventGotEOF(NodeId node_id) override + EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex); + + virtual void EventGotPermanentReadError(NodeId node_id, const std::string& errmsg) override + EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex); + + virtual void EventIOLoopCompletedForAll() EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex) override; + + /** + * Encrypt the header and message payload and send it. + * @throws std::runtime_error if encrypting the message fails. + */ + bool EncryptAndSendMessage(Sv2Client& client, node::Sv2NetMsg& net_msg); + + /** + * A helper method to read and decrypt multiple Sv2NetMsgs. + */ + std::vector ReadAndDecryptSv2NetMsgs(Sv2Client& client, std::span buffer); + +public: + Sv2Connman(uint8_t subprotocol, CKey static_key, XOnlyPubKey authority_pubkey, Sv2SignatureNoiseMessage certificate) : + m_subprotocol(subprotocol), m_static_key(static_key), m_authority_pubkey(authority_pubkey), m_certificate(certificate) {}; + + ~Sv2Connman(); + + Mutex m_clients_mutex; + + /** + * Starts the Stratum v2 server and thread. + * returns false if port is unable to bind. + */ + [[nodiscard]] bool Start(Sv2EventsInterface* msgproc, std::string host, uint16_t port); + + /** + * Triggered on interrupt signals to stop the main event loop in ThreadSv2Handler(). + */ + void Interrupt(); + + /** + * Tear down of the connman thread and any other necessary tear down. + */ + void StopThreads(); + + /** + * Main handler for all received stratum v2 messages. + */ + void ProcessSv2Message(const node::Sv2NetMsg& sv2_header, Sv2Client& client); + + std::shared_ptr GetClientById(NodeId node_id) const EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex); + + using Sv2ClientFn = std::function; + /** Perform a function on each fully connected client. */ + void ForEachClient(const Sv2ClientFn& func) EXCLUSIVE_LOCKS_REQUIRED(!m_clients_mutex) + { + LOCK(m_clients_mutex); + for (const auto& client : m_sv2_clients) { + LOCK(client.second->cs_status); + if (client.second->IsFullyConnected()) func(*client.second); + } + }; + + /** Number of clients that are not marked for disconnection, used for tests. */ + size_t ConnectedClients() EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex) + { + return std::count_if(m_sv2_clients.begin(), m_sv2_clients.end(), [](const auto& c) { + LOCK(c.second->cs_status); + return !c.second->m_disconnect_flag; + }); + } + + /** Number of clients with m_setup_connection_confirmed, used for tests. */ + size_t FullyConnectedClients() EXCLUSIVE_LOCKS_REQUIRED(m_clients_mutex) + { + return std::count_if(m_sv2_clients.begin(), m_sv2_clients.end(), [](const auto& c) { + LOCK(c.second->cs_status); + return c.second->IsFullyConnected(); + }); + } + +}; + +#endif // BITCOIN_SV2_CONNMAN_H diff --git a/src/sv2/messages.h b/src/sv2/messages.h index 9475f5bbdedc..c04c95c1dd5c 100644 --- a/src/sv2/messages.h +++ b/src/sv2/messages.h @@ -6,11 +6,22 @@ #define BITCOIN_SV2_MESSAGES_H #include // for CSerializedNetMsg and CNetMessage +#include +#include +#include +#include