From 4dad9b0cdbf9a49dc650a75f4e96361357c35f20 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Sun, 26 Apr 2026 23:21:57 +0200 Subject: [PATCH 01/11] Add asynchronous HTTP handling with HttpAsync module --- CMakeLists.txt | 2 + include/HttpAsync.h | 73 +++++++ include/TLuaEngine.h | 2 +- src/HttpAsync.cpp | 477 +++++++++++++++++++++++++++++++++++++++++++ src/TLuaEngine.cpp | 16 ++ 5 files changed, 569 insertions(+), 1 deletion(-) create mode 100644 include/HttpAsync.h create mode 100644 src/HttpAsync.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 67e758d6..62496c64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ set(PRJ_HEADERS include/Defer.h include/Environment.h include/Http.h + include/HttpAsync.h include/IThreaded.h include/Json.h include/LuaAPI.h @@ -60,6 +61,7 @@ set(PRJ_SOURCES src/Common.cpp src/Compat.cpp src/Http.cpp + src/HttpAsync.cpp src/LuaAPI.cpp src/SignalHandling.cpp src/TConfig.cpp diff --git a/include/HttpAsync.h b/include/HttpAsync.h new file mode 100644 index 00000000..b281d67d --- /dev/null +++ b/include/HttpAsync.h @@ -0,0 +1,73 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#pragma once + +#include +#include +#include +#include + +namespace HttpAsync { + + struct HttpResult { + enum class Type { COMPLETE, PROGRESS } type; + + uint64_t requestId; + + int status; + std::string body; + std::map headers; + + long long current; + long long total; + }; + + class AsyncHttpProxy { + public: + AsyncHttpProxy(std::string baseUrl, sol::table defaultHeaders); + ~AsyncHttpProxy() = default; + + void SetTimeout(int seconds); + + void Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); + void Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + void Put(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + void Patch(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + void Delete(std::string endpoint, sol::object headers, sol::function cb); + void Head(std::string endpoint, sol::object headers, sol::function cb); + + void Download(std::string endpoint, std::string savePath, sol::function cb, sol::object prog); + void PostFile(std::string endpoint, std::string fieldName, std::string filePath, sol::object headers, sol::function cb); + + private: + std::map PrepareHeaders(sol::object overrides); + void PreparePayload(sol::object data, sol::object overrides, std::string& outBody, std::map& outHeaders); + + std::string mBaseUrl; + std::map mDefaultHeaders; + int mTimeoutSeconds = 30; + }; + + void Init(); + void Shutdown(); + void Update(sol::state_view& lua); + void RegisterBindings(sol::state_view& lua); + void CleanupState(lua_State* L); + +} // namespace HttpAsync \ No newline at end of file diff --git a/include/TLuaEngine.h b/include/TLuaEngine.h index 47a5c760..d4a2a2ae 100644 --- a/include/TLuaEngine.h +++ b/include/TLuaEngine.h @@ -241,7 +241,7 @@ class TLuaEngine : public std::enable_shared_from_this, IThreaded { public: StateThreadData(const std::string& Name, TLuaStateId StateId, TLuaEngine& Engine); StateThreadData(const StateThreadData&) = delete; - virtual ~StateThreadData() noexcept { beammp_debug("\"" + mStateId + "\" destroyed"); } + virtual ~StateThreadData() noexcept; [[nodiscard]] std::shared_ptr EnqueueScript(const TLuaChunk& Script); [[nodiscard]] std::shared_ptr EnqueueFunctionCall(const std::string& FunctionName, const std::vector& Args, const std::string& EventName); [[nodiscard]] std::shared_ptr EnqueueFunctionCallFromCustomEvent(const std::string& FunctionName, const std::vector& Args, const std::string& EventName, CallStrategy Strategy); diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp new file mode 100644 index 00000000..3ecdc41a --- /dev/null +++ b/src/HttpAsync.cpp @@ -0,0 +1,477 @@ +// BeamMP, the BeamNG.drive multiplayer mod. +// Copyright (C) 2024 BeamMP Ltd., BeamMP team and contributors. +// +// BeamMP Ltd. can be contacted by electronic mail via contact@beammp.com. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +#include "HttpAsync.h" +#include "httplib.h" +#include "Common.h" +#include "LuaAPI.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace HttpAsync { + +struct PendingRequest { + lua_State* L; + int callbackRef = LUA_REFNIL; + int progressRef = LUA_REFNIL; + std::atomic abandoned{false}; +}; + +static std::deque g_Results; +static std::mutex g_Mutex; +static std::atomic g_ShuttingDown{false}; +static std::unique_ptr g_ThreadPool; +static std::atomic g_NextRequestId{1}; +static std::map> g_PendingRequests; + +const int THREAD_POOL_SIZE = 8; + +static std::string ToLower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(),[](unsigned char c){ return static_cast(std::tolower(c)); }); + return s; +} + +static void PushResult(HttpResult res) { + if (g_ShuttingDown.load()) return; + std::lock_guard lock(g_Mutex); + g_Results.push_back(std::move(res)); +} + +static int MakeRef(sol::object obj) { + if (!obj.is()) return LUA_REFNIL; + lua_State* L = obj.lua_state(); + obj.push(); + return luaL_ref(L, LUA_REGISTRYINDEX); +} + +static void ExtractHeaders(const httplib::Headers& source, std::map& dest) { + for (const auto& [k, v] : source) { + if (dest.count(k)) dest[k] += ", " + v; + else dest[k] = v; + } +} + +static bool SetupClient(const std::string& url, int timeout, std::unique_ptr& outClient, std::string& outPath) { + static const std::regex url_regex(R"(^(https?://[^/]+)(/.*)?$)", std::regex::extended); + std::smatch match; + if (!std::regex_match(url, match, url_regex)) return false; + + outClient = std::make_unique(match[1].str()); + outPath = match[2].length() == 0 ? "/" : match[2].str(); + + outClient->set_connection_timeout(timeout, 0); + outClient->set_read_timeout(timeout, 0); + outClient->set_follow_location(true); + outClient->enable_server_certificate_verification(true); + return true; +} + +template +static void EnqueueTask(lua_State* L, int cbRef, int progRef, Func&& task) { + if (g_ShuttingDown.load() || !g_ThreadPool) { + if (cbRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, cbRef); + if (progRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, progRef); + return; + } + + uint64_t reqId = g_NextRequestId++; + auto info = std::make_shared(); + info->L = L; + info->callbackRef = cbRef; + info->progressRef = progRef; + + { + std::lock_guard lock(g_Mutex); + g_PendingRequests[reqId] = info; + } + + g_ThreadPool->enqueue([reqId, info, task = std::forward(task)]() { + try { + task(reqId, info); + } catch (const std::exception& e) { + HttpResult res; + res.type = HttpResult::Type::COMPLETE; + res.requestId = reqId; + res.body = std::string("Internal Error: ") + e.what(); + PushResult(std::move(res)); + } + }); +} + +static void Dispatch(std::string method, std::string url, std::map headers, + std::string body, int timeout, lua_State* L, int cbRef, int progRef) { + + EnqueueTask(L, cbRef, progRef,[=, b = std::move(body), hMap = std::move(headers)] + (uint64_t reqId, std::shared_ptr info) { + std::string path; + std::unique_ptr cli; + if (!SetupClient(url, timeout, cli, path)) throw std::runtime_error("Invalid URL Format"); + + httplib::Headers h; + bool hasUA = false; + std::string cType = "application/json"; + + for (auto const&[key, val] : hMap) { + if (ToLower(key) == "user-agent") hasUA = true; + if (ToLower(key) == "content-type") { + cType = val; + if (method == "POST" || method == "PUT" || method == "PATCH") continue; + } + h.emplace(key, val); + } + if (!hasUA) h.emplace("User-Agent", "BeamMP-Server/1.0"); + + auto lastProg = std::chrono::steady_clock::now(); + auto prog_func = [&](uint64_t len, uint64_t total) { + if (g_ShuttingDown.load() || info->abandoned.load()) return false; + if (info->progressRef != LUA_REFNIL) { + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { + HttpResult res; res.type = HttpResult::Type::PROGRESS; res.requestId = reqId; + res.current = static_cast(len); res.total = static_cast(total); + PushResult(std::move(res)); lastProg = now; + } + } + return true; + }; + + httplib::Result response; + if (method == "POST") response = cli->Post(path.c_str(), h, b, cType.c_str()); + else if (method == "PUT") response = cli->Put(path.c_str(), h, b, cType.c_str()); + else if (method == "PATCH") response = cli->Patch(path.c_str(), h, b, cType.c_str()); + else if (method == "DELETE") response = cli->Delete(path.c_str(), h); + else if (method == "HEAD") response = cli->Head(path.c_str(), h); + else response = cli->Get(path.c_str(), h, prog_func); + + HttpResult res; + res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; + if (response) { + res.status = response->status; + res.body = std::move(response->body); + ExtractHeaders(response->headers, res.headers); + } else { + res.status = 0; + res.body = "Network Error: " + httplib::to_string(response.error()); + } + PushResult(std::move(res)); + }); +} + +AsyncHttpProxy::AsyncHttpProxy(std::string baseUrl, sol::table defaultHeaders) : mBaseUrl(std::move(baseUrl)) { + if (defaultHeaders != sol::lua_nil && defaultHeaders.valid()) { + for (auto const& pair : defaultHeaders) { + if (pair.first.is() && pair.second.is()) + mDefaultHeaders[pair.first.as()] = pair.second.as(); + } + } +} + +void AsyncHttpProxy::SetTimeout(int seconds) { mTimeoutSeconds = seconds; } + +std::map AsyncHttpProxy::PrepareHeaders(sol::object overrides) { + auto finalHeaders = mDefaultHeaders; + if (overrides.is()) { + for (auto const& pair : overrides.as()) { + if (pair.first.is() && pair.second.is()) + finalHeaders[pair.first.as()] = pair.second.as(); + } + } + return finalHeaders; +} + +void AsyncHttpProxy::PreparePayload(sol::object data, sol::object overrides, std::string& outBody, std::map& outHeaders) { + outHeaders = PrepareHeaders(overrides); + bool hasCT = false; + for (const auto& [k, v] : outHeaders) if (ToLower(k) == "content-type") hasCT = true; + + if (data.is()) { + outBody = LuaAPI::MP::JsonEncode(data.as()); + if (!hasCT) outHeaders["Content-Type"] = "application/json"; + } else { + outBody = data.is() ? data.as() : ""; + } +} + +void AsyncHttpProxy::Get(std::string ep, sol::object h, sol::function cb, sol::object prog) { + Dispatch("GET", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, cb.lua_state(), MakeRef(cb), MakeRef(prog)); +} + +void AsyncHttpProxy::Post(std::string ep, sol::object data, sol::object h, sol::function cb) { + std::string body; std::map headers; + PreparePayload(data, h, body, headers); + Dispatch("POST", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +void AsyncHttpProxy::Put(std::string ep, sol::object data, sol::object h, sol::function cb) { + std::string body; std::map headers; + PreparePayload(data, h, body, headers); + Dispatch("PUT", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +void AsyncHttpProxy::Patch(std::string ep, sol::object data, sol::object h, sol::function cb) { + std::string body; std::map headers; + PreparePayload(data, h, body, headers); + Dispatch("PATCH", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +void AsyncHttpProxy::Delete(std::string ep, sol::object h, sol::function cb) { + Dispatch("DELETE", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +void AsyncHttpProxy::Head(std::string ep, sol::object h, sol::function cb) { + Dispatch("HEAD", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +} + +void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string filePath, sol::object headers, sol::function cb) { + auto hMap = PrepareHeaders(headers); + std::string url = mBaseUrl + ep; + int timeout = mTimeoutSeconds; + + EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL,[=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { + std::string path; + std::unique_ptr cli; + if (!SetupClient(url, timeout, cli, path)) throw std::runtime_error("Invalid URL"); + if (!fs::exists(filePath)) throw std::runtime_error("File not found"); + + auto file_stream = std::make_shared(filePath, std::ios::binary); + if (!file_stream || !file_stream->is_open()) throw std::runtime_error("Could not open file"); + + httplib::UploadFormDataItems regular_items; + httplib::FormDataProviderItems provider_items = { + { + fieldName, + [file_stream](size_t offset, httplib::DataSink &sink) { + if (static_cast(file_stream->tellg()) != offset) { + file_stream->clear(); + file_stream->seekg(offset, std::ios::beg); + } + + char buffer[8192]; + file_stream->read(buffer, sizeof(buffer)); + std::streamsize read_bytes = file_stream->gcount(); + if (read_bytes > 0) sink.write(buffer, static_cast(read_bytes)); + if (file_stream->eof()) sink.done(); + return true; + }, + fs::path(filePath).filename().string(), + "application/octet-stream" + } + }; + + httplib::Headers finalH; + for (auto const& [key, val] : hMap) { + if (ToLower(key) == "content-type") continue; + finalH.emplace(key, val); + } + + auto response = cli->Post(path.c_str(), finalH, regular_items, provider_items); + + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; + if (response) { + res.status = response->status; + res.body = std::move(response->body); + ExtractHeaders(response->headers, res.headers); + } else { + res.status = 0; + res.body = "Upload Failed: " + httplib::to_string(response.error()); + } + PushResult(std::move(res)); + }); +} + +void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::function cb, sol::object prog) { + std::string url = mBaseUrl + ep; + int timeout = mTimeoutSeconds; + auto hMap = PrepareHeaders(sol::lua_nil); + + EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog),[=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { + std::string path; + std::unique_ptr cli; + if (!SetupClient(url, timeout, cli, path)) throw std::runtime_error("Invalid URL"); + + std::ofstream ofs(savePath, std::ios::binary); + if (!ofs) throw std::runtime_error("Could not open file for writing"); + + httplib::Headers finalH; + for (auto const&[key, val] : hMap) { + finalH.emplace(key, val); + } + + int status_code = 0; std::map resHeaders; + auto lastProg = std::chrono::steady_clock::now(); + + auto res = cli->Get(path.c_str(), finalH, + [&](const httplib::Response &r) { + status_code = r.status; + ExtractHeaders(r.headers, resHeaders); + return !g_ShuttingDown.load() && !info->abandoned.load(); + }, + [&](const char *b, size_t l) { + if (g_ShuttingDown.load() || info->abandoned.load()) return false; + ofs.write(b, static_cast(l)); + return true; + }, + [&](uint64_t len, uint64_t total) { + if (info->progressRef != LUA_REFNIL && !info->abandoned.load()) { + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { + HttpResult pres; pres.type = HttpResult::Type::PROGRESS; pres.requestId = reqId; + pres.current = static_cast(len); pres.total = static_cast(total); + PushResult(std::move(pres)); lastProg = now; + } + } + return true; + } + ); + ofs.close(); + + HttpResult fres; fres.type = HttpResult::Type::COMPLETE; fres.requestId = reqId; + fres.status = status_code; + fres.body = res ? "Success" : "Download Failed: " + httplib::to_string(res.error()); + fres.headers = resHeaders; + PushResult(std::move(fres)); + }); +} + +void RegisterBindings(sol::state_view& lua) { + lua.new_usertype("AsyncHttp", sol::no_constructor, + "SetTimeout", &AsyncHttpProxy::SetTimeout, + "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), + "Post", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Post(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Post), + "PostFile", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string fn, std::string fp, sol::function cb) { self.PostFile(ep, fn, fp, sol::nil, cb); }, &AsyncHttpProxy::PostFile), + "Put", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Put(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Put), + "Patch", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Patch(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Patch), + "Delete", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { self.Delete(ep, sol::nil, cb); }, &AsyncHttpProxy::Delete), + "Head", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { self.Head(ep, sol::nil, cb); }, &AsyncHttpProxy::Head), + "Download", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string p, sol::function cb) { self.Download(ep, p, cb, sol::nil); }, &AsyncHttpProxy::Download) + ); + + lua["AsyncHttp"]["new"] = sol::overload([](std::string url) { return std::make_shared(url, sol::table(sol::lua_nil)); },[](std::string url, sol::table headers) { return std::make_shared(url, headers); } + ); +} + +void Update(sol::state_view& lua) { + lua_State* L = lua.lua_state(); + std::deque toProcess; + + { + std::lock_guard lock(g_Mutex); + auto it = g_Results.begin(); + while (it != g_Results.end()) { + auto reqIt = g_PendingRequests.find(it->requestId); + if (reqIt == g_PendingRequests.end()) { + it = g_Results.erase(it); + } else if (reqIt->second->L == L) { + toProcess.push_back(std::move(*it)); + it = g_Results.erase(it); + } else { + ++it; + } + } + } + + for (const auto& res : toProcess) { + std::shared_ptr info; + { + std::lock_guard lock(g_Mutex); + if (g_PendingRequests.count(res.requestId)) info = g_PendingRequests[res.requestId]; + } + + if (!info || info->abandoned.load()) continue; + + if (res.type == HttpResult::Type::PROGRESS) { + if (info->progressRef != LUA_REFNIL) { + lua_rawgeti(L, LUA_REGISTRYINDEX, info->progressRef); + sol::protected_function prog = sol::stack::pop(L); + if (prog.valid()) { + auto r = prog(res.current, res.total); + if (!r.valid()) beammp_lua_errorf("AsyncHttp Progress Error: {}", sol::error(r).what()); + } + } + } else { + if (info->callbackRef != LUA_REFNIL) { + lua_rawgeti(L, LUA_REGISTRYINDEX, info->callbackRef); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { + auto r = cb(res.status, res.body, res.headers); + if (!r.valid()) beammp_lua_errorf("AsyncHttp Callback Error: {}", sol::error(r).what()); + } + } + + if (info->callbackRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->callbackRef); + info->callbackRef = LUA_REFNIL; + } + if (info->progressRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); + info->progressRef = LUA_REFNIL; + } + + std::lock_guard lock(g_Mutex); + g_PendingRequests.erase(res.requestId); + } + } +} + +void CleanupState(lua_State* L) { + std::lock_guard lock(g_Mutex); + for (auto it = g_PendingRequests.begin(); it != g_PendingRequests.end(); ) { + if (it->second->L == L) { + it->second->abandoned.store(true); + + if (it->second->callbackRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, it->second->callbackRef); + it->second->callbackRef = LUA_REFNIL; + } + if (it->second->progressRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, it->second->progressRef); + it->second->progressRef = LUA_REFNIL; + } + + it = g_PendingRequests.erase(it); + } else { + ++it; + } + } +} + +void Init() { + g_ShuttingDown.store(false); + g_ThreadPool = std::make_unique(THREAD_POOL_SIZE); +} + +void Shutdown() { + g_ShuttingDown.store(true); + if (g_ThreadPool) g_ThreadPool->shutdown(); + g_ThreadPool.reset(); + std::lock_guard lock(g_Mutex); + g_PendingRequests.clear(); + g_Results.clear(); +} + +} // namespace HttpAsync \ No newline at end of file diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index d7bee3e6..2fd44b5a 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -21,6 +21,7 @@ #include "Common.h" #include "CustomAssert.h" #include "Http.h" +#include "HttpAsync.h" #include "LuaAPI.h" #include "Env.h" #include "Profiling.h" @@ -41,6 +42,7 @@ TLuaEngine::TLuaEngine() : mResourceServerPath(fs::path(Application::Settings.getAsString(Settings::Key::General_ResourceFolder)) / "Server") { Application::SetSubsystemStatus("LuaEngine", Application::Status::Starting); LuaAPI::MP::Engine = this; + HttpAsync::Init(); if (!fs::exists(Application::Settings.getAsString(Settings::Key::General_ResourceFolder))) { fs::create_directory(Application::Settings.getAsString(Settings::Key::General_ResourceFolder)); } @@ -49,6 +51,7 @@ TLuaEngine::TLuaEngine() } Application::RegisterShutdownHandler([&] { Application::SetSubsystemStatus("LuaEngine", Application::Status::ShuttingDown); + HttpAsync::Shutdown(); if (mThread.joinable()) { mThread.join(); } @@ -1070,9 +1073,21 @@ TLuaEngine::StateThreadData::StateThreadData(const std::string& Name, TLuaStateI FSTable.set_function("ListDirectories", [this](const std::string& Path) { return Lua_FS_ListDirectories(Path); }); + HttpAsync::RegisterBindings(mStateView); Start(); } +TLuaEngine::StateThreadData::~StateThreadData() noexcept { + HttpAsync::CleanupState(mState); + + beammp_debug("\"" + mStateId + "\" destroyed"); + + if (mState) { + lua_close(mState); + mState = nullptr; + } +} + std::shared_ptr TLuaEngine::StateThreadData::EnqueueScript(const TLuaChunk& Script) { std::unique_lock Lock(mStateExecuteQueueMutex); auto Result = std::make_shared(); @@ -1119,6 +1134,7 @@ void TLuaEngine::StateThreadData::RegisterEvent(const std::string& EventName, co void TLuaEngine::StateThreadData::operator()() { RegisterThread("Lua:" + mStateId); while (!Application::IsShuttingDown()) { + HttpAsync::Update(mStateView); { // StateExecuteQueue Scope std::unique_lock Lock(mStateExecuteQueueMutex); if (!mStateExecuteQueue.empty()) { From 713417371cc93f000cd1066f1b746f93fefe5db4 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Mon, 27 Apr 2026 16:32:36 +0200 Subject: [PATCH 02/11] Add SSL verification option to AsyncHttpProxy and add default UA header to AsyncHttpProxy::Download and AsyncHttpProxy::PostFile Co-authored-by: Copilot --- include/HttpAsync.h | 2 ++ src/HttpAsync.cpp | 46 ++++++++++++++++++++++++++++----------------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/include/HttpAsync.h b/include/HttpAsync.h index b281d67d..cd3e2610 100644 --- a/include/HttpAsync.h +++ b/include/HttpAsync.h @@ -44,6 +44,7 @@ namespace HttpAsync { ~AsyncHttpProxy() = default; void SetTimeout(int seconds); + void VerifySSL(bool verify) { mVerifySSL = verify; } void Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); void Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); @@ -62,6 +63,7 @@ namespace HttpAsync { std::string mBaseUrl; std::map mDefaultHeaders; int mTimeoutSeconds = 30; + bool mVerifySSL = true; }; void Init(); diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 3ecdc41a..8ecdc129 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -76,7 +76,7 @@ static void ExtractHeaders(const httplib::Headers& source, std::map& outClient, std::string& outPath) { +static bool SetupClient(const std::string& url, int timeout, bool verifySSL, std::unique_ptr& outClient, std::string& outPath) { static const std::regex url_regex(R"(^(https?://[^/]+)(/.*)?$)", std::regex::extended); std::smatch match; if (!std::regex_match(url, match, url_regex)) return false; @@ -87,7 +87,7 @@ static bool SetupClient(const std::string& url, int timeout, std::unique_ptrset_connection_timeout(timeout, 0); outClient->set_read_timeout(timeout, 0); outClient->set_follow_location(true); - outClient->enable_server_certificate_verification(true); + outClient->enable_server_certificate_verification(verifySSL); return true; } @@ -124,13 +124,13 @@ static void EnqueueTask(lua_State* L, int cbRef, int progRef, Func&& task) { } static void Dispatch(std::string method, std::string url, std::map headers, - std::string body, int timeout, lua_State* L, int cbRef, int progRef) { + std::string body, int timeout, bool verifySSL, lua_State* L, int cbRef, int progRef) { EnqueueTask(L, cbRef, progRef,[=, b = std::move(body), hMap = std::move(headers)] (uint64_t reqId, std::shared_ptr info) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, timeout, cli, path)) throw std::runtime_error("Invalid URL Format"); + if (!SetupClient(url, timeout, verifySSL, cli, path)) throw std::runtime_error("Invalid URL Format"); httplib::Headers h; bool hasUA = false; @@ -218,44 +218,45 @@ void AsyncHttpProxy::PreparePayload(sol::object data, sol::object overrides, std } void AsyncHttpProxy::Get(std::string ep, sol::object h, sol::function cb, sol::object prog) { - Dispatch("GET", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, cb.lua_state(), MakeRef(cb), MakeRef(prog)); + Dispatch("GET", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), MakeRef(prog)); } void AsyncHttpProxy::Post(std::string ep, sol::object data, sol::object h, sol::function cb) { std::string body; std::map headers; PreparePayload(data, h, body, headers); - Dispatch("POST", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + Dispatch("POST", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } void AsyncHttpProxy::Put(std::string ep, sol::object data, sol::object h, sol::function cb) { std::string body; std::map headers; PreparePayload(data, h, body, headers); - Dispatch("PUT", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + Dispatch("PUT", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } void AsyncHttpProxy::Patch(std::string ep, sol::object data, sol::object h, sol::function cb) { std::string body; std::map headers; PreparePayload(data, h, body, headers); - Dispatch("PATCH", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + Dispatch("PATCH", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } void AsyncHttpProxy::Delete(std::string ep, sol::object h, sol::function cb) { - Dispatch("DELETE", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + Dispatch("DELETE", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } void AsyncHttpProxy::Head(std::string ep, sol::object h, sol::function cb) { - Dispatch("HEAD", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + Dispatch("HEAD", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string filePath, sol::object headers, sol::function cb) { auto hMap = PrepareHeaders(headers); std::string url = mBaseUrl + ep; int timeout = mTimeoutSeconds; + bool verify = mVerifySSL; - EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL,[=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { + EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, timeout, cli, path)) throw std::runtime_error("Invalid URL"); + if (!SetupClient(url, timeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); if (!fs::exists(filePath)) throw std::runtime_error("File not found"); auto file_stream = std::make_shared(filePath, std::ios::binary); @@ -265,10 +266,12 @@ void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string httplib::FormDataProviderItems provider_items = { { fieldName, - [file_stream](size_t offset, httplib::DataSink &sink) { + [file_stream, info](size_t offset, httplib::DataSink &sink) { + if (info->abandoned.load()) return false; + if (static_cast(file_stream->tellg()) != offset) { file_stream->clear(); - file_stream->seekg(offset, std::ios::beg); + file_stream->seekg(static_cast(offset), std::ios::beg); } char buffer[8192]; @@ -284,10 +287,14 @@ void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string }; httplib::Headers finalH; + bool hasUA = false; for (auto const& [key, val] : hMap) { - if (ToLower(key) == "content-type") continue; + std::string kLower = ToLower(key); + if (kLower == "user-agent") hasUA = true; + if (kLower == "content-type") continue; // httplib generates this for us in PostFile finalH.emplace(key, val); } + if (!hasUA) finalH.emplace("User-Agent", "BeamMP-Server/1.0"); auto response = cli->Post(path.c_str(), finalH, regular_items, provider_items); @@ -307,20 +314,24 @@ void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::function cb, sol::object prog) { std::string url = mBaseUrl + ep; int timeout = mTimeoutSeconds; + bool verify = mVerifySSL; auto hMap = PrepareHeaders(sol::lua_nil); - EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog),[=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { + EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, timeout, cli, path)) throw std::runtime_error("Invalid URL"); + if (!SetupClient(url, timeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); std::ofstream ofs(savePath, std::ios::binary); if (!ofs) throw std::runtime_error("Could not open file for writing"); httplib::Headers finalH; + bool hasUA = false; for (auto const&[key, val] : hMap) { + if (ToLower(key) == "user-agent") hasUA = true; finalH.emplace(key, val); } + if (!hasUA) finalH.emplace("User-Agent", "BeamMP-Server/1.0"); int status_code = 0; std::map resHeaders; auto lastProg = std::chrono::steady_clock::now(); @@ -361,6 +372,7 @@ void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::functio void RegisterBindings(sol::state_view& lua) { lua.new_usertype("AsyncHttp", sol::no_constructor, "SetTimeout", &AsyncHttpProxy::SetTimeout, + "VerifySSL", &AsyncHttpProxy::VerifySSL, "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), "Post", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Post(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Post), "PostFile", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string fn, std::string fp, sol::function cb) { self.PostFile(ep, fn, fp, sol::nil, cb); }, &AsyncHttpProxy::PostFile), From 52c2ee405ad6c3b102593491eb3185a98edf6813 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Mon, 27 Apr 2026 19:10:21 +0200 Subject: [PATCH 03/11] Add method to set default headers in AsyncHttpProxy --- include/HttpAsync.h | 1 + src/HttpAsync.cpp | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/include/HttpAsync.h b/include/HttpAsync.h index cd3e2610..bb0e33ec 100644 --- a/include/HttpAsync.h +++ b/include/HttpAsync.h @@ -45,6 +45,7 @@ namespace HttpAsync { void SetTimeout(int seconds); void VerifySSL(bool verify) { mVerifySSL = verify; } + void SetDefaultHeaders(sol::table headers); void Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); void Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 8ecdc129..0a0d0be5 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -369,10 +369,22 @@ void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::functio }); } +void AsyncHttpProxy::SetDefaultHeaders(sol::table headers) { + mDefaultHeaders.clear(); + if (headers != sol::lua_nil && headers.valid()) { + for (auto const& pair : headers) { + if (pair.first.is() && pair.second.is()) { + mDefaultHeaders[pair.first.as()] = pair.second.as(); + } + } + } +} + void RegisterBindings(sol::state_view& lua) { lua.new_usertype("AsyncHttp", sol::no_constructor, "SetTimeout", &AsyncHttpProxy::SetTimeout, "VerifySSL", &AsyncHttpProxy::VerifySSL, + "SetDefaultHeaders", &AsyncHttpProxy::SetDefaultHeaders, "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), "Post", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Post(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Post), "PostFile", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string fn, std::string fp, sol::function cb) { self.PostFile(ep, fn, fp, sol::nil, cb); }, &AsyncHttpProxy::PostFile), From 7209cf0122785b5a9624676cc44ca820afbfe78b Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Mon, 27 Apr 2026 21:28:55 +0200 Subject: [PATCH 04/11] Update vcpkg baseline to latest tag 2026.03.18 to update cpp-httplib that had a bug in it that was resolved in more recent version --- vcpkg | 2 +- vcpkg.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vcpkg b/vcpkg index 5bf0c552..c3867e71 160000 --- a/vcpkg +++ b/vcpkg @@ -1 +1 @@ -Subproject commit 5bf0c55239da398b8c6f450818c9e28d36bf9966 +Subproject commit c3867e714dd3a51c272826eea77267876517ed99 diff --git a/vcpkg.json b/vcpkg.json index 1814b289..b4d95e35 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -27,5 +27,5 @@ "version": "5.3.5#6" } ], - "builtin-baseline": "5bf0c55239da398b8c6f450818c9e28d36bf9966" + "builtin-baseline": "c3867e714dd3a51c272826eea77267876517ed99" } From 63944176e4789ff1105969bd670b3ad94e6aec44 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Mon, 27 Apr 2026 23:35:02 +0200 Subject: [PATCH 05/11] overhaul HttpAsync with cancellation, sandboxing, and improved header handling Co-authored-by: Copilot --- include/HttpAsync.h | 24 +++--- src/HttpAsync.cpp | 180 +++++++++++++++++++++++++++++--------------- 2 files changed, 134 insertions(+), 70 deletions(-) diff --git a/include/HttpAsync.h b/include/HttpAsync.h index bb0e33ec..fd84488c 100644 --- a/include/HttpAsync.h +++ b/include/HttpAsync.h @@ -32,7 +32,7 @@ namespace HttpAsync { int status; std::string body; - std::map headers; + std::map> headers; long long current; long long total; @@ -43,19 +43,20 @@ namespace HttpAsync { AsyncHttpProxy(std::string baseUrl, sol::table defaultHeaders); ~AsyncHttpProxy() = default; - void SetTimeout(int seconds); + void SetConnectTimeout(int seconds); + void SetReadTimeout(int seconds); void VerifySSL(bool verify) { mVerifySSL = verify; } void SetDefaultHeaders(sol::table headers); - void Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); - void Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); - void Put(std::string endpoint, sol::object data, sol::object headers, sol::function cb); - void Patch(std::string endpoint, sol::object data, sol::object headers, sol::function cb); - void Delete(std::string endpoint, sol::object headers, sol::function cb); - void Head(std::string endpoint, sol::object headers, sol::function cb); + sol::table Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); + sol::table Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + sol::table Put(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + sol::table Patch(std::string endpoint, sol::object data, sol::object headers, sol::function cb); + sol::table Delete(std::string endpoint, sol::object headers, sol::function cb); + sol::table Head(std::string endpoint, sol::object headers, sol::function cb); - void Download(std::string endpoint, std::string savePath, sol::function cb, sol::object prog); - void PostFile(std::string endpoint, std::string fieldName, std::string filePath, sol::object headers, sol::function cb); + sol::table Download(std::string endpoint, std::string savePath, sol::function cb, sol::object prog); + sol::table PostFile(std::string endpoint, std::string fieldName, std::string filePath, sol::object headers, sol::function cb); private: std::map PrepareHeaders(sol::object overrides); @@ -63,7 +64,8 @@ namespace HttpAsync { std::string mBaseUrl; std::map mDefaultHeaders; - int mTimeoutSeconds = 30; + int mConnectTimeoutSeconds = 5; + int mReadTimeoutSeconds = 30; bool mVerifySSL = true; }; diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 0a0d0be5..526bc77c 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -48,9 +48,35 @@ static std::atomic g_ShuttingDown{false}; static std::unique_ptr g_ThreadPool; static std::atomic g_NextRequestId{1}; static std::map> g_PendingRequests; +static std::map g_StateRequestCount; +static std::mutex g_LimitMutex; +const int MAX_REQUESTS_PER_STATE = 20; const int THREAD_POOL_SIZE = 8; +static sol::table CreateHandle(lua_State* L, std::shared_ptr info) { + sol::state_view lua(L); + sol::table handle = lua.create_table(); + if (info) { + handle["Cancel"] = [info]() { info->abandoned.store(true); }; + handle["IsActive"] = [info]() { return !info->abandoned.load(); }; + + // This allows the modder to attach a progress listener to the handle + handle["OnProgress"] = [L, info](sol::object func) { + if (func.is()) { + if (info->progressRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); + } + func.push(); + info->progressRef = luaL_ref(L, LUA_REGISTRYINDEX); + } + }; + } else { + handle["Error"] = "Rate limited or Shutdown"; + } + return handle; +} + static std::string ToLower(std::string s) { std::transform(s.begin(), s.end(), s.begin(),[](unsigned char c){ return static_cast(std::tolower(c)); }); return s; @@ -69,14 +95,13 @@ static int MakeRef(sol::object obj) { return luaL_ref(L, LUA_REGISTRYINDEX); } -static void ExtractHeaders(const httplib::Headers& source, std::map& dest) { +static void ExtractHeaders(const httplib::Headers& source, std::map>& dest) { for (const auto& [k, v] : source) { - if (dest.count(k)) dest[k] += ", " + v; - else dest[k] = v; + dest[k].push_back(v); } } -static bool SetupClient(const std::string& url, int timeout, bool verifySSL, std::unique_ptr& outClient, std::string& outPath) { +static bool SetupClient(const std::string& url, int connectTimeout, int readTimeout, bool verifySSL, std::unique_ptr& outClient, std::string& outPath) { static const std::regex url_regex(R"(^(https?://[^/]+)(/.*)?$)", std::regex::extended); std::smatch match; if (!std::regex_match(url, match, url_regex)) return false; @@ -84,19 +109,29 @@ static bool SetupClient(const std::string& url, int timeout, bool verifySSL, std outClient = std::make_unique(match[1].str()); outPath = match[2].length() == 0 ? "/" : match[2].str(); - outClient->set_connection_timeout(timeout, 0); - outClient->set_read_timeout(timeout, 0); + outClient->set_connection_timeout(connectTimeout, 0); + outClient->set_read_timeout(readTimeout, 0); + outClient->set_write_timeout(readTimeout, 0); outClient->set_follow_location(true); outClient->enable_server_certificate_verification(verifySSL); return true; } -template -static void EnqueueTask(lua_State* L, int cbRef, int progRef, Func&& task) { +static std::shared_ptr EnqueueTask(lua_State* L, int cbRef, int progRef, std::function)> task) { if (g_ShuttingDown.load() || !g_ThreadPool) { if (cbRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, cbRef); if (progRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, progRef); - return; + return nullptr; + } + { + std::lock_guard lock(g_LimitMutex); + if (g_StateRequestCount[L] >= MAX_REQUESTS_PER_STATE) { + beammp_lua_warnf("Plugin reached HTTP request limit ({}). Request rejected.", MAX_REQUESTS_PER_STATE); + if (cbRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, cbRef); + if (progRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, progRef); + return nullptr; + } + g_StateRequestCount[L]++; } uint64_t reqId = g_NextRequestId++; @@ -110,27 +145,24 @@ static void EnqueueTask(lua_State* L, int cbRef, int progRef, Func&& task) { g_PendingRequests[reqId] = info; } - g_ThreadPool->enqueue([reqId, info, task = std::forward(task)]() { - try { - task(reqId, info); - } catch (const std::exception& e) { - HttpResult res; - res.type = HttpResult::Type::COMPLETE; - res.requestId = reqId; - res.body = std::string("Internal Error: ") + e.what(); - PushResult(std::move(res)); - } + g_ThreadPool->enqueue([L, reqId, info, task = std::move(task)]() { + task(reqId, info); + + std::lock_guard lock(g_LimitMutex); + g_StateRequestCount[L]--; }); + + return info; } -static void Dispatch(std::string method, std::string url, std::map headers, - std::string body, int timeout, bool verifySSL, lua_State* L, int cbRef, int progRef) { +static sol::table Dispatch(std::string method, std::string url, std::map headers, + std::string body, int connectTimeout, int readTimeout, bool verifySSL, lua_State* L, int cbRef, int progRef) { - EnqueueTask(L, cbRef, progRef,[=, b = std::move(body), hMap = std::move(headers)] - (uint64_t reqId, std::shared_ptr info) { + auto info = EnqueueTask(L, cbRef, progRef, [=, b = std::move(body), hMap = std::move(headers)] + (uint64_t reqId, std::shared_ptr pReq) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, timeout, verifySSL, cli, path)) throw std::runtime_error("Invalid URL Format"); + if (!SetupClient(url, connectTimeout, readTimeout, verifySSL, cli, path)) return; httplib::Headers h; bool hasUA = false; @@ -148,8 +180,8 @@ static void Dispatch(std::string method, std::string url, std::mapabandoned.load()) return false; - if (info->progressRef != LUA_REFNIL) { + if (g_ShuttingDown.load() || pReq->abandoned.load()) return false; + if (pReq->progressRef != LUA_REFNIL) { auto now = std::chrono::steady_clock::now(); if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { HttpResult res; res.type = HttpResult::Type::PROGRESS; res.requestId = reqId; @@ -161,13 +193,15 @@ static void Dispatch(std::string method, std::string url, std::mapPost(path.c_str(), h, b, cType.c_str()); - else if (method == "PUT") response = cli->Put(path.c_str(), h, b, cType.c_str()); - else if (method == "PATCH") response = cli->Patch(path.c_str(), h, b, cType.c_str()); - else if (method == "DELETE") response = cli->Delete(path.c_str(), h); + if (method == "POST") response = cli->Post(path.c_str(), h, b, cType.c_str(), prog_func); + else if (method == "PUT") response = cli->Put(path.c_str(), h, b, cType.c_str(), prog_func); + else if (method == "PATCH") response = cli->Patch(path.c_str(), h, b, cType.c_str(), prog_func); + else if (method == "DELETE") response = cli->Delete(path.c_str(), h, prog_func); else if (method == "HEAD") response = cli->Head(path.c_str(), h); else response = cli->Get(path.c_str(), h, prog_func); + if (pReq->abandoned.load()) return; + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; if (response) { @@ -180,6 +214,8 @@ static void Dispatch(std::string method, std::string url, std::map AsyncHttpProxy::PrepareHeaders(sol::object overrides) { auto finalHeaders = mDefaultHeaders; @@ -217,46 +259,47 @@ void AsyncHttpProxy::PreparePayload(sol::object data, sol::object overrides, std } } -void AsyncHttpProxy::Get(std::string ep, sol::object h, sol::function cb, sol::object prog) { - Dispatch("GET", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), MakeRef(prog)); +sol::table AsyncHttpProxy::Get(std::string ep, sol::object h, sol::function cb, sol::object prog) { + return Dispatch("GET", mBaseUrl + ep, PrepareHeaders(h), "", mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), MakeRef(prog)); } -void AsyncHttpProxy::Post(std::string ep, sol::object data, sol::object h, sol::function cb) { +sol::table AsyncHttpProxy::Post(std::string ep, sol::object data, sol::object h, sol::function cb) { std::string body; std::map headers; PreparePayload(data, h, body, headers); - Dispatch("POST", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + return Dispatch("POST", mBaseUrl + ep, headers, std::move(body), mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } -void AsyncHttpProxy::Put(std::string ep, sol::object data, sol::object h, sol::function cb) { +sol::table AsyncHttpProxy::Put(std::string ep, sol::object data, sol::object h, sol::function cb) { std::string body; std::map headers; PreparePayload(data, h, body, headers); - Dispatch("PUT", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + return Dispatch("PUT", mBaseUrl + ep, headers, std::move(body), mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } -void AsyncHttpProxy::Patch(std::string ep, sol::object data, sol::object h, sol::function cb) { +sol::table AsyncHttpProxy::Patch(std::string ep, sol::object data, sol::object h, sol::function cb) { std::string body; std::map headers; PreparePayload(data, h, body, headers); - Dispatch("PATCH", mBaseUrl + ep, headers, std::move(body), mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); + return Dispatch("PATCH", mBaseUrl + ep, headers, std::move(body), mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } -void AsyncHttpProxy::Delete(std::string ep, sol::object h, sol::function cb) { - Dispatch("DELETE", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +sol::table AsyncHttpProxy::Delete(std::string ep, sol::object h, sol::function cb) { + return Dispatch("DELETE", mBaseUrl + ep, PrepareHeaders(h), "", mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } -void AsyncHttpProxy::Head(std::string ep, sol::object h, sol::function cb) { - Dispatch("HEAD", mBaseUrl + ep, PrepareHeaders(h), "", mTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); +sol::table AsyncHttpProxy::Head(std::string ep, sol::object h, sol::function cb) { + return Dispatch("HEAD", mBaseUrl + ep, PrepareHeaders(h), "", mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cb.lua_state(), MakeRef(cb), LUA_REFNIL); } -void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string filePath, sol::object headers, sol::function cb) { +sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string filePath, sol::object headers, sol::function cb) { auto hMap = PrepareHeaders(headers); std::string url = mBaseUrl + ep; - int timeout = mTimeoutSeconds; + int connectTimeout = mConnectTimeoutSeconds; + int readTimeout = mReadTimeoutSeconds; bool verify = mVerifySSL; - EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { + auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, timeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); + if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); if (!fs::exists(filePath)) throw std::runtime_error("File not found"); auto file_stream = std::make_shared(filePath, std::ios::binary); @@ -266,8 +309,8 @@ void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string httplib::FormDataProviderItems provider_items = { { fieldName, - [file_stream, info](size_t offset, httplib::DataSink &sink) { - if (info->abandoned.load()) return false; + [file_stream, pReq](size_t offset, httplib::DataSink &sink) { + if (pReq->abandoned.load()) return false; if (static_cast(file_stream->tellg()) != offset) { file_stream->clear(); @@ -309,18 +352,22 @@ void AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string } PushResult(std::move(res)); }); + + return CreateHandle(cb.lua_state(), info); } -void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::function cb, sol::object prog) { +sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::function cb, sol::object prog) { std::string url = mBaseUrl + ep; - int timeout = mTimeoutSeconds; + int connectTimeout = mConnectTimeoutSeconds; + int readTimeout = mReadTimeoutSeconds; + bool verify = mVerifySSL; auto hMap = PrepareHeaders(sol::lua_nil); - EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr info) { + auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, timeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); + if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); std::ofstream ofs(savePath, std::ios::binary); if (!ofs) throw std::runtime_error("Could not open file for writing"); @@ -333,22 +380,22 @@ void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::functio } if (!hasUA) finalH.emplace("User-Agent", "BeamMP-Server/1.0"); - int status_code = 0; std::map resHeaders; + int status_code = 0; std::map> resHeaders; auto lastProg = std::chrono::steady_clock::now(); auto res = cli->Get(path.c_str(), finalH, [&](const httplib::Response &r) { status_code = r.status; ExtractHeaders(r.headers, resHeaders); - return !g_ShuttingDown.load() && !info->abandoned.load(); + return !g_ShuttingDown.load() && !pReq->abandoned.load(); }, [&](const char *b, size_t l) { - if (g_ShuttingDown.load() || info->abandoned.load()) return false; + if (g_ShuttingDown.load() || pReq->abandoned.load()) return false; ofs.write(b, static_cast(l)); return true; }, [&](uint64_t len, uint64_t total) { - if (info->progressRef != LUA_REFNIL && !info->abandoned.load()) { + if (pReq->progressRef != LUA_REFNIL && !pReq->abandoned.load()) { auto now = std::chrono::steady_clock::now(); if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { HttpResult pres; pres.type = HttpResult::Type::PROGRESS; pres.requestId = reqId; @@ -367,6 +414,8 @@ void AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::functio fres.headers = resHeaders; PushResult(std::move(fres)); }); + + return CreateHandle(cb.lua_state(), info); } void AsyncHttpProxy::SetDefaultHeaders(sol::table headers) { @@ -382,7 +431,8 @@ void AsyncHttpProxy::SetDefaultHeaders(sol::table headers) { void RegisterBindings(sol::state_view& lua) { lua.new_usertype("AsyncHttp", sol::no_constructor, - "SetTimeout", &AsyncHttpProxy::SetTimeout, + "SetConnectTimeout", &AsyncHttpProxy::SetConnectTimeout, + "SetReadTimeout", &AsyncHttpProxy::SetReadTimeout, "VerifySSL", &AsyncHttpProxy::VerifySSL, "SetDefaultHeaders", &AsyncHttpProxy::SetDefaultHeaders, "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), @@ -442,7 +492,19 @@ void Update(sol::state_view& lua) { lua_rawgeti(L, LUA_REGISTRYINDEX, info->callbackRef); sol::protected_function cb = sol::stack::pop(L); if (cb.valid()) { - auto r = cb(res.status, res.body, res.headers); + sol::table luaHeaders = lua.create_table(); + for (auto const& [name, values] : res.headers) { + if (values.empty()) continue; + + std::string key = ToLower(name); + + if (values.size() > 1 || key == "set-cookie") { + luaHeaders[key] = sol::as_table(values); + } else { + luaHeaders[key] = values[0]; + } + } + auto r = cb(res.status, res.body, luaHeaders); if (!r.valid()) beammp_lua_errorf("AsyncHttp Callback Error: {}", sol::error(r).what()); } } From 1568e48dd2ef6dc1ef432032d8b0a9e11ab49eb4 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Tue, 28 Apr 2026 19:52:13 +0200 Subject: [PATCH 06/11] Add websocket support, unify User-Agent name and fix the http bindings that wasn't returning the handles. Co-authored-by: Copilot --- include/HttpAsync.h | 60 ++++++++++ src/HttpAsync.cpp | 269 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 313 insertions(+), 16 deletions(-) diff --git a/include/HttpAsync.h b/include/HttpAsync.h index fd84488c..a271164f 100644 --- a/include/HttpAsync.h +++ b/include/HttpAsync.h @@ -21,8 +21,16 @@ #include #include #include +#include +#include #include +namespace httplib { + namespace ws { + class WebSocketClient; + } +} + namespace HttpAsync { struct HttpResult { @@ -69,6 +77,58 @@ namespace HttpAsync { bool mVerifySSL = true; }; + enum class WSEventType { OPEN, MESSAGE, CLOSE, ERROR_EVENT }; + + struct WSEvent { + WSEventType type; + std::string payload; + int closeCode; + }; + + class AsyncWebSocket : public std::enable_shared_from_this { + public: + AsyncWebSocket(std::string url, sol::table headers, lua_State* state); + ~AsyncWebSocket(); + + void Connect(); + void Send(const std::string& data); + void Close(); + void VerifySSL(bool verify); + + void OnOpen(sol::object cb); + void OnMessage(sol::object cb); + void OnClose(sol::object cb); + void OnError(sol::object cb); + + void ProcessEvents(); + lua_State* GetLuaState() const { return L; } + void Abandon(); + + private: + std::string mUrl; + lua_State* L; + std::map mHeaders; + + bool mVerifySSL = true; + + std::thread mThread; + std::atomic mIsRunning{false}; + std::atomic mAbandoned{false}; + + httplib::ws::WebSocketClient* mClient = nullptr; + std::mutex mClientMutex; + + std::queue mEvents; + std::mutex mMutex; + + int mOnOpenRef = LUA_REFNIL; + int mOnMessageRef = LUA_REFNIL; + int mOnCloseRef = LUA_REFNIL; + int mOnErrorRef = LUA_REFNIL; + + void PushEvent(WSEvent ev); + }; + void Init(); void Shutdown(); void Update(sol::state_view& lua); diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 526bc77c..1074a86f 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -50,10 +50,14 @@ static std::atomic g_NextRequestId{1}; static std::map> g_PendingRequests; static std::map g_StateRequestCount; static std::mutex g_LimitMutex; +static std::vector> g_WebSockets; +static std::mutex g_WsMutex; const int MAX_REQUESTS_PER_STATE = 20; const int THREAD_POOL_SIZE = 8; +static const char* DEFAULT_USER_AGENT = "BeamMP-Server/1.0"; + static sol::table CreateHandle(lua_State* L, std::shared_ptr info) { sol::state_view lua(L); sol::table handle = lua.create_table(); @@ -176,7 +180,7 @@ static sol::table Dispatch(std::string method, std::string url, std::mapPost(path.c_str(), finalH, regular_items, provider_items); @@ -378,7 +382,7 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f if (ToLower(key) == "user-agent") hasUA = true; finalH.emplace(key, val); } - if (!hasUA) finalH.emplace("User-Agent", "BeamMP-Server/1.0"); + if (!hasUA) finalH.emplace("User-Agent", DEFAULT_USER_AGENT); int status_code = 0; std::map> resHeaders; auto lastProg = std::chrono::steady_clock::now(); @@ -429,24 +433,214 @@ void AsyncHttpProxy::SetDefaultHeaders(sol::table headers) { } } +AsyncWebSocket::AsyncWebSocket(std::string url, sol::table headers, lua_State* state) + : mUrl(std::move(url)), L(state) { + if (headers.valid()) { + for (auto const& pair : headers) { + if (pair.first.is() && pair.second.is()) { + mHeaders.emplace(pair.first.as(), pair.second.as()); + } + } + } +} + +AsyncWebSocket::~AsyncWebSocket() { + Abandon(); + if (mThread.joinable()) { + mThread.join(); + } +} + +void AsyncWebSocket::VerifySSL(bool verify) { + mVerifySSL = verify; +} + +void AsyncWebSocket::Connect() { + if (mIsRunning.exchange(true)) return; + + mThread = std::thread([this]() { + httplib::Headers h; + bool hasUA = false; + + for (const auto& [k, v] : mHeaders) { + if (ToLower(k) == "user-agent") hasUA = true; + h.emplace(k, v); + } + + if (!hasUA) h.emplace("User-Agent", DEFAULT_USER_AGENT); + + httplib::ws::WebSocketClient client(mUrl, h); + + client.enable_server_certificate_verification(mVerifySSL); + + { + std::lock_guard lock(mClientMutex); + if (mAbandoned) return; + mClient = &client; + } + + if (!client.connect()) { + PushEvent({WSEventType::ERROR_EVENT, "Failed to connect", 0}); + mIsRunning = false; + std::lock_guard lock(mClientMutex); + mClient = nullptr; + return; + } + + PushEvent({WSEventType::OPEN, "", 0}); + + std::string msg; + while (mIsRunning && !mAbandoned) { + auto res = client.read(msg); + if (res == httplib::ws::ReadResult::Fail) { + break; + } + PushEvent({WSEventType::MESSAGE, msg, 0}); + msg.clear(); + } + + PushEvent({WSEventType::CLOSE, "Connection closed", 1000}); + mIsRunning = false; + + std::lock_guard lock(mClientMutex); + mClient = nullptr; + }); +} + +void AsyncWebSocket::Send(const std::string& data) { + std::lock_guard lock(mClientMutex); + if (mClient && mClient->is_open()) { + mClient->send(data); + } +} + +void AsyncWebSocket::Close() { + mIsRunning = false; + std::lock_guard lock(mClientMutex); + if (mClient && mClient->is_open()) { + mClient->close(); + } +} + +void AsyncWebSocket::PushEvent(WSEvent ev) { + if (mAbandoned) return; + std::lock_guard lock(mMutex); + mEvents.push(std::move(ev)); +} + +void AsyncWebSocket::OnOpen(sol::object cb) { + if (mOnOpenRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnOpenRef); + mOnOpenRef = MakeRef(cb); +} +void AsyncWebSocket::OnMessage(sol::object cb) { + if (mOnMessageRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnMessageRef); + mOnMessageRef = MakeRef(cb); +} +void AsyncWebSocket::OnClose(sol::object cb) { + if (mOnCloseRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnCloseRef); + mOnCloseRef = MakeRef(cb); +} +void AsyncWebSocket::OnError(sol::object cb) { + if (mOnErrorRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnErrorRef); + mOnErrorRef = MakeRef(cb); +} + +void AsyncWebSocket::ProcessEvents() { + if (mAbandoned) return; + + std::queue events; + { + std::lock_guard lock(mMutex); + std::swap(events, mEvents); + } + + while (!events.empty()) { + auto ev = events.front(); + events.pop(); + + if (mAbandoned) break; + + try { + if (ev.type == WSEventType::OPEN && mOnOpenRef != LUA_REFNIL) { + lua_rawgeti(L, LUA_REGISTRYINDEX, mOnOpenRef); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { auto r = cb(); if (!r.valid()) beammp_lua_errorf("WS OnOpen Error: {}", sol::error(r).what()); } + } + else if (ev.type == WSEventType::MESSAGE && mOnMessageRef != LUA_REFNIL) { + lua_rawgeti(L, LUA_REGISTRYINDEX, mOnMessageRef); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { auto r = cb(ev.payload); if (!r.valid()) beammp_lua_errorf("WS OnMessage Error: {}", sol::error(r).what()); } + } + else if (ev.type == WSEventType::CLOSE && mOnCloseRef != LUA_REFNIL) { + lua_rawgeti(L, LUA_REGISTRYINDEX, mOnCloseRef); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { auto r = cb(ev.closeCode, ev.payload); if (!r.valid()) beammp_lua_errorf("WS OnClose Error: {}", sol::error(r).what()); } + } + else if (ev.type == WSEventType::ERROR_EVENT && mOnErrorRef != LUA_REFNIL) { + lua_rawgeti(L, LUA_REGISTRYINDEX, mOnErrorRef); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { auto r = cb(ev.payload); if (!r.valid()) beammp_lua_errorf("WS OnError Error: {}", sol::error(r).what()); } + } + } catch (const std::exception& e) { + beammp_lua_errorf("WebSocket Exception: {}", e.what()); + } + } +} + +void AsyncWebSocket::Abandon() { + mAbandoned = true; + Close(); + + if (mOnOpenRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnOpenRef); mOnOpenRef = LUA_REFNIL; } + if (mOnMessageRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnMessageRef); mOnMessageRef = LUA_REFNIL; } + if (mOnCloseRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnCloseRef); mOnCloseRef = LUA_REFNIL; } + if (mOnErrorRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnErrorRef); mOnErrorRef = LUA_REFNIL; } +} + void RegisterBindings(sol::state_view& lua) { lua.new_usertype("AsyncHttp", sol::no_constructor, "SetConnectTimeout", &AsyncHttpProxy::SetConnectTimeout, "SetReadTimeout", &AsyncHttpProxy::SetReadTimeout, "VerifySSL", &AsyncHttpProxy::VerifySSL, "SetDefaultHeaders", &AsyncHttpProxy::SetDefaultHeaders, - "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), - "Post", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Post(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Post), - "PostFile", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string fn, std::string fp, sol::function cb) { self.PostFile(ep, fn, fp, sol::nil, cb); }, &AsyncHttpProxy::PostFile), - "Put", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Put(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Put), - "Patch", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { self.Patch(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Patch), - "Delete", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { self.Delete(ep, sol::nil, cb); }, &AsyncHttpProxy::Delete), - "Head", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { self.Head(ep, sol::nil, cb); }, &AsyncHttpProxy::Head), - "Download", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string p, sol::function cb) { self.Download(ep, p, cb, sol::nil); }, &AsyncHttpProxy::Download) + "Get", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object h, sol::function cb) { return self.Get(ep, h, cb, sol::nil); }, &AsyncHttpProxy::Get), + "Post", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { return self.Post(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Post), + "PostFile", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string fn, std::string fp, sol::function cb) { return self.PostFile(ep, fn, fp, sol::nil, cb); }, &AsyncHttpProxy::PostFile), + "Put", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { return self.Put(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Put), + "Patch", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::object d, sol::function cb) { return self.Patch(ep, d, sol::nil, cb); }, &AsyncHttpProxy::Patch), + "Delete", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { return self.Delete(ep, sol::nil, cb); }, &AsyncHttpProxy::Delete), + "Head", sol::overload([](AsyncHttpProxy& self, std::string ep, sol::function cb) { return self.Head(ep, sol::nil, cb); }, &AsyncHttpProxy::Head), + "Download", sol::overload([](AsyncHttpProxy& self, std::string ep, std::string p, sol::function cb) { return self.Download(ep, p, cb, sol::nil); }, &AsyncHttpProxy::Download) ); lua["AsyncHttp"]["new"] = sol::overload([](std::string url) { return std::make_shared(url, sol::table(sol::lua_nil)); },[](std::string url, sol::table headers) { return std::make_shared(url, headers); } ); + + lua.new_usertype("AsyncWebSocket", sol::no_constructor, + "Connect", &AsyncWebSocket::Connect, + "Send", &AsyncWebSocket::Send, + "Close", &AsyncWebSocket::Close, + "VerifySSL", &AsyncWebSocket::VerifySSL, + "OnOpen", &AsyncWebSocket::OnOpen, + "OnMessage", &AsyncWebSocket::OnMessage, + "OnClose", &AsyncWebSocket::OnClose, + "OnError", &AsyncWebSocket::OnError + ); + + lua["AsyncWebSocket"]["new"] = [](sol::this_state s, std::string url, sol::object headers) { + sol::table h; + + if (headers.is()) { + h = headers.as(); + } + + auto ws = std::make_shared(url, h, s.lua_state()); + + std::lock_guard lock(g_WsMutex); + g_WebSockets.push_back(ws); + + return ws; + }; } void Update(sol::state_view& lua) { @@ -454,7 +648,7 @@ void Update(sol::state_view& lua) { std::deque toProcess; { - std::lock_guard lock(g_Mutex); + std::lock_guard httpLock(g_Mutex); auto it = g_Results.begin(); while (it != g_Results.end()) { auto reqIt = g_PendingRequests.find(it->requestId); @@ -472,7 +666,7 @@ void Update(sol::state_view& lua) { for (const auto& res : toProcess) { std::shared_ptr info; { - std::lock_guard lock(g_Mutex); + std::lock_guard httpLock(g_Mutex); if (g_PendingRequests.count(res.requestId)) info = g_PendingRequests[res.requestId]; } @@ -518,14 +712,34 @@ void Update(sol::state_view& lua) { info->progressRef = LUA_REFNIL; } - std::lock_guard lock(g_Mutex); + std::lock_guard httpLock(g_Mutex); g_PendingRequests.erase(res.requestId); } } + + std::vector> websocketsToUpdate; + { + std::lock_guard wsLock(g_WsMutex); + auto wsIt = g_WebSockets.begin(); + while (wsIt != g_WebSockets.end()) { + if (auto ws = wsIt->lock()) { + if (ws->GetLuaState() == L) { + websocketsToUpdate.push_back(ws); + } + ++wsIt; + } else { + wsIt = g_WebSockets.erase(wsIt); + } + } + } + + for (auto& ws : websocketsToUpdate) { + ws->ProcessEvents(); + } } void CleanupState(lua_State* L) { - std::lock_guard lock(g_Mutex); + std::lock_guard httpLock(g_Mutex); for (auto it = g_PendingRequests.begin(); it != g_PendingRequests.end(); ) { if (it->second->L == L) { it->second->abandoned.store(true); @@ -544,6 +758,21 @@ void CleanupState(lua_State* L) { ++it; } } + + std::lock_guard wsLock(g_WsMutex); + auto wsIt = g_WebSockets.begin(); + while (wsIt != g_WebSockets.end()) { + if (auto ws = wsIt->lock()) { + if (ws->GetLuaState() == L) { + ws->Abandon(); + wsIt = g_WebSockets.erase(wsIt); + } else { + ++wsIt; + } + } else { + wsIt = g_WebSockets.erase(wsIt); + } + } } void Init() { @@ -555,9 +784,17 @@ void Shutdown() { g_ShuttingDown.store(true); if (g_ThreadPool) g_ThreadPool->shutdown(); g_ThreadPool.reset(); - std::lock_guard lock(g_Mutex); + std::lock_guard httpLock(g_Mutex); g_PendingRequests.clear(); g_Results.clear(); + + std::lock_guard wsLock(g_WsMutex); + for (auto& weak_ws : g_WebSockets) { + if (auto ws = weak_ws.lock()) { + ws->Abandon(); + } + } + g_WebSockets.clear(); } } // namespace HttpAsync \ No newline at end of file From f33f4c0e98aa09499e5e19bea5183aa31ea44241 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Tue, 28 Apr 2026 20:15:06 +0200 Subject: [PATCH 07/11] Improve error handling to return detailed responses for invalid URLs and file access issues --- src/HttpAsync.cpp | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 1074a86f..6ff57018 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -82,7 +82,11 @@ static sol::table CreateHandle(lua_State* L, std::shared_ptr inf } static std::string ToLower(std::string s) { - std::transform(s.begin(), s.end(), s.begin(),[](unsigned char c){ return static_cast(std::tolower(c)); }); + for (char &c : s) { + if (c >= 'A' && c <= 'Z') { + c += 32; + } + } return s; } @@ -166,7 +170,9 @@ static sol::table Dispatch(std::string method, std::string url, std::map pReq) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, connectTimeout, readTimeout, verifySSL, cli, path)) return; + if (!SetupClient(url, connectTimeout, readTimeout, verifySSL, cli, path)) { + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Invalid URL"; PushResult(std::move(res)); return; + } httplib::Headers h; bool hasUA = false; @@ -303,11 +309,17 @@ sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std:: auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); - if (!fs::exists(filePath)) throw std::runtime_error("File not found"); + if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) { + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Invalid URL"; PushResult(std::move(res)); return; + } + if (!fs::exists(filePath)) { + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "File not found"; PushResult(std::move(res)); return; + } auto file_stream = std::make_shared(filePath, std::ios::binary); - if (!file_stream || !file_stream->is_open()) throw std::runtime_error("Could not open file"); + if (!file_stream || !file_stream->is_open()) { + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Could not open file"; PushResult(std::move(res)); return; + } httplib::UploadFormDataItems regular_items; httplib::FormDataProviderItems provider_items = { @@ -371,10 +383,14 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) { std::string path; std::unique_ptr cli; - if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) throw std::runtime_error("Invalid URL"); + if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) { + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Invalid URL"; PushResult(std::move(res)); return; + } std::ofstream ofs(savePath, std::ios::binary); - if (!ofs) throw std::runtime_error("Could not open file for writing"); + if (!ofs) { + HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Could not open file for writing"; PushResult(std::move(res)); return; + } httplib::Headers finalH; bool hasUA = false; From 4e4cad815d35e99c4f59aad2ad4a5c129d21807a Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Tue, 28 Apr 2026 23:00:21 +0200 Subject: [PATCH 08/11] Enhance WebSocket and HTTP request limits with dynamic configuration and improved cancellation handling --- src/HttpAsync.cpp | 95 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 83 insertions(+), 12 deletions(-) diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 6ff57018..cb6796bf 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -52,9 +52,13 @@ static std::map g_StateRequestCount; static std::mutex g_LimitMutex; static std::vector> g_WebSockets; static std::mutex g_WsMutex; +static std::map g_StateWsCount; -const int MAX_REQUESTS_PER_STATE = 20; -const int THREAD_POOL_SIZE = 8; +static int g_ActualPoolSize = 16; +static int g_MaxRequestsPerPlugin = 8; +static int g_MaxWsPerPlugin = 4; +static int g_MaxWsGlobal = 32; +static int g_CurrentWsGlobal = 0; static const char* DEFAULT_USER_AGENT = "BeamMP-Server/1.0"; @@ -62,7 +66,18 @@ static sol::table CreateHandle(lua_State* L, std::shared_ptr inf sol::state_view lua(L); sol::table handle = lua.create_table(); if (info) { - handle["Cancel"] = [info]() { info->abandoned.store(true); }; + handle["Cancel"] = [info, L]() { + info->abandoned.store(true); + + if (info->callbackRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->callbackRef); + info->callbackRef = LUA_REFNIL; + } + if (info->progressRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); + info->progressRef = LUA_REFNIL; + } + }; handle["IsActive"] = [info]() { return !info->abandoned.load(); }; // This allows the modder to attach a progress listener to the handle @@ -133,8 +148,8 @@ static std::shared_ptr EnqueueTask(lua_State* L, int cbRef, int } { std::lock_guard lock(g_LimitMutex); - if (g_StateRequestCount[L] >= MAX_REQUESTS_PER_STATE) { - beammp_lua_warnf("Plugin reached HTTP request limit ({}). Request rejected.", MAX_REQUESTS_PER_STATE); + if (g_StateRequestCount[L] >= g_MaxRequestsPerPlugin) { + beammp_lua_warnf("Plugin reached HTTP request limit ({}). Request rejected.", g_MaxRequestsPerPlugin); if (cbRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, cbRef); if (progRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, progRef); return nullptr; @@ -604,6 +619,16 @@ void AsyncWebSocket::ProcessEvents() { } void AsyncWebSocket::Abandon() { + if (mAbandoned.exchange(true)) return; + + { + std::lock_guard lock(g_LimitMutex); + g_StateWsCount[L]--; + if (g_StateWsCount[L] <= 0) g_StateWsCount.erase(L); + + g_CurrentWsGlobal--; + } + mAbandoned = true; Close(); @@ -643,19 +668,39 @@ void RegisterBindings(sol::state_view& lua) { "OnError", &AsyncWebSocket::OnError ); - lua["AsyncWebSocket"]["new"] = [](sol::this_state s, std::string url, sol::object headers) { + lua["AsyncWebSocket"]["new"] = [](sol::this_state s, std::string url, sol::object headers) -> sol::object { + lua_State* L = s.lua_state(); + { + std::lock_guard lock(g_LimitMutex); + + if (g_CurrentWsGlobal >= g_MaxWsGlobal) { + beammp_lua_warnf("Server-wide WebSocket limit reached ({}).", g_MaxWsGlobal); + return sol::make_object(s, sol::lua_nil); + } + + if (g_StateWsCount[L] >= g_MaxWsPerPlugin) { + beammp_lua_warnf("Plugin reached WebSocket limit ({}).", g_MaxWsPerPlugin); + return sol::make_object(s, sol::lua_nil); + } + + g_StateWsCount[L]++; + g_CurrentWsGlobal++; + } + sol::table h; if (headers.is()) { h = headers.as(); } - auto ws = std::make_shared(url, h, s.lua_state()); + auto ws = std::make_shared(url, h, L); - std::lock_guard lock(g_WsMutex); - g_WebSockets.push_back(ws); + { + std::lock_guard lock(g_WsMutex); + g_WebSockets.push_back(ws); + } - return ws; + return sol::make_object(s, ws); }; } @@ -686,7 +731,21 @@ void Update(sol::state_view& lua) { if (g_PendingRequests.count(res.requestId)) info = g_PendingRequests[res.requestId]; } - if (!info || info->abandoned.load()) continue; + if (!info || info->abandoned.load()) { + if (info) { + if (info->callbackRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->callbackRef); + info->callbackRef = LUA_REFNIL; + } + if (info->progressRef != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); + info->progressRef = LUA_REFNIL; + } + } + std::lock_guard httpLock(g_Mutex); + g_PendingRequests.erase(res.requestId); + continue; + } if (res.type == HttpResult::Type::PROGRESS) { if (info->progressRef != LUA_REFNIL) { @@ -793,7 +852,19 @@ void CleanupState(lua_State* L) { void Init() { g_ShuttingDown.store(false); - g_ThreadPool = std::make_unique(THREAD_POOL_SIZE); + int cores = static_cast(std::thread::hardware_concurrency()); + if (cores <= 0) cores = 4; + + + g_ActualPoolSize = std::clamp(cores * 4, 16, 128); + g_MaxRequestsPerPlugin = std::max(g_ActualPoolSize / 2, 5); + g_ThreadPool = std::make_unique(g_ActualPoolSize); + + g_MaxWsGlobal = std::max(cores * 8, 32); + g_MaxWsPerPlugin = std::max(g_MaxWsGlobal / 4, 4); + + beammp_infof("AsyncHttp initialized. HTTP Pool: {} ({} per plugin). WS Quota: {} ({} per plugin).", + g_ActualPoolSize, g_MaxRequestsPerPlugin, g_MaxWsGlobal, g_MaxWsPerPlugin); } void Shutdown() { From d230c4f8e41bfe4914cc61cf6509c82d460e389b Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Wed, 29 Apr 2026 14:38:15 +0200 Subject: [PATCH 09/11] Refactor code structure for improved readability and maintainability --- include/HttpAsync.h | 39 ++- src/HttpAsync.cpp | 702 ++++++++++++++++++++++---------------------- 2 files changed, 372 insertions(+), 369 deletions(-) diff --git a/include/HttpAsync.h b/include/HttpAsync.h index a271164f..72932325 100644 --- a/include/HttpAsync.h +++ b/include/HttpAsync.h @@ -20,11 +20,16 @@ #include #include +#include #include #include #include +#include +#include +#include #include +// Forward declaration for WebSocket client to keep header lean namespace httplib { namespace ws { class WebSocketClient; @@ -38,24 +43,28 @@ namespace HttpAsync { uint64_t requestId; - int status; + // Progress data + long long current = 0; + long long total = 0; + + // Response data + int status = 0; std::string body; std::map> headers; - - long long current; - long long total; }; - class AsyncHttpProxy { + class AsyncHttpProxy : public std::enable_shared_from_this { public: AsyncHttpProxy(std::string baseUrl, sol::table defaultHeaders); ~AsyncHttpProxy() = default; + // Configuration void SetConnectTimeout(int seconds); void SetReadTimeout(int seconds); - void VerifySSL(bool verify) { mVerifySSL = verify; } + void VerifySSL(bool verify); void SetDefaultHeaders(sol::table headers); + // HTTP Methods sol::table Get(std::string endpoint, sol::object headers, sol::function cb, sol::object prog); sol::table Post(std::string endpoint, sol::object data, sol::object headers, sol::function cb); sol::table Put(std::string endpoint, sol::object data, sol::object headers, sol::function cb); @@ -63,6 +72,7 @@ namespace HttpAsync { sol::table Delete(std::string endpoint, sol::object headers, sol::function cb); sol::table Head(std::string endpoint, sol::object headers, sol::function cb); + // File Operations sol::table Download(std::string endpoint, std::string savePath, sol::function cb, sol::object prog); sol::table PostFile(std::string endpoint, std::string fieldName, std::string filePath, sol::object headers, sol::function cb); @@ -79,7 +89,7 @@ namespace HttpAsync { enum class WSEventType { OPEN, MESSAGE, CLOSE, ERROR_EVENT }; - struct WSEvent { + struct WSEvent { WSEventType type; std::string payload; int closeCode; @@ -87,6 +97,8 @@ namespace HttpAsync { class AsyncWebSocket : public std::enable_shared_from_this { public: + static sol::object Create(sol::this_state s, std::string url, sol::object headers); + AsyncWebSocket(std::string url, sol::table headers, lua_State* state); ~AsyncWebSocket(); @@ -95,40 +107,45 @@ namespace HttpAsync { void Close(); void VerifySSL(bool verify); + // Lua Callback Registration void OnOpen(sol::object cb); void OnMessage(sol::object cb); void OnClose(sol::object cb); void OnError(sol::object cb); void ProcessEvents(); - lua_State* GetLuaState() const { return L; } void Abandon(); + + [[nodiscard]] lua_State* GetLuaState() const { return L; } private: + void PushEvent(WSEvent ev); + std::string mUrl; lua_State* L; std::map mHeaders; - bool mVerifySSL = true; std::thread mThread; std::atomic mIsRunning{false}; std::atomic mAbandoned{false}; + // Internal httplib pointer and sync httplib::ws::WebSocketClient* mClient = nullptr; std::mutex mClientMutex; + // Event Queue std::queue mEvents; std::mutex mMutex; + // Lua Registry References int mOnOpenRef = LUA_REFNIL; int mOnMessageRef = LUA_REFNIL; int mOnCloseRef = LUA_REFNIL; int mOnErrorRef = LUA_REFNIL; - - void PushEvent(WSEvent ev); }; + // Module Lifecycle void Init(); void Shutdown(); void Update(sol::state_view& lua); diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index cb6796bf..9f35247f 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -25,11 +25,11 @@ #include #include #include -#include #include #include #include #include +#include namespace fs = std::filesystem; @@ -42,73 +42,56 @@ struct PendingRequest { std::atomic abandoned{false}; }; -static std::deque g_Results; -static std::mutex g_Mutex; -static std::atomic g_ShuttingDown{false}; -static std::unique_ptr g_ThreadPool; -static std::atomic g_NextRequestId{1}; -static std::map> g_PendingRequests; -static std::map g_StateRequestCount; -static std::mutex g_LimitMutex; -static std::vector> g_WebSockets; -static std::mutex g_WsMutex; -static std::map g_StateWsCount; - -static int g_ActualPoolSize = 16; -static int g_MaxRequestsPerPlugin = 8; -static int g_MaxWsPerPlugin = 4; -static int g_MaxWsGlobal = 32; -static int g_CurrentWsGlobal = 0; +// --- Centralized Global Context --- +static struct GlobalContext { + std::deque results; + std::mutex resultsMutex; + + std::atomic shuttingDown{false}; + std::unique_ptr threadPool; + std::atomic nextRequestId{1}; + + std::map> pendingRequests; + std::map stateRequestCount; + std::mutex limitMutex; + + std::vector> webSockets; + std::mutex wsMutex; + std::map stateWsCount; + + int actualPoolSize = 16; + int maxRequestsPerPlugin = 8; + int maxWsPerPlugin = 4; + int maxWsGlobal = 32; + int currentWsGlobal = 0; +} ctx; static const char* DEFAULT_USER_AGENT = "BeamMP-Server/1.0"; -static sol::table CreateHandle(lua_State* L, std::shared_ptr info) { - sol::state_view lua(L); - sol::table handle = lua.create_table(); - if (info) { - handle["Cancel"] = [info, L]() { - info->abandoned.store(true); +// --- Utilities --- - if (info->callbackRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->callbackRef); - info->callbackRef = LUA_REFNIL; - } - if (info->progressRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); - info->progressRef = LUA_REFNIL; - } - }; - handle["IsActive"] = [info]() { return !info->abandoned.load(); }; - - // This allows the modder to attach a progress listener to the handle - handle["OnProgress"] = [L, info](sol::object func) { - if (func.is()) { - if (info->progressRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); - } - func.push(); - info->progressRef = luaL_ref(L, LUA_REGISTRYINDEX); - } - }; - } else { - handle["Error"] = "Rate limited or Shutdown"; +static void ToLowerInPlace(std::string& s) { + for (char& c : s) { + if (c >= 'A' && c <= 'Z') c += 32; } - return handle; } static std::string ToLower(std::string s) { - for (char &c : s) { - if (c >= 'A' && c <= 'Z') { - c += 32; - } - } + ToLowerInPlace(s); return s; } -static void PushResult(HttpResult res) { - if (g_ShuttingDown.load()) return; - std::lock_guard lock(g_Mutex); - g_Results.push_back(std::move(res)); +static void UnrefCallback(lua_State* L, int& ref) { + if (ref != LUA_REFNIL) { + luaL_unref(L, LUA_REGISTRYINDEX, ref); + ref = LUA_REFNIL; + } +} + +static void ReleasePendingRequest(const std::shared_ptr& info) { + if (!info) return; + UnrefCallback(info->L, info->callbackRef); + UnrefCallback(info->L, info->progressRef); } static int MakeRef(sol::object obj) { @@ -118,20 +101,82 @@ static int MakeRef(sol::object obj) { return luaL_ref(L, LUA_REGISTRYINDEX); } +template +static void InvokeLuaCallback(lua_State* L, int ref, const char* errorContext, Args&&... args) { + if (ref == LUA_REFNIL) return; + lua_rawgeti(L, LUA_REGISTRYINDEX, ref); + sol::protected_function cb = sol::stack::pop(L); + if (cb.valid()) { + auto r = cb(std::forward(args)...); + if (!r.valid()) beammp_lua_errorf("%s: %s", errorContext, sol::error(r).what()); + } +} + +static void PushResult(HttpResult res) { + if (ctx.shuttingDown.load()) return; + std::lock_guard lock(ctx.resultsMutex); + ctx.results.push_back(std::move(res)); +} + static void ExtractHeaders(const httplib::Headers& source, std::map>& dest) { - for (const auto& [k, v] : source) { - dest[k].push_back(v); + for (const auto& [k, v] : source) dest[k].push_back(v); +} + +static bool ParseUrl(const std::string& url, std::string& base, std::string& path) { + if (url.rfind("http://", 0) != 0 && url.rfind("https://", 0) != 0) return false; + + auto pos = url.find("://"); + if (pos == std::string::npos) return false; + + auto pathPos = url.find_first_of("/?#", pos + 3); + if (pathPos == std::string::npos) { + base = url; + path = "/"; + } else { + base = url.substr(0, pathPos); + path = url.substr(pathPos); } + + if (base.length() <= pos + 3) return false; + + return true; } -static bool SetupClient(const std::string& url, int connectTimeout, int readTimeout, bool verifySSL, std::unique_ptr& outClient, std::string& outPath) { - static const std::regex url_regex(R"(^(https?://[^/]+)(/.*)?$)", std::regex::extended); - std::smatch match; - if (!std::regex_match(url, match, url_regex)) return false; +static bool IsValidWsUrl(const std::string& url) { + return url.rfind("ws://", 0) == 0 || url.rfind("wss://", 0) == 0; +} - outClient = std::make_unique(match[1].str()); - outPath = match[2].length() == 0 ? "/" : match[2].str(); +// --- HTTP Implementation --- +static sol::table CreateHandle(lua_State* L, std::shared_ptr info) { + sol::state_view lua(L); + sol::table handle = lua.create_table(); + + if (info) { + handle["Cancel"] = [info]() { + info->abandoned.store(true); + ReleasePendingRequest(info); + }; + handle["IsActive"] = [info]() { + return !info->abandoned.load(); + }; + handle["OnProgress"] = [info](sol::object func) { + if (func.is()) { + UnrefCallback(info->L, info->progressRef); + info->progressRef = MakeRef(func); + } + }; + } else { + handle["Error"] = "Rate limited or Shutdown"; + } + return handle; +} + +static bool SetupClient(const std::string& url, int connectTimeout, int readTimeout, bool verifySSL, std::unique_ptr& outClient, std::string& outPath) { + std::string base; + if (!ParseUrl(url, base, outPath)) return false; + + outClient = std::make_unique(base); outClient->set_connection_timeout(connectTimeout, 0); outClient->set_read_timeout(readTimeout, 0); outClient->set_write_timeout(readTimeout, 0); @@ -141,61 +186,63 @@ static bool SetupClient(const std::string& url, int connectTimeout, int readTime } static std::shared_ptr EnqueueTask(lua_State* L, int cbRef, int progRef, std::function)> task) { - if (g_ShuttingDown.load() || !g_ThreadPool) { - if (cbRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, cbRef); - if (progRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, progRef); + if (ctx.shuttingDown.load() || !ctx.threadPool) { + UnrefCallback(L, cbRef); + UnrefCallback(L, progRef); return nullptr; } + { - std::lock_guard lock(g_LimitMutex); - if (g_StateRequestCount[L] >= g_MaxRequestsPerPlugin) { - beammp_lua_warnf("Plugin reached HTTP request limit ({}). Request rejected.", g_MaxRequestsPerPlugin); - if (cbRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, cbRef); - if (progRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, progRef); + std::lock_guard lock(ctx.limitMutex); + if (ctx.stateRequestCount[L] >= ctx.maxRequestsPerPlugin) { + beammp_lua_warnf("Plugin reached HTTP request limit ({}). Request rejected.", ctx.maxRequestsPerPlugin); + UnrefCallback(L, cbRef); + UnrefCallback(L, progRef); return nullptr; } - g_StateRequestCount[L]++; + ctx.stateRequestCount[L]++; } - uint64_t reqId = g_NextRequestId++; + uint64_t reqId = ctx.nextRequestId++; auto info = std::make_shared(); info->L = L; info->callbackRef = cbRef; info->progressRef = progRef; { - std::lock_guard lock(g_Mutex); - g_PendingRequests[reqId] = info; + std::lock_guard lock(ctx.resultsMutex); + ctx.pendingRequests[reqId] = info; } - g_ThreadPool->enqueue([L, reqId, info, task = std::move(task)]() { + ctx.threadPool->enqueue([reqId, info, task = std::move(task)]() { task(reqId, info); - - std::lock_guard lock(g_LimitMutex); - g_StateRequestCount[L]--; }); return info; } static sol::table Dispatch(std::string method, std::string url, std::map headers, - std::string body, int connectTimeout, int readTimeout, bool verifySSL, lua_State* L, int cbRef, int progRef) { + std::string body, int connectTimeout, int readTimeout, bool verifySSL, lua_State* L, int cbRef, int progRef) { - auto info = EnqueueTask(L, cbRef, progRef, [=, b = std::move(body), hMap = std::move(headers)] - (uint64_t reqId, std::shared_ptr pReq) { + auto info = EnqueueTask(L, cbRef, progRef,[=, b = std::move(body), hMap = std::move(headers)] + (uint64_t reqId, std::shared_ptr pReq) { std::string path; std::unique_ptr cli; + if (!SetupClient(url, connectTimeout, readTimeout, verifySSL, cli, path)) { - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Invalid URL"; PushResult(std::move(res)); return; + HttpResult res{HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}; + PushResult(std::move(res)); + return; } httplib::Headers h; bool hasUA = false; std::string cType = "application/json"; - for (auto const&[key, val] : hMap) { - if (ToLower(key) == "user-agent") hasUA = true; - if (ToLower(key) == "content-type") { + for (const auto&[key, val] : hMap) { + std::string kLower = ToLower(key); + if (kLower == "user-agent") hasUA = true; + if (kLower == "content-type") { cType = val; if (method == "POST" || method == "PUT" || method == "PATCH") continue; } @@ -205,20 +252,21 @@ static sol::table Dispatch(std::string method, std::string url, std::mapabandoned.load()) return false; + if (ctx.shuttingDown.load() || pReq->abandoned.load()) return false; + if (pReq->progressRef != LUA_REFNIL) { auto now = std::chrono::steady_clock::now(); if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { - HttpResult res; res.type = HttpResult::Type::PROGRESS; res.requestId = reqId; - res.current = static_cast(len); res.total = static_cast(total); - PushResult(std::move(res)); lastProg = now; + HttpResult res{HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}; + PushResult(std::move(res)); + lastProg = now; } } return true; }; httplib::Result response; - if (method == "POST") response = cli->Post(path.c_str(), h, b, cType.c_str(), prog_func); + if (method == "POST") response = cli->Post(path.c_str(), h, b, cType.c_str(), prog_func); else if (method == "PUT") response = cli->Put(path.c_str(), h, b, cType.c_str(), prog_func); else if (method == "PATCH") response = cli->Patch(path.c_str(), h, b, cType.c_str(), prog_func); else if (method == "DELETE") response = cli->Delete(path.c_str(), h, prog_func); @@ -227,14 +275,12 @@ static sol::table Dispatch(std::string method, std::string url, std::mapabandoned.load()) return; - HttpResult res; - res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; + HttpResult res{HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "", {}}; if (response) { res.status = response->status; res.body = std::move(response->body); ExtractHeaders(response->headers, res.headers); } else { - res.status = 0; res.body = "Network Error: " + httplib::to_string(response.error()); } PushResult(std::move(res)); @@ -243,23 +289,26 @@ static sol::table Dispatch(std::string method, std::string url, std::map() && pair.second.is()) mDefaultHeaders[pair.first.as()] = pair.second.as(); } } } -void AsyncHttpProxy::SetConnectTimeout(int seconds) { - mConnectTimeoutSeconds = seconds; -} - -void AsyncHttpProxy::SetReadTimeout(int seconds) { - mReadTimeoutSeconds = seconds; -} - std::map AsyncHttpProxy::PrepareHeaders(sol::object overrides) { auto finalHeaders = mDefaultHeaders; if (overrides.is()) { @@ -273,8 +322,11 @@ std::map AsyncHttpProxy::PrepareHeaders(sol::object ov void AsyncHttpProxy::PreparePayload(sol::object data, sol::object overrides, std::string& outBody, std::map& outHeaders) { outHeaders = PrepareHeaders(overrides); + bool hasCT = false; - for (const auto& [k, v] : outHeaders) if (ToLower(k) == "content-type") hasCT = true; + for (const auto& [k, v] : outHeaders) { + if (ToLower(k) == "content-type") hasCT = true; + } if (data.is()) { outBody = LuaAPI::MP::JsonEncode(data.as()); @@ -316,24 +368,25 @@ sol::table AsyncHttpProxy::Head(std::string ep, sol::object h, sol::function cb) sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std::string filePath, sol::object headers, sol::function cb) { auto hMap = PrepareHeaders(headers); - std::string url = mBaseUrl + ep; - int connectTimeout = mConnectTimeoutSeconds; - int readTimeout = mReadTimeoutSeconds; - bool verify = mVerifySSL; - - auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) { + auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), LUA_REFNIL, + [this, ep, fieldName, filePath, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) mutable { std::string path; std::unique_ptr cli; - if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) { - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Invalid URL"; PushResult(std::move(res)); return; + + if (!SetupClient(mBaseUrl + ep, mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cli, path)) { + PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); + return; } + if (!fs::exists(filePath)) { - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "File not found"; PushResult(std::move(res)); return; + PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "File not found", {}}); + return; } auto file_stream = std::make_shared(filePath, std::ios::binary); if (!file_stream || !file_stream->is_open()) { - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Could not open file"; PushResult(std::move(res)); return; + PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file", {}}); + return; } httplib::UploadFormDataItems regular_items; @@ -362,23 +415,22 @@ sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std:: httplib::Headers finalH; bool hasUA = false; - for (auto const& [key, val] : hMap) { + for (const auto& [key, val] : hMap) { std::string kLower = ToLower(key); if (kLower == "user-agent") hasUA = true; - if (kLower == "content-type") continue; // httplib generates this for us in PostFile + if (kLower == "content-type") continue; finalH.emplace(key, val); } if (!hasUA) finalH.emplace("User-Agent", DEFAULT_USER_AGENT); auto response = cli->Post(path.c_str(), finalH, regular_items, provider_items); - - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; + HttpResult res{HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "", {}}; + if (response) { res.status = response->status; res.body = std::move(response->body); ExtractHeaders(response->headers, res.headers); } else { - res.status = 0; res.body = "Upload Failed: " + httplib::to_string(response.error()); } PushResult(std::move(res)); @@ -388,44 +440,43 @@ sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std:: } sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::function cb, sol::object prog) { - std::string url = mBaseUrl + ep; - int connectTimeout = mConnectTimeoutSeconds; - int readTimeout = mReadTimeoutSeconds; - - bool verify = mVerifySSL; auto hMap = PrepareHeaders(sol::lua_nil); - - auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), [=, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) { + auto info = EnqueueTask(cb.lua_state(), MakeRef(cb), MakeRef(prog), + [this, ep, savePath, hMap = std::move(hMap)](uint64_t reqId, std::shared_ptr pReq) mutable { std::string path; std::unique_ptr cli; - if (!SetupClient(url, connectTimeout, readTimeout, verify, cli, path)) { - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Invalid URL"; PushResult(std::move(res)); return; + + if (!SetupClient(mBaseUrl + ep, mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cli, path)) { + PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); + return; } std::ofstream ofs(savePath, std::ios::binary); if (!ofs) { - HttpResult res; res.type = HttpResult::Type::COMPLETE; res.requestId = reqId; res.status = 0; res.body = "Could not open file for writing"; PushResult(std::move(res)); return; + PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file for writing", {}}); + return; } httplib::Headers finalH; bool hasUA = false; - for (auto const&[key, val] : hMap) { + for (const auto&[key, val] : hMap) { if (ToLower(key) == "user-agent") hasUA = true; finalH.emplace(key, val); } if (!hasUA) finalH.emplace("User-Agent", DEFAULT_USER_AGENT); - int status_code = 0; std::map> resHeaders; + int status_code = 0; + std::map> resHeaders; auto lastProg = std::chrono::steady_clock::now(); auto res = cli->Get(path.c_str(), finalH, [&](const httplib::Response &r) { status_code = r.status; ExtractHeaders(r.headers, resHeaders); - return !g_ShuttingDown.load() && !pReq->abandoned.load(); + return !ctx.shuttingDown.load() && !pReq->abandoned.load(); }, [&](const char *b, size_t l) { - if (g_ShuttingDown.load() || pReq->abandoned.load()) return false; + if (ctx.shuttingDown.load() || pReq->abandoned.load()) return false; ofs.write(b, static_cast(l)); return true; }, @@ -433,9 +484,8 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f if (pReq->progressRef != LUA_REFNIL && !pReq->abandoned.load()) { auto now = std::chrono::steady_clock::now(); if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { - HttpResult pres; pres.type = HttpResult::Type::PROGRESS; pres.requestId = reqId; - pres.current = static_cast(len); pres.total = static_cast(total); - PushResult(std::move(pres)); lastProg = now; + PushResult({HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}); + lastProg = now; } } return true; @@ -443,26 +493,16 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f ); ofs.close(); - HttpResult fres; fres.type = HttpResult::Type::COMPLETE; fres.requestId = reqId; - fres.status = status_code; + HttpResult fres{HttpResult::Type::COMPLETE, reqId, 0, 0, status_code, "", std::move(resHeaders)}; fres.body = res ? "Success" : "Download Failed: " + httplib::to_string(res.error()); - fres.headers = resHeaders; PushResult(std::move(fres)); }); return CreateHandle(cb.lua_state(), info); } -void AsyncHttpProxy::SetDefaultHeaders(sol::table headers) { - mDefaultHeaders.clear(); - if (headers != sol::lua_nil && headers.valid()) { - for (auto const& pair : headers) { - if (pair.first.is() && pair.second.is()) { - mDefaultHeaders[pair.first.as()] = pair.second.as(); - } - } - } -} + +// --- AsyncWebSocket --- AsyncWebSocket::AsyncWebSocket(std::string url, sol::table headers, lua_State* state) : mUrl(std::move(url)), L(state) { @@ -477,15 +517,40 @@ AsyncWebSocket::AsyncWebSocket(std::string url, sol::table headers, lua_State* s AsyncWebSocket::~AsyncWebSocket() { Abandon(); - if (mThread.joinable()) { - mThread.join(); - } + if (mThread.joinable()) mThread.join(); } -void AsyncWebSocket::VerifySSL(bool verify) { - mVerifySSL = verify; +sol::object AsyncWebSocket::Create(sol::this_state s, std::string url, sol::object headers) { + lua_State* L = s.lua_state(); + + { + std::lock_guard lock(ctx.limitMutex); + if (ctx.currentWsGlobal >= ctx.maxWsGlobal || ctx.stateWsCount[L] >= ctx.maxWsPerPlugin) { + beammp_lua_warnf("WebSocket limit reached (Global: {}/{}, Plugin: {}/{}).", + ctx.currentWsGlobal, ctx.maxWsGlobal, ctx.stateWsCount[L], ctx.maxWsPerPlugin); + return sol::make_object(s, sol::lua_nil); + } + ctx.stateWsCount[L]++; + ctx.currentWsGlobal++; + } + + if (!IsValidWsUrl(url)) { + beammp_lua_warnf("Invalid WebSocket URL: {}. Use 'ws://' or 'wss://'.", url); + std::lock_guard lock(ctx.limitMutex); + ctx.stateWsCount[L]--; + ctx.currentWsGlobal--; + return sol::make_object(s, sol::lua_nil); + } + + auto ws = std::make_shared(url, headers.is() ? headers.as() : sol::table(), L); + + std::lock_guard lock(ctx.wsMutex); + ctx.webSockets.push_back(ws); + return sol::make_object(s, ws); } +void AsyncWebSocket::VerifySSL(bool verify) { mVerifySSL = verify; } + void AsyncWebSocket::Connect() { if (mIsRunning.exchange(true)) return; @@ -497,11 +562,9 @@ void AsyncWebSocket::Connect() { if (ToLower(k) == "user-agent") hasUA = true; h.emplace(k, v); } - if (!hasUA) h.emplace("User-Agent", DEFAULT_USER_AGENT); httplib::ws::WebSocketClient client(mUrl, h); - client.enable_server_certificate_verification(mVerifySSL); { @@ -522,10 +585,7 @@ void AsyncWebSocket::Connect() { std::string msg; while (mIsRunning && !mAbandoned) { - auto res = client.read(msg); - if (res == httplib::ws::ReadResult::Fail) { - break; - } + if (client.read(msg) == httplib::ws::ReadResult::Fail) break; PushEvent({WSEventType::MESSAGE, msg, 0}); msg.clear(); } @@ -540,17 +600,13 @@ void AsyncWebSocket::Connect() { void AsyncWebSocket::Send(const std::string& data) { std::lock_guard lock(mClientMutex); - if (mClient && mClient->is_open()) { - mClient->send(data); - } + if (mClient && mClient->is_open()) mClient->send(data); } void AsyncWebSocket::Close() { mIsRunning = false; std::lock_guard lock(mClientMutex); - if (mClient && mClient->is_open()) { - mClient->close(); - } + if (mClient && mClient->is_open()) mClient->close(); } void AsyncWebSocket::PushEvent(WSEvent ev) { @@ -559,22 +615,10 @@ void AsyncWebSocket::PushEvent(WSEvent ev) { mEvents.push(std::move(ev)); } -void AsyncWebSocket::OnOpen(sol::object cb) { - if (mOnOpenRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnOpenRef); - mOnOpenRef = MakeRef(cb); -} -void AsyncWebSocket::OnMessage(sol::object cb) { - if (mOnMessageRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnMessageRef); - mOnMessageRef = MakeRef(cb); -} -void AsyncWebSocket::OnClose(sol::object cb) { - if (mOnCloseRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnCloseRef); - mOnCloseRef = MakeRef(cb); -} -void AsyncWebSocket::OnError(sol::object cb) { - if (mOnErrorRef != LUA_REFNIL) luaL_unref(L, LUA_REGISTRYINDEX, mOnErrorRef); - mOnErrorRef = MakeRef(cb); -} +void AsyncWebSocket::OnOpen(sol::object cb) { UnrefCallback(L, mOnOpenRef); mOnOpenRef = MakeRef(cb); } +void AsyncWebSocket::OnMessage(sol::object cb) { UnrefCallback(L, mOnMessageRef); mOnMessageRef = MakeRef(cb); } +void AsyncWebSocket::OnClose(sol::object cb) { UnrefCallback(L, mOnCloseRef); mOnCloseRef = MakeRef(cb); } +void AsyncWebSocket::OnError(sol::object cb) { UnrefCallback(L, mOnErrorRef); mOnErrorRef = MakeRef(cb); } void AsyncWebSocket::ProcessEvents() { if (mAbandoned) return; @@ -585,32 +629,26 @@ void AsyncWebSocket::ProcessEvents() { std::swap(events, mEvents); } - while (!events.empty()) { + while (!events.empty() && !mAbandoned) { auto ev = events.front(); events.pop(); - if (mAbandoned) break; - try { - if (ev.type == WSEventType::OPEN && mOnOpenRef != LUA_REFNIL) { - lua_rawgeti(L, LUA_REGISTRYINDEX, mOnOpenRef); - sol::protected_function cb = sol::stack::pop(L); - if (cb.valid()) { auto r = cb(); if (!r.valid()) beammp_lua_errorf("WS OnOpen Error: {}", sol::error(r).what()); } - } - else if (ev.type == WSEventType::MESSAGE && mOnMessageRef != LUA_REFNIL) { - lua_rawgeti(L, LUA_REGISTRYINDEX, mOnMessageRef); - sol::protected_function cb = sol::stack::pop(L); - if (cb.valid()) { auto r = cb(ev.payload); if (!r.valid()) beammp_lua_errorf("WS OnMessage Error: {}", sol::error(r).what()); } - } - else if (ev.type == WSEventType::CLOSE && mOnCloseRef != LUA_REFNIL) { - lua_rawgeti(L, LUA_REGISTRYINDEX, mOnCloseRef); - sol::protected_function cb = sol::stack::pop(L); - if (cb.valid()) { auto r = cb(ev.closeCode, ev.payload); if (!r.valid()) beammp_lua_errorf("WS OnClose Error: {}", sol::error(r).what()); } - } - else if (ev.type == WSEventType::ERROR_EVENT && mOnErrorRef != LUA_REFNIL) { - lua_rawgeti(L, LUA_REGISTRYINDEX, mOnErrorRef); - sol::protected_function cb = sol::stack::pop(L); - if (cb.valid()) { auto r = cb(ev.payload); if (!r.valid()) beammp_lua_errorf("WS OnError Error: {}", sol::error(r).what()); } + switch (ev.type) { + case WSEventType::OPEN: + InvokeLuaCallback(L, mOnOpenRef, "WS OnOpen Error"); + break; + case WSEventType::MESSAGE: + InvokeLuaCallback(L, mOnMessageRef, "WS OnMessage Error", ev.payload); + break; + case WSEventType::CLOSE: + InvokeLuaCallback(L, mOnCloseRef, "WS OnClose Error", ev.closeCode, ev.payload); + break; + case WSEventType::ERROR_EVENT: + InvokeLuaCallback(L, mOnErrorRef, "WS OnError Error", ev.payload); + break; + default: + break; } } catch (const std::exception& e) { beammp_lua_errorf("WebSocket Exception: {}", e.what()); @@ -622,22 +660,22 @@ void AsyncWebSocket::Abandon() { if (mAbandoned.exchange(true)) return; { - std::lock_guard lock(g_LimitMutex); - g_StateWsCount[L]--; - if (g_StateWsCount[L] <= 0) g_StateWsCount.erase(L); - - g_CurrentWsGlobal--; + std::lock_guard lock(ctx.limitMutex); + ctx.stateWsCount[L]--; + if (ctx.stateWsCount[L] <= 0) ctx.stateWsCount.erase(L); + ctx.currentWsGlobal--; } - mAbandoned = true; Close(); - - if (mOnOpenRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnOpenRef); mOnOpenRef = LUA_REFNIL; } - if (mOnMessageRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnMessageRef); mOnMessageRef = LUA_REFNIL; } - if (mOnCloseRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnCloseRef); mOnCloseRef = LUA_REFNIL; } - if (mOnErrorRef != LUA_REFNIL) { luaL_unref(L, LUA_REGISTRYINDEX, mOnErrorRef); mOnErrorRef = LUA_REFNIL; } + + UnrefCallback(L, mOnOpenRef); + UnrefCallback(L, mOnMessageRef); + UnrefCallback(L, mOnCloseRef); + UnrefCallback(L, mOnErrorRef); } +// --- Lifecycle & Lua Bindings --- + void RegisterBindings(sol::state_view& lua) { lua.new_usertype("AsyncHttp", sol::no_constructor, "SetConnectTimeout", &AsyncHttpProxy::SetConnectTimeout, @@ -668,142 +706,88 @@ void RegisterBindings(sol::state_view& lua) { "OnError", &AsyncWebSocket::OnError ); - lua["AsyncWebSocket"]["new"] = [](sol::this_state s, std::string url, sol::object headers) -> sol::object { - lua_State* L = s.lua_state(); - { - std::lock_guard lock(g_LimitMutex); - - if (g_CurrentWsGlobal >= g_MaxWsGlobal) { - beammp_lua_warnf("Server-wide WebSocket limit reached ({}).", g_MaxWsGlobal); - return sol::make_object(s, sol::lua_nil); - } - - if (g_StateWsCount[L] >= g_MaxWsPerPlugin) { - beammp_lua_warnf("Plugin reached WebSocket limit ({}).", g_MaxWsPerPlugin); - return sol::make_object(s, sol::lua_nil); - } - - g_StateWsCount[L]++; - g_CurrentWsGlobal++; - } - - sol::table h; - - if (headers.is()) { - h = headers.as(); - } - - auto ws = std::make_shared(url, h, L); - - { - std::lock_guard lock(g_WsMutex); - g_WebSockets.push_back(ws); - } - - return sol::make_object(s, ws); - }; + lua["AsyncWebSocket"]["new"] = &AsyncWebSocket::Create; } void Update(sol::state_view& lua) { lua_State* L = lua.lua_state(); std::deque toProcess; + // 1. Gather relevant results safely { - std::lock_guard httpLock(g_Mutex); - auto it = g_Results.begin(); - while (it != g_Results.end()) { - auto reqIt = g_PendingRequests.find(it->requestId); - if (reqIt == g_PendingRequests.end()) { - it = g_Results.erase(it); + std::lock_guard lock(ctx.resultsMutex); + auto it = ctx.results.begin(); + while (it != ctx.results.end()) { + auto reqIt = ctx.pendingRequests.find(it->requestId); + if (reqIt == ctx.pendingRequests.end()) { + it = ctx.results.erase(it); } else if (reqIt->second->L == L) { toProcess.push_back(std::move(*it)); - it = g_Results.erase(it); + it = ctx.results.erase(it); } else { ++it; } } } + // 2. Dispatch Lua callbacks for (const auto& res : toProcess) { std::shared_ptr info; { - std::lock_guard httpLock(g_Mutex); - if (g_PendingRequests.count(res.requestId)) info = g_PendingRequests[res.requestId]; + std::lock_guard lock(ctx.resultsMutex); + auto it = ctx.pendingRequests.find(res.requestId); + if (it != ctx.pendingRequests.end()) info = it->second; } if (!info || info->abandoned.load()) { if (info) { - if (info->callbackRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->callbackRef); - info->callbackRef = LUA_REFNIL; - } - if (info->progressRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); - info->progressRef = LUA_REFNIL; - } + ReleasePendingRequest(info); + + std::lock_guard lock(ctx.limitMutex); + ctx.stateRequestCount[L]--; } - std::lock_guard httpLock(g_Mutex); - g_PendingRequests.erase(res.requestId); + std::lock_guard lock(ctx.resultsMutex); + ctx.pendingRequests.erase(res.requestId); continue; } if (res.type == HttpResult::Type::PROGRESS) { - if (info->progressRef != LUA_REFNIL) { - lua_rawgeti(L, LUA_REGISTRYINDEX, info->progressRef); - sol::protected_function prog = sol::stack::pop(L); - if (prog.valid()) { - auto r = prog(res.current, res.total); - if (!r.valid()) beammp_lua_errorf("AsyncHttp Progress Error: {}", sol::error(r).what()); - } - } + InvokeLuaCallback(L, info->progressRef, "AsyncHttp Progress Error", res.current, res.total); } else { if (info->callbackRef != LUA_REFNIL) { - lua_rawgeti(L, LUA_REGISTRYINDEX, info->callbackRef); - sol::protected_function cb = sol::stack::pop(L); - if (cb.valid()) { - sol::table luaHeaders = lua.create_table(); - for (auto const& [name, values] : res.headers) { - if (values.empty()) continue; - - std::string key = ToLower(name); - - if (values.size() > 1 || key == "set-cookie") { - luaHeaders[key] = sol::as_table(values); - } else { - luaHeaders[key] = values[0]; - } - } - auto r = cb(res.status, res.body, luaHeaders); - if (!r.valid()) beammp_lua_errorf("AsyncHttp Callback Error: {}", sol::error(r).what()); + sol::table luaHeaders = lua.create_table(); + for (auto const& [name, values] : res.headers) { + if (values.empty()) continue; + std::string key = ToLower(name); + if (values.size() > 1 || key == "set-cookie") luaHeaders[key] = sol::as_table(values); + else luaHeaders[key] = values[0]; } + InvokeLuaCallback(L, info->callbackRef, "AsyncHttp Callback Error", res.status, res.body, luaHeaders); } - if (info->callbackRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->callbackRef); - info->callbackRef = LUA_REFNIL; - } - if (info->progressRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, info->progressRef); - info->progressRef = LUA_REFNIL; + ReleasePendingRequest(info); + + { + std::lock_guard lock(ctx.limitMutex); + ctx.stateRequestCount[L]--; } - std::lock_guard httpLock(g_Mutex); - g_PendingRequests.erase(res.requestId); + std::lock_guard lock(ctx.resultsMutex); + ctx.pendingRequests.erase(res.requestId); } } + // 3. Update WebSockets std::vector> websocketsToUpdate; { - std::lock_guard wsLock(g_WsMutex); - auto wsIt = g_WebSockets.begin(); - while (wsIt != g_WebSockets.end()) { + std::lock_guard wsLock(ctx.wsMutex); + auto wsIt = ctx.webSockets.begin(); + while (wsIt != ctx.webSockets.end()) { if (auto ws = wsIt->lock()) { - if (ws->GetLuaState() == L) { - websocketsToUpdate.push_back(ws); - } + if (ws->GetLuaState() == L) websocketsToUpdate.push_back(ws); ++wsIt; } else { - wsIt = g_WebSockets.erase(wsIt); + wsIt = ctx.webSockets.erase(wsIt); } } } @@ -814,74 +798,76 @@ void Update(sol::state_view& lua) { } void CleanupState(lua_State* L) { - std::lock_guard httpLock(g_Mutex); - for (auto it = g_PendingRequests.begin(); it != g_PendingRequests.end(); ) { - if (it->second->L == L) { - it->second->abandoned.store(true); - - if (it->second->callbackRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, it->second->callbackRef); - it->second->callbackRef = LUA_REFNIL; - } - if (it->second->progressRef != LUA_REFNIL) { - luaL_unref(L, LUA_REGISTRYINDEX, it->second->progressRef); - it->second->progressRef = LUA_REFNIL; + // Purge pending HTTP requests + { + std::lock_guard lock(ctx.resultsMutex); + for (auto it = ctx.pendingRequests.begin(); it != ctx.pendingRequests.end(); ) { + if (it->second->L == L) { + it->second->abandoned.store(true); + ReleasePendingRequest(it->second); + it = ctx.pendingRequests.erase(it); + } else { + ++it; } - - it = g_PendingRequests.erase(it); - } else { - ++it; } } - std::lock_guard wsLock(g_WsMutex); - auto wsIt = g_WebSockets.begin(); - while (wsIt != g_WebSockets.end()) { - if (auto ws = wsIt->lock()) { - if (ws->GetLuaState() == L) { - ws->Abandon(); - wsIt = g_WebSockets.erase(wsIt); + // Purge attached WebSockets + { + std::lock_guard wsLock(ctx.wsMutex); + auto wsIt = ctx.webSockets.begin(); + while (wsIt != ctx.webSockets.end()) { + if (auto ws = wsIt->lock()) { + if (ws->GetLuaState() == L) { + ws->Abandon(); + wsIt = ctx.webSockets.erase(wsIt); + } else { + ++wsIt; + } } else { - ++wsIt; + wsIt = ctx.webSockets.erase(wsIt); } - } else { - wsIt = g_WebSockets.erase(wsIt); } } } void Init() { - g_ShuttingDown.store(false); + ctx.shuttingDown.store(false); int cores = static_cast(std::thread::hardware_concurrency()); if (cores <= 0) cores = 4; + ctx.actualPoolSize = std::clamp(cores * 4, 16, 128); + ctx.maxRequestsPerPlugin = std::max(ctx.actualPoolSize / 2, 5); + ctx.threadPool = std::make_unique(ctx.actualPoolSize); - g_ActualPoolSize = std::clamp(cores * 4, 16, 128); - g_MaxRequestsPerPlugin = std::max(g_ActualPoolSize / 2, 5); - g_ThreadPool = std::make_unique(g_ActualPoolSize); - - g_MaxWsGlobal = std::max(cores * 8, 32); - g_MaxWsPerPlugin = std::max(g_MaxWsGlobal / 4, 4); + ctx.maxWsGlobal = std::max(cores * 8, 32); + ctx.maxWsPerPlugin = std::max(ctx.maxWsGlobal / 4, 4); beammp_infof("AsyncHttp initialized. HTTP Pool: {} ({} per plugin). WS Quota: {} ({} per plugin).", - g_ActualPoolSize, g_MaxRequestsPerPlugin, g_MaxWsGlobal, g_MaxWsPerPlugin); + ctx.actualPoolSize, ctx.maxRequestsPerPlugin, ctx.maxWsGlobal, ctx.maxWsPerPlugin); } void Shutdown() { - g_ShuttingDown.store(true); - if (g_ThreadPool) g_ThreadPool->shutdown(); - g_ThreadPool.reset(); - std::lock_guard httpLock(g_Mutex); - g_PendingRequests.clear(); - g_Results.clear(); - - std::lock_guard wsLock(g_WsMutex); - for (auto& weak_ws : g_WebSockets) { - if (auto ws = weak_ws.lock()) { - ws->Abandon(); + ctx.shuttingDown.store(true); + + if (ctx.threadPool) { + ctx.threadPool->shutdown(); + ctx.threadPool.reset(); + } + + { + std::lock_guard lock(ctx.resultsMutex); + ctx.pendingRequests.clear(); + ctx.results.clear(); + } + + { + std::lock_guard wsLock(ctx.wsMutex); + for (auto& weak_ws : ctx.webSockets) { + if (auto ws = weak_ws.lock()) ws->Abandon(); } + ctx.webSockets.clear(); } - g_WebSockets.clear(); } } // namespace HttpAsync \ No newline at end of file From 0d717fde7eecd3afc9303ff91eb7c506c2691d72 Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Wed, 29 Apr 2026 15:45:01 +0200 Subject: [PATCH 10/11] Refactor HttpAsync to improve plugin context management and thread safety Co-authored-by: Copilot --- src/HttpAsync.cpp | 145 ++++++++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 63 deletions(-) diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index 9f35247f..fbb27852 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -42,23 +42,31 @@ struct PendingRequest { std::atomic abandoned{false}; }; -// --- Centralized Global Context --- -static struct GlobalContext { +struct PluginContext { std::deque results; std::mutex resultsMutex; + std::vector> webSockets; + std::mutex wsMutex; +}; + +// --- Centralized Global Context --- +static struct GlobalContext { std::atomic shuttingDown{false}; std::unique_ptr threadPool; std::atomic nextRequestId{1}; std::map> pendingRequests; + std::mutex pendingRequestsMutex; + std::map stateRequestCount; std::mutex limitMutex; - std::vector> webSockets; - std::mutex wsMutex; std::map stateWsCount; + std::map> pluginContexts; + std::mutex pluginContextsMutex; + int actualPoolSize = 16; int maxRequestsPerPlugin = 8; int maxWsPerPlugin = 4; @@ -70,6 +78,18 @@ static const char* DEFAULT_USER_AGENT = "BeamMP-Server/1.0"; // --- Utilities --- +static std::shared_ptr GetPluginContext(lua_State* L) { + std::lock_guard lock(ctx.pluginContextsMutex); + auto it = ctx.pluginContexts.find(L); + if (it == ctx.pluginContexts.end()) { + auto newCtx = std::make_shared(); + ctx.pluginContexts[L] = newCtx; + return newCtx; + } + return it->second; +} + + static void ToLowerInPlace(std::string& s) { for (char& c : s) { if (c >= 'A' && c <= 'Z') c += 32; @@ -112,10 +132,12 @@ static void InvokeLuaCallback(lua_State* L, int ref, const char* errorContext, A } } -static void PushResult(HttpResult res) { +static void PushResult(lua_State* L, HttpResult res) { if (ctx.shuttingDown.load()) return; - std::lock_guard lock(ctx.resultsMutex); - ctx.results.push_back(std::move(res)); + auto pCtx = GetPluginContext(L); + + std::lock_guard lock(pCtx->resultsMutex); + pCtx->results.push_back(std::move(res)); } static void ExtractHeaders(const httplib::Headers& source, std::map>& dest) { @@ -210,7 +232,7 @@ static std::shared_ptr EnqueueTask(lua_State* L, int cbRef, int info->progressRef = progRef; { - std::lock_guard lock(ctx.resultsMutex); + std::lock_guard lock(ctx.pendingRequestsMutex); ctx.pendingRequests[reqId] = info; } @@ -231,7 +253,7 @@ static sol::table Dispatch(std::string method, std::string url, std::mapL, std::move(res)); return; } @@ -258,7 +280,7 @@ static sol::table Dispatch(std::string method, std::string url, std::map(now - lastProg).count() > 100 || len == total) { HttpResult res{HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}; - PushResult(std::move(res)); + PushResult(pReq->L, std::move(res)); lastProg = now; } } @@ -283,7 +305,7 @@ static sol::table Dispatch(std::string method, std::string url, std::mapL, std::move(res)); }); return CreateHandle(L, info); @@ -374,18 +396,18 @@ sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std:: std::unique_ptr cli; if (!SetupClient(mBaseUrl + ep, mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cli, path)) { - PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); return; } if (!fs::exists(filePath)) { - PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "File not found", {}}); + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "File not found", {}}); return; } auto file_stream = std::make_shared(filePath, std::ios::binary); if (!file_stream || !file_stream->is_open()) { - PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file", {}}); + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file", {}}); return; } @@ -433,7 +455,7 @@ sol::table AsyncHttpProxy::PostFile(std::string ep, std::string fieldName, std:: } else { res.body = "Upload Failed: " + httplib::to_string(response.error()); } - PushResult(std::move(res)); + PushResult(pReq->L, std::move(res)); }); return CreateHandle(cb.lua_state(), info); @@ -447,13 +469,13 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f std::unique_ptr cli; if (!SetupClient(mBaseUrl + ep, mConnectTimeoutSeconds, mReadTimeoutSeconds, mVerifySSL, cli, path)) { - PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Invalid URL", {}}); return; } std::ofstream ofs(savePath, std::ios::binary); if (!ofs) { - PushResult({HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file for writing", {}}); + PushResult(pReq->L, {HttpResult::Type::COMPLETE, reqId, 0, 0, 0, "Could not open file for writing", {}}); return; } @@ -484,7 +506,7 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f if (pReq->progressRef != LUA_REFNIL && !pReq->abandoned.load()) { auto now = std::chrono::steady_clock::now(); if (std::chrono::duration_cast(now - lastProg).count() > 100 || len == total) { - PushResult({HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}); + PushResult(pReq->L, {HttpResult::Type::PROGRESS, reqId, static_cast(len), static_cast(total), 0, "", {}}); lastProg = now; } } @@ -495,7 +517,7 @@ sol::table AsyncHttpProxy::Download(std::string ep, std::string savePath, sol::f HttpResult fres{HttpResult::Type::COMPLETE, reqId, 0, 0, status_code, "", std::move(resHeaders)}; fres.body = res ? "Success" : "Download Failed: " + httplib::to_string(res.error()); - PushResult(std::move(fres)); + PushResult(pReq->L, std::move(fres)); }); return CreateHandle(cb.lua_state(), info); @@ -544,8 +566,9 @@ sol::object AsyncWebSocket::Create(sol::this_state s, std::string url, sol::obje auto ws = std::make_shared(url, headers.is() ? headers.as() : sol::table(), L); - std::lock_guard lock(ctx.wsMutex); - ctx.webSockets.push_back(ws); + auto pCtx = GetPluginContext(L); + std::lock_guard lock(pCtx->wsMutex); + pCtx->webSockets.push_back(ws); return sol::make_object(s, ws); } @@ -711,30 +734,21 @@ void RegisterBindings(sol::state_view& lua) { void Update(sol::state_view& lua) { lua_State* L = lua.lua_state(); + auto pCtx = GetPluginContext(L); + std::deque toProcess; // 1. Gather relevant results safely { - std::lock_guard lock(ctx.resultsMutex); - auto it = ctx.results.begin(); - while (it != ctx.results.end()) { - auto reqIt = ctx.pendingRequests.find(it->requestId); - if (reqIt == ctx.pendingRequests.end()) { - it = ctx.results.erase(it); - } else if (reqIt->second->L == L) { - toProcess.push_back(std::move(*it)); - it = ctx.results.erase(it); - } else { - ++it; - } - } + std::lock_guard lock(pCtx->resultsMutex); + toProcess.swap(pCtx->results); } // 2. Dispatch Lua callbacks for (const auto& res : toProcess) { std::shared_ptr info; { - std::lock_guard lock(ctx.resultsMutex); + std::lock_guard lock(ctx.pendingRequestsMutex); auto it = ctx.pendingRequests.find(res.requestId); if (it != ctx.pendingRequests.end()) info = it->second; } @@ -746,7 +760,7 @@ void Update(sol::state_view& lua) { std::lock_guard lock(ctx.limitMutex); ctx.stateRequestCount[L]--; } - std::lock_guard lock(ctx.resultsMutex); + std::lock_guard lock(ctx.pendingRequestsMutex); ctx.pendingRequests.erase(res.requestId); continue; } @@ -772,22 +786,22 @@ void Update(sol::state_view& lua) { ctx.stateRequestCount[L]--; } - std::lock_guard lock(ctx.resultsMutex); + std::lock_guard lock(ctx.pendingRequestsMutex); ctx.pendingRequests.erase(res.requestId); } } - // 3. Update WebSockets + // 3. Update WebSockets cleanly without cross-plugin locking std::vector> websocketsToUpdate; { - std::lock_guard wsLock(ctx.wsMutex); - auto wsIt = ctx.webSockets.begin(); - while (wsIt != ctx.webSockets.end()) { + std::lock_guard wsLock(pCtx->wsMutex); + auto wsIt = pCtx->webSockets.begin(); + while (wsIt != pCtx->webSockets.end()) { if (auto ws = wsIt->lock()) { - if (ws->GetLuaState() == L) websocketsToUpdate.push_back(ws); + websocketsToUpdate.push_back(ws); ++wsIt; } else { - wsIt = ctx.webSockets.erase(wsIt); + wsIt = pCtx->webSockets.erase(wsIt); } } } @@ -798,9 +812,11 @@ void Update(sol::state_view& lua) { } void CleanupState(lua_State* L) { + auto pCtx = GetPluginContext(L); + // Purge pending HTTP requests { - std::lock_guard lock(ctx.resultsMutex); + std::lock_guard lock(ctx.pendingRequestsMutex); for (auto it = ctx.pendingRequests.begin(); it != ctx.pendingRequests.end(); ) { if (it->second->L == L) { it->second->abandoned.store(true); @@ -811,24 +827,24 @@ void CleanupState(lua_State* L) { } } } + + // Clear out un-polled results + { + std::lock_guard lock(pCtx->resultsMutex); + pCtx->results.clear(); + } // Purge attached WebSockets { - std::lock_guard wsLock(ctx.wsMutex); - auto wsIt = ctx.webSockets.begin(); - while (wsIt != ctx.webSockets.end()) { - if (auto ws = wsIt->lock()) { - if (ws->GetLuaState() == L) { - ws->Abandon(); - wsIt = ctx.webSockets.erase(wsIt); - } else { - ++wsIt; - } - } else { - wsIt = ctx.webSockets.erase(wsIt); - } + std::lock_guard wsLock(pCtx->wsMutex); + for (auto& weak_ws : pCtx->webSockets) { + if (auto ws = weak_ws.lock()) ws->Abandon(); } + pCtx->webSockets.clear(); } + + std::lock_guard lock(ctx.pluginContextsMutex); + ctx.pluginContexts.erase(L); } void Init() { @@ -856,17 +872,20 @@ void Shutdown() { } { - std::lock_guard lock(ctx.resultsMutex); + std::lock_guard lock(ctx.pendingRequestsMutex); ctx.pendingRequests.clear(); - ctx.results.clear(); } { - std::lock_guard wsLock(ctx.wsMutex); - for (auto& weak_ws : ctx.webSockets) { - if (auto ws = weak_ws.lock()) ws->Abandon(); + std::lock_guard lock(ctx.pluginContextsMutex); + for (auto& [L, pCtx] : ctx.pluginContexts) { + std::lock_guard wsLock(pCtx->wsMutex); + for (auto& weak_ws : pCtx->webSockets) { + if (auto ws = weak_ws.lock()) ws->Abandon(); + } + pCtx->webSockets.clear(); } - ctx.webSockets.clear(); + ctx.pluginContexts.clear(); } } From e8cf098e91e5612877890ff7abe7d8d2fe65498d Mon Sep 17 00:00:00 2001 From: cocorico8 Date: Sat, 2 May 2026 20:44:47 +0200 Subject: [PATCH 11/11] Enhance PluginContext with lastLimitWarning and improve request limit warnings for HTTP and WebSocket connections. Also fixes a race condition in the websockets part --- src/HttpAsync.cpp | 98 ++++++++++++++++++++++++++++++++-------------- src/TLuaEngine.cpp | 5 --- 2 files changed, 69 insertions(+), 34 deletions(-) diff --git a/src/HttpAsync.cpp b/src/HttpAsync.cpp index fbb27852..2f0471b9 100644 --- a/src/HttpAsync.cpp +++ b/src/HttpAsync.cpp @@ -48,6 +48,8 @@ struct PluginContext { std::vector> webSockets; std::mutex wsMutex; + + std::chrono::steady_clock::time_point lastLimitWarning{}; }; // --- Centralized Global Context --- @@ -78,6 +80,15 @@ static const char* DEFAULT_USER_AGENT = "BeamMP-Server/1.0"; // --- Utilities --- +static bool ShouldWarn(std::shared_ptr pCtx) { + auto now = std::chrono::steady_clock::now(); + if (std::chrono::duration_cast(now - pCtx->lastLimitWarning).count() >= 10) { + pCtx->lastLimitWarning = now; + return true; + } + return false; +} + static std::shared_ptr GetPluginContext(lua_State* L) { std::lock_guard lock(ctx.pluginContextsMutex); auto it = ctx.pluginContexts.find(L); @@ -217,7 +228,10 @@ static std::shared_ptr EnqueueTask(lua_State* L, int cbRef, int { std::lock_guard lock(ctx.limitMutex); if (ctx.stateRequestCount[L] >= ctx.maxRequestsPerPlugin) { - beammp_lua_warnf("Plugin reached HTTP request limit ({}). Request rejected.", ctx.maxRequestsPerPlugin); + auto pCtx = GetPluginContext(L); + if (ShouldWarn(pCtx)) { + beammp_lua_warnf("Plugin reached HTTP request limit ({}). Further requests silenced for 10s.", ctx.maxRequestsPerPlugin); + } UnrefCallback(L, cbRef); UnrefCallback(L, progRef); return nullptr; @@ -539,7 +553,7 @@ AsyncWebSocket::AsyncWebSocket(std::string url, sol::table headers, lua_State* s AsyncWebSocket::~AsyncWebSocket() { Abandon(); - if (mThread.joinable()) mThread.join(); + if (mThread.joinable()) mThread.detach(); } sol::object AsyncWebSocket::Create(sol::this_state s, std::string url, sol::object headers) { @@ -548,8 +562,11 @@ sol::object AsyncWebSocket::Create(sol::this_state s, std::string url, sol::obje { std::lock_guard lock(ctx.limitMutex); if (ctx.currentWsGlobal >= ctx.maxWsGlobal || ctx.stateWsCount[L] >= ctx.maxWsPerPlugin) { - beammp_lua_warnf("WebSocket limit reached (Global: {}/{}, Plugin: {}/{}).", - ctx.currentWsGlobal, ctx.maxWsGlobal, ctx.stateWsCount[L], ctx.maxWsPerPlugin); + auto pCtx = GetPluginContext(L); + if (ShouldWarn(pCtx)) { + beammp_lua_warnf("WebSocket limit reached (Global: {}/{}, Plugin: {}/{}). Silencing for 10s.", + ctx.currentWsGlobal, ctx.maxWsGlobal, ctx.stateWsCount[L], ctx.maxWsPerPlugin); + } return sol::make_object(s, sol::lua_nil); } ctx.stateWsCount[L]++; @@ -577,47 +594,70 @@ void AsyncWebSocket::VerifySSL(bool verify) { mVerifySSL = verify; } void AsyncWebSocket::Connect() { if (mIsRunning.exchange(true)) return; - mThread = std::thread([this]() { + std::weak_ptr weakSelf = shared_from_this(); + std::string url = mUrl; + std::map headers = mHeaders; + bool verifySSL = mVerifySSL; + + mThread = std::thread([weakSelf, url = std::move(url), headers = std::move(headers), verifySSL]() { httplib::Headers h; bool hasUA = false; - for (const auto& [k, v] : mHeaders) { + for (const auto&[k, v] : headers) { if (ToLower(k) == "user-agent") hasUA = true; h.emplace(k, v); } if (!hasUA) h.emplace("User-Agent", DEFAULT_USER_AGENT); - httplib::ws::WebSocketClient client(mUrl, h); - client.enable_server_certificate_verification(mVerifySSL); + httplib::ws::WebSocketClient client(url, h); + client.enable_server_certificate_verification(verifySSL); - { - std::lock_guard lock(mClientMutex); - if (mAbandoned) return; - mClient = &client; - } - if (!client.connect()) { - PushEvent({WSEventType::ERROR_EVENT, "Failed to connect", 0}); - mIsRunning = false; - std::lock_guard lock(mClientMutex); - mClient = nullptr; + if (auto self = weakSelf.lock()) { + if (self->mIsRunning.exchange(false)) { + self->PushEvent({WSEventType::ERROR_EVENT, "Failed to connect", 0}); + } + } return; } - PushEvent({WSEventType::OPEN, "", 0}); + if (auto self = weakSelf.lock()) { + std::lock_guard lock(self->mClientMutex); + if (self->mAbandoned) return; + + if (!self->mIsRunning) { + return; + } + + self->mClient = &client; + self->PushEvent({WSEventType::OPEN, "", 0}); + } else { + return; + } std::string msg; - while (mIsRunning && !mAbandoned) { + while (true) { + if (auto self = weakSelf.lock()) { + if (!self->mIsRunning || self->mAbandoned) break; + } else break; + if (client.read(msg) == httplib::ws::ReadResult::Fail) break; - PushEvent({WSEventType::MESSAGE, msg, 0}); + + if (auto self = weakSelf.lock()) { + if (!self->mIsRunning || self->mAbandoned) break; + self->PushEvent({WSEventType::MESSAGE, msg, 0}); + } else break; + msg.clear(); } - PushEvent({WSEventType::CLOSE, "Connection closed", 1000}); - mIsRunning = false; - - std::lock_guard lock(mClientMutex); - mClient = nullptr; + if (auto self = weakSelf.lock()) { + if (self->mIsRunning.exchange(false)) { + self->PushEvent({WSEventType::CLOSE, "Connection closed by peer", 1000}); + } + std::lock_guard lock(self->mClientMutex); + self->mClient = nullptr; + } }); } @@ -627,9 +667,9 @@ void AsyncWebSocket::Send(const std::string& data) { } void AsyncWebSocket::Close() { - mIsRunning = false; - std::lock_guard lock(mClientMutex); - if (mClient && mClient->is_open()) mClient->close(); + if (mIsRunning.exchange(false)) { + PushEvent({WSEventType::CLOSE, "Connection closed locally", 1000}); + } } void AsyncWebSocket::PushEvent(WSEvent ev) { diff --git a/src/TLuaEngine.cpp b/src/TLuaEngine.cpp index 2fd44b5a..7b998685 100644 --- a/src/TLuaEngine.cpp +++ b/src/TLuaEngine.cpp @@ -1081,11 +1081,6 @@ TLuaEngine::StateThreadData::~StateThreadData() noexcept { HttpAsync::CleanupState(mState); beammp_debug("\"" + mStateId + "\" destroyed"); - - if (mState) { - lua_close(mState); - mState = nullptr; - } } std::shared_ptr TLuaEngine::StateThreadData::EnqueueScript(const TLuaChunk& Script) {