Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions include/tguf.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#pragma once

#include <torch/types.h>

Check warning on line 3 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:3:1 [misc-include-cleaner]

included header types.h is not used directly

#include <condition_variable>
#include <cstddef>
#include <cstdint>

Check warning on line 7 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:7:1 [misc-include-cleaner]

included header cstdint is not used directly
#include <future>
#include <memory>
#include <mutex>
#include <optional>
#include <queue>
#include <string>
#include <thread>

/** @namespace tguf
* @brief Temporal Graph Unified Format: A Temporal Graph Stream Format.
Expand All @@ -16,7 +22,7 @@
* @brief Container for temporal edge data.
*/
struct Batch {
torch::Tensor src; ///< Source node IDs [B]

Check warning on line 25 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:25:10 [misc-include-cleaner]

no header providing "at::Tensor" is directly included
torch::Tensor dst; ///< Destination node IDs [B]
torch::Tensor time; ///< Timestamps [B]
torch::Tensor msg; ///< Edge features [B, msg_dim]
Expand Down Expand Up @@ -100,7 +106,7 @@
/** @enum NegStrategy
* @brief Determines how negative samples are generated during get_batch().
*/
enum class NegStrategy {

Check warning on line 109 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:109:14 [performance-enum-size]

enum 'NegStrategy' uses a larger base type ('int', size: 4 bytes) than necessary for its value set, consider using 'std::uint8_t' (1 byte) as the base type to reduce its size
None, ///< No negatives (inference or node-level tasks).
Random, ///< Samples one random negative node per edge.
PreComputed, ///< Uses the fixed negatives stored in TGUF (for eval).
Expand All @@ -110,17 +116,17 @@
*/
struct IndexRange {
IndexRange() = default;
IndexRange(std::size_t s, std::size_t e) : start_(s), end_(e) {

Check warning on line 119 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:119:16 [bugprone-easily-swappable-parameters]

2 adjacent parameters of 'IndexRange' of similar type ('std::size_t') are easily swapped by mistake
if (end_ < start_) {
throw std::out_of_range("Invalid range");

Check warning on line 121 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:121:20 [misc-include-cleaner]

no header providing "std::out_of_range" is directly included
}
}
[[nodiscard]] auto start() const -> std::size_t { return start_; }
[[nodiscard]] auto end() const -> std::size_t { return end_; }
[[nodiscard]] auto size() const -> std::size_t { return end_ - start_; }

std::size_t start_{0};

Check warning on line 128 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:128:17 [misc-non-private-member-variables-in-classes]

member variable 'start_' has public visibility
std::size_t end_{0};

Check warning on line 129 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:129:17 [misc-non-private-member-variables-in-classes]

member variable 'end_' has public visibility
};

virtual ~TGStore() = default;
Expand Down Expand Up @@ -168,7 +174,7 @@
*/
[[nodiscard]] virtual auto get_batch(std::size_t start, std::size_t size,
NegStrategy strategy = NegStrategy::None,
torch::Device device = torch::kCPU) const

Check warning on line 177 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:177:70 [misc-include-cleaner]

no header providing "c10::kCPU" is directly included

Check warning on line 177 in include/tguf.h

View workflow job for this annotation

GitHub Actions / cpp-lint

include/tguf.h:177:47 [misc-include-cleaner]

no header providing "c10::Device" is directly included
-> Batch = 0;

/** * @brief Performs a vectorized random-access gather of edge timestamps.
Expand Down Expand Up @@ -215,4 +221,46 @@
-> LabelEvent = 0;
};

/**
* @class AsyncDataLoader
* @brief A generic producer-consumer pipeline for asynchronous data fetching.
* @tparam T The type of the data batch being produced.
*/
template <typename T>
class AsyncDataLoader {
public:
explicit AsyncDataLoader(std::size_t prefetch_factor);
~AsyncDataLoader();

/**
* @brief Start the background producer thread.
* @tparam Producer A callable type.
* @param start_idx The start edge index to iterate from.
* @param end_idx The end edge index to iterate to (exclusive).
* @param batch_size The step size to use for each batch.
*/
// TODO(kuba): use concepts here
template <typename Producer>
auto start(std::size_t start_idx, std::size_t end_idx, std::size_t batch_size,
Producer&& producer) -> void;

/**
* @brief Stop the background producer thread.
*/
auto stop() -> void;

/**
* @brief Retrieve next materialized batch by the producer. Blocks if empty.
*/
auto next() -> std::optional<T>;

private:
std::size_t prefetch_factor_{};
std::queue<std::future<T>> q_{};
std::thread worker_{};
std::mutex mtx_{};
std::condition_variable cv_empty_{}, cv_full_{};
bool stop_{false};
};

} // namespace tguf
99 changes: 99 additions & 0 deletions src/tguf/data_loader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <algorithm>
#include <condition_variable>
#include <cstdint>
#include <future>
#include <mutex>
#include <optional>
#include <queue>
#include <thread>
#include <utility>

#include "logging.h"
#include "tguf.h"

namespace tguf {

template <typename T>
AsyncDataLoader<T>::AsyncDataLoader(std::size_t prefetch_factor)
: prefetch_factor_(prefetch_factor) {}

template <typename T>
AsyncDataLoader<T>::~AsyncDataLoader() {
stop();
}

template <typename T>
template <typename Producer>
auto AsyncDataLoader<T>::start(std::size_t start_idx, std::size_t end_idx,
std::size_t batch_size, Producer&& producer)
-> void {
stop_ = false;
worker_ = std::thread([this, start_idx, end_idx, batch_size,
fn = std::forward<Producer>(producer)]() mutable {
for (auto i = start_idx; i < end_idx; i += batch_size) {
auto current_batch_size = std::min(batch_size, end_idx - i);

// Wait for space in the prefetch buffer
std::unique_lock<std::mutex> lock(mtx_);
cv_full_.wait(lock,
[this] { return q_.size() < prefetch_factor_ || stop_; });
if (stop_) {
break;
}

// Launch the task. We pass 'fn' by value into the async lambda.
auto task = std::async(std::launch::async, [fn, i, current_batch_size] {
return fn(i, current_batch_size);
});
q_.push(std::move(task));

// Signal the consumer
lock.unlock();
cv_empty_.notify_one();
}
});
}

template <typename T>
auto AsyncDataLoader<T>::stop() -> void {
{
std::lock_guard<std::mutex> lock(mtx_);
if (stop_) {
return;
}
stop_ = true;
}

cv_full_.notify_all();
cv_empty_.notify_all();
if (worker_.joinable()) {
worker_.join();
}

std::lock_guard<std::mutex> lock(mtx_);
while (!q_.empty()) {
q_.pop();
}
}

template <typename T>
auto AsyncDataLoader<T>::next() -> std::optional<T> {
std::unique_lock<std::mutex> lock(mtx_);

// Wait for a task to be available or for the loader to stop
cv_empty_.wait(lock, [this] { return !q_.empty() || stop_; });
if (q_.empty()) {
return std::nullopt;
}

// Move the future out of the queue
auto fut = std::move(q_.front());
q_.pop();

// Unlock before blocking on .get() to allow the producer to continue
lock.unlock();
cv_full_.notify_one();
return fut.get();
}

} // namespace tguf
Loading