diff --git a/include/tguf.h b/include/tguf.h index f54ec36..8e1234b 100644 --- a/include/tguf.h +++ b/include/tguf.h @@ -2,10 +2,16 @@ #include +#include #include +#include +#include #include +#include #include +#include #include +#include /** @namespace tguf * @brief Temporal Graph Unified Format: A Temporal Graph Stream Format. @@ -215,4 +221,46 @@ class TGStore { -> LabelEvent = 0; }; +/** + * @class AsyncDataLoader + * @brief A generic producer-consumer pipeline for asynchronous data fetching. + * @tparam T The type of the data batch being produced. + */ +template +class AsyncDataLoader { + public: + explicit AsyncDataLoader(std::size_t prefetch_factor); + ~AsyncDataLoader(); + + /** + * @brief Start the background producer thread. + * @tparam Producer A callable type. + * @param start_idx The start edge index to iterate from. + * @param end_idx The end edge index to iterate to (exclusive). + * @param batch_size The step size to use for each batch. + */ + // TODO(kuba): use concepts here + template + auto start(std::size_t start_idx, std::size_t end_idx, std::size_t batch_size, + Producer&& producer) -> void; + + /** + * @brief Stop the background producer thread. + */ + auto stop() -> void; + + /** + * @brief Retrieve next materialized batch by the producer. Blocks if empty. + */ + auto next() -> std::optional; + + private: + std::size_t prefetch_factor_{}; + std::queue> q_{}; + std::thread worker_{}; + std::mutex mtx_{}; + std::condition_variable cv_empty_{}, cv_full_{}; + bool stop_{false}; +}; + } // namespace tguf diff --git a/src/tguf/data_loader.cpp b/src/tguf/data_loader.cpp new file mode 100644 index 0000000..5999bb9 --- /dev/null +++ b/src/tguf/data_loader.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "logging.h" +#include "tguf.h" + +namespace tguf { + +template +AsyncDataLoader::AsyncDataLoader(std::size_t prefetch_factor) + : prefetch_factor_(prefetch_factor) {} + +template +AsyncDataLoader::~AsyncDataLoader() { + stop(); +} + +template +template +auto AsyncDataLoader::start(std::size_t start_idx, std::size_t end_idx, + std::size_t batch_size, Producer&& producer) + -> void { + stop_ = false; + worker_ = std::thread([this, start_idx, end_idx, batch_size, + fn = std::forward(producer)]() mutable { + for (auto i = start_idx; i < end_idx; i += batch_size) { + auto current_batch_size = std::min(batch_size, end_idx - i); + + // Wait for space in the prefetch buffer + std::unique_lock lock(mtx_); + cv_full_.wait(lock, + [this] { return q_.size() < prefetch_factor_ || stop_; }); + if (stop_) { + break; + } + + // Launch the task. We pass 'fn' by value into the async lambda. + auto task = std::async(std::launch::async, [fn, i, current_batch_size] { + return fn(i, current_batch_size); + }); + q_.push(std::move(task)); + + // Signal the consumer + lock.unlock(); + cv_empty_.notify_one(); + } + }); +} + +template +auto AsyncDataLoader::stop() -> void { + { + std::lock_guard lock(mtx_); + if (stop_) { + return; + } + stop_ = true; + } + + cv_full_.notify_all(); + cv_empty_.notify_all(); + if (worker_.joinable()) { + worker_.join(); + } + + std::lock_guard lock(mtx_); + while (!q_.empty()) { + q_.pop(); + } +} + +template +auto AsyncDataLoader::next() -> std::optional { + std::unique_lock lock(mtx_); + + // Wait for a task to be available or for the loader to stop + cv_empty_.wait(lock, [this] { return !q_.empty() || stop_; }); + if (q_.empty()) { + return std::nullopt; + } + + // Move the future out of the queue + auto fut = std::move(q_.front()); + q_.pop(); + + // Unlock before blocking on .get() to allow the producer to continue + lock.unlock(); + cv_full_.notify_one(); + return fut.get(); +} + +} // namespace tguf