From 6a5e4aed69aafc72163bff74ea972c52ed9376b4 Mon Sep 17 00:00:00 2001 From: Jacob-Chmura Date: Mon, 23 Mar 2026 10:04:08 -0400 Subject: [PATCH 1/5] WIP --- CMakeLists.txt | 1 + Makefile | 4 ++-- README.md | 2 +- examples/link_pred.cpp | 2 ++ examples/node_pred.cpp | 2 ++ examples/util.h | 34 ++++++++++++++++++++++++++++++++-- 6 files changed, 40 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2615769..9a6a315 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ else() # Target Linux x86_64 set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu${CUDA_TAG}/libtorch-shared-with-deps-2.10.0%2Bcu${CUDA_TAG}.zip") message(STATUS "TGUF: Target System is LINUX (CUDA ${CUDA_VERSION}).") enable_language(CUDA) + add_compile_definitions(TGN_WITH_CUDA) endif() endif() diff --git a/Makefile b/Makefile index 8d4858c..4114a0b 100644 --- a/Makefile +++ b/Makefile @@ -147,11 +147,11 @@ data/%.tguf: python .PHONY: run-link-% run-link-%: examples data/%.tguf - @$(EXAMPLE_LINK) data/$*.tguf + @$(EXAMPLE_LINK) data/$*.tguf $(ARGS) .PHONY: run-node-% run-node-%: examples data/%.tguf - @$(EXAMPLE_NODE) data/$*.tguf + @$(EXAMPLE_NODE) data/$*.tguf $(ARGS) .PHONY: profile-build profile-build: $(PROFILE_DIR)/CMakeCache.txt diff --git a/README.md b/README.md index c020d1e..a18a336 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ make run-node-tgbn-trade ```sh # Example: Cuda 12.6 on an A100 (Arch 80) -CUDA_VERSION=12.6 GPU_ARCH=80 make run-link-tgbl-wiki +CUDA_VERSION=12.6 GPU_ARCH=80 make run-link-tgbl-wiki ARGS="--device cuda:0" ``` > \[!TIP\] diff --git a/examples/link_pred.cpp b/examples/link_pred.cpp index e7592f1..8e938ab 100644 --- a/examples/link_pred.cpp +++ b/examples/link_pred.cpp @@ -157,6 +157,8 @@ auto main(int argc, char** argv) -> int { .dropout = args.dropout}; tgn::TGN encoder(cfg, store); LinkPredictor decoder{cfg.embedding_dim}; + encoder->to(args.device); + decoder->to(args.device); auto params = encoder->parameters(); auto dec_params = decoder->parameters(); diff --git a/examples/node_pred.cpp b/examples/node_pred.cpp index d6cf9b6..6ecc8d7 100644 --- a/examples/node_pred.cpp +++ b/examples/node_pred.cpp @@ -166,6 +166,8 @@ auto main(int argc, char** argv) -> int { tgn::TGN encoder(cfg, store); NodePredictor decoder{cfg.embedding_dim, store->label_dim() /* num_classes */}; + encoder->to(args.device); + decoder->to(args.device); auto params = encoder->parameters(); auto dec_params = decoder->parameters(); diff --git a/examples/util.h b/examples/util.h index b513989..cd1f80a 100644 --- a/examples/util.h +++ b/examples/util.h @@ -8,11 +8,17 @@ #include #include +#ifdef TGN_WITH_CUDA +#include +#endif + namespace util { struct TGNArgs { std::string tguf_path; + torch::Device device = torch::kCPU; + std::size_t epochs = 10; std::size_t batch_size = 200; double lr = 1e-4; @@ -29,6 +35,7 @@ inline auto parse_args(int argc, char** argv) -> TGNArgs { auto print_usage = [argv]() { std::cerr << "Usage: " << argv[0] << " [options]\n" << "Options:\n" + << " --device (default: cpu)\n" << " --epochs (default: 10)\n" << " --batch-size (default: 200)\n" << " --lr (default: 1e-4)\n" @@ -73,7 +80,9 @@ inline auto parse_args(int argc, char** argv) -> TGNArgs { std::string_view arg{argv[i]}; std::string_view val{argv[i + 1]}; - if (arg == "--epochs") { + if (arg == "--device") { + args.device = torch::Device(std::string(val)); + } else if (arg == "--epochs") { args.epochs = to_type.template operator()(val); } else if (arg == "--batch-size") { args.batch_size = to_type.template operator()(val); @@ -98,6 +107,7 @@ inline auto parse_args(int argc, char** argv) -> TGNArgs { } TGUF_LOG_INFO(" TGUF Path: {}", args.tguf_path); + TGUF_LOG_INFO(" Device: {}", args.device.str()); TGUF_LOG_INFO(" Epochs: {}", args.epochs); TGUF_LOG_INFO(" Batch Size: {}", args.batch_size); TGUF_LOG_INFO(" Learning Rate:{:.2e}", args.lr); @@ -182,6 +192,26 @@ inline auto log_torch_backend_info() -> void { #else TGUF_LOG_WARN("LibTorch Backend | BLAS: Intel MKL not found"); #endif -}; + +#ifdef TGN_WITH_CUDA + if (torch::cuda::is_available()) { + const auto device_count = torch::cuda::device_count(); + TGUF_LOG_INFO("LibTorch Backend | CUDA: Enabled ({} device(s) found)", + device_count); + + for (auto i = 0; i < device_count; ++i) { + const auto prop = at::cuda::getDeviceProperties(i); + TGUF_LOG_INFO( + "LibTorch Backend | Device {}: {} | Compute Capability: {}.{}", i, + prop->name, prop->major, prop->minor); + } + } else { + TGUF_LOG_WARN( + "LibTorch Backend | CUDA backend linked but no GPU devices found"); + } +#else + TGUF_LOG_INFO("LibTorch Backend | CUDA: Not linked with CUDA backend"); +#endif +}; // namespace util } // namespace util From 7617fcb13a2359c26914e35fee28ed3ef0172ae0 Mon Sep 17 00:00:00 2001 From: Jacob-Chmura Date: Mon, 23 Mar 2026 15:46:48 -0400 Subject: [PATCH 2/5] wip --- CMakeLists.txt | 2 + Makefile | 9 ++- examples/link_pred.cpp | 16 ++--- examples/node_pred.cpp | 25 ++++--- include/tgn.h | 4 ++ include/tguf.h | 6 +- python/bind.cpp | 1 + python/test/test_csv_tguf_roundtrip.py | 1 + python/test/test_store.py | 2 + python/tguf/_tguf_py.pyi | 97 ++++++++++++++------------ src/tguf/store.cpp | 36 ++++++---- src/tguf_models/common/sampler.cpp | 12 ++-- src/tguf_models/common/sampler.h | 3 +- src/tguf_models/common/scatter_ops.cpp | 2 +- src/tguf_models/tgn.cpp | 66 ++++++++++-------- 15 files changed, 159 insertions(+), 123 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9a6a315..3d0c8b4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,12 +45,14 @@ else() # Target Linux x86_64 endif() endif() +message(STATUS "Downloading Libtorch (${LIBTORCH_URL})") FetchContent_Declare( libtorch URL ${LIBTORCH_URL} DOWNLOAD_EXTRACT_TIMESTAMP ON ) FetchContent_MakeAvailable(libtorch) +message(STATUS "Downloading Libtorch (${LIBTORCH_URL}) - done") set(CMAKE_PREFIX_PATH ${libtorch_SOURCE_DIR}) find_package(Torch REQUIRED) diff --git a/Makefile b/Makefile index 4114a0b..3166b97 100644 --- a/Makefile +++ b/Makefile @@ -7,8 +7,6 @@ GPU_ARCH ?= native CMAKE_FLAGS := -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCUDA_VERSION=$(CUDA_VERSION) ifneq ($(CUDA_VERSION), cpu) - CMAKE_FLAGS += -DCMAKE_CUDA_ARCHITECTURES=$(GPU_ARCH) - # Add Torch-specific Arch list (converts 80 -> 8.0) ifneq ($(GPU_ARCH), native) TORCH_ARCH := $(shell echo $(GPU_ARCH) | sed 's/\([0-9]\)\([0-9]\)/\1.\2/') @@ -22,6 +20,11 @@ ifneq ($(CUDA_VERSION), cpu) endif endif +PYTHON_BUILD_FLAGS := -DTGN_BUILD_PYTHON=ON +ifneq ($(CUDA_VERSION), cpu) + PYTHON_BUILD_FLAGS += -DCMAKE_CUDA_ARCHITECTURES=$(GPU_ARCH) +endif + NPROCS := $(shell nproc 2>/dev/null || sysctl -n hw.logicalcpu) EXAMPLE_LINK := $(BUILD_DIR)/examples/tgn_link_pred @@ -173,7 +176,7 @@ download-%: data/%.tguf python: @(cd python && \ uv sync --group dev --no-install-project && \ - SKBUILD_CMAKE_ARGS="-DTGN_BUILD_PYTHON=ON" \ + SKBUILD_CMAKE_ARGS="$(PYTHON_BUILD_FLAGS)" \ uv pip install -e . --no-build-isolation) .PHONY: test-python diff --git a/examples/link_pred.cpp b/examples/link_pred.cpp index 8e938ab..247c6c8 100644 --- a/examples/link_pred.cpp +++ b/examples/link_pred.cpp @@ -20,11 +20,12 @@ util::TGNArgs args{}; std::size_t current_epoch = 1; struct LinkPredictorImpl : torch::nn::Module { - explicit LinkPredictorImpl(std::size_t in_dim) { + explicit LinkPredictorImpl(std::size_t in_dim, torch::Device device = torch::kCPU) { w_src_ = register_module("w_src_", torch::nn::Linear(in_dim, in_dim)); w_dst_ = register_module("w_dst_", torch::nn::Linear(in_dim, in_dim)); w_final_ = register_module("w_final_", torch::nn::Linear(in_dim, 1)); - TGUF_LOG_INFO("LinkDecoder: Initialized (in_channels={})", in_dim); + this->to(device); + TGUF_LOG_INFO("LinkDecoder: Initialized on {} (in_channels={})", device.str(), in_dim); } auto forward(const torch::Tensor& z_src, const torch::Tensor& z_dst) @@ -68,7 +69,7 @@ auto train(tgn::TGN& encoder, LinkPredictor& decoder, torch::optim::Adam& opt, opt.zero_grad(); const auto batch = store->get_batch(e_id, args.batch_size, - tguf::TGStore::NegStrategy::Random); + tguf::TGStore::NegStrategy::Random, encoder->device()); const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); @@ -111,7 +112,7 @@ auto eval(tgn::TGN& encoder, LinkPredictor& decoder, for (auto e_id = e_range.start(); e_id < e_range.end(); e_id += args.batch_size) { const auto batch = store->get_batch( - e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed); + e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed, encoder->device()); const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); @@ -149,16 +150,15 @@ auto main(int argc, char** argv) -> int { const std::shared_ptr store = tguf::TGStore::from_tguf(args.tguf_path); - const auto cfg = tgn::TGNConfig{.embedding_dim = args.embedding_dim, + const auto cfg = tgn::TGNConfig{.device = args.device, + .embedding_dim = args.embedding_dim, .memory_dim = args.memory_dim, .time_dim = args.time_dim, .num_heads = args.num_heads, .num_nbrs = args.num_nbrs, .dropout = args.dropout}; tgn::TGN encoder(cfg, store); - LinkPredictor decoder{cfg.embedding_dim}; - encoder->to(args.device); - decoder->to(args.device); + LinkPredictor decoder{cfg.embedding_dim, cfg.device}; auto params = encoder->parameters(); auto dec_params = decoder->parameters(); diff --git a/examples/node_pred.cpp b/examples/node_pred.cpp index 6ecc8d7..4ced9b3 100644 --- a/examples/node_pred.cpp +++ b/examples/node_pred.cpp @@ -21,16 +21,16 @@ util::TGNArgs args{}; std::size_t current_epoch = 1; struct NodePredictorImpl : torch::nn::Module { - explicit NodePredictorImpl(std::size_t in_dim, std::size_t out_dim, - std::size_t hidden_dim = 64) { + explicit NodePredictorImpl(std::size_t in_dim, std::size_t out_dim, torch::Device device = torch::kCPU, std::size_t hidden_dim=64) { model_ = torch::nn::Sequential(torch::nn::Linear(in_dim, hidden_dim), torch::nn::ReLU(), torch::nn::Linear(hidden_dim, out_dim)); register_module("model_", model_); + this->to(device); TGUF_LOG_INFO( - "NodeDecoder: Initialized (in_channels={}, hidden_dim={}, " + "NodeDecoder: Initialized on {} (in_channels={}, hidden_dim={}, " "out_channels={})", - in_dim, hidden_dim, out_dim); + device.str(), in_dim, hidden_dim, out_dim); } auto forward(const torch::Tensor& z_node) -> torch::Tensor { @@ -45,7 +45,7 @@ TORCH_MODULE(NodePredictor); auto compute_ndcg(const torch::Tensor& y_pred, const torch::Tensor& y_true, std::int64_t k = 10) -> float { k = std::min(k, y_pred.size(-1)); - const auto ranks = torch::arange(1, k + 1).to(torch::kFloat32); + const auto ranks = torch::arange(1, k + 1, y_pred.options().dtype(torch::kFloat32)); const auto discounts = torch::log2(ranks + 1.0); const auto [pred_labels, pred_indices] = y_pred.topk(k, -1); @@ -78,7 +78,7 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id); if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; - const auto batch = store->get_batch(e_id, num_edges_to_process); + const auto batch = store->get_batch(e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, encoder->device()); encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id; @@ -86,7 +86,7 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, opt.zero_grad(); - const auto label_event = store->get_label_event(l_id++); + const auto label_event = store->get_label_event(l_id++, encoder->device()); const auto [z] = encoder->forward(label_event.n_id); const auto y_pred = decoder->forward(z); @@ -127,13 +127,13 @@ auto eval(tgn::TGN& encoder, NodePredictor& decoder, const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id); if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; - const auto batch = store->get_batch(e_id, num_edges_to_process); + const auto batch = store->get_batch(e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, encoder->device()); encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id; } - const auto label_event = store->get_label_event(l_id++); + const auto label_event = store->get_label_event(l_id++, encoder->device()); const auto [z] = encoder->forward(label_event.n_id); const auto y_pred = decoder->forward(z); perf_list.push_back(compute_ndcg(y_pred, label_event.target)); @@ -157,7 +157,8 @@ auto main(int argc, char** argv) -> int { const std::shared_ptr store = tguf::TGStore::from_tguf(args.tguf_path); - const auto cfg = tgn::TGNConfig{.embedding_dim = args.embedding_dim, + const auto cfg = tgn::TGNConfig{.device = args.device, + .embedding_dim = args.embedding_dim, .memory_dim = args.memory_dim, .time_dim = args.time_dim, .num_heads = args.num_heads, @@ -165,9 +166,7 @@ auto main(int argc, char** argv) -> int { .dropout = args.dropout}; tgn::TGN encoder(cfg, store); NodePredictor decoder{cfg.embedding_dim, - store->label_dim() /* num_classes */}; - encoder->to(args.device); - decoder->to(args.device); + store->label_dim() /* num_classes */, cfg.device}; auto params = encoder->parameters(); auto dec_params = decoder->parameters(); diff --git a/include/tgn.h b/include/tgn.h index 182ee3d..692c00c 100644 --- a/include/tgn.h +++ b/include/tgn.h @@ -20,6 +20,7 @@ namespace tgn { * @brief Configuration parameters for the TGN model architecture. */ struct TGNConfig { + torch::Device device = torch::kCPU; ///< Default to CPU. std::size_t embedding_dim = 100; ///< TransformerConv embedding size. std::size_t memory_dim = 100; ///< TGNMemory embedding size. std::size_t time_dim = 100; ///< TimeEncoder embedding size. @@ -49,6 +50,9 @@ class TGNImpl : public torch::nn::Module { const torch::Tensor& time, const torch::Tensor& msg) -> void; + /** @brief Get the torch::Device used by the module */ + auto device() const -> torch::Device; + /** * @brief Variadic forward pass. * @param inputs Tensors of node IDs to compute embeddings for. diff --git a/include/tguf.h b/include/tguf.h index 3344ede..bbde6e8 100644 --- a/include/tguf.h +++ b/include/tguf.h @@ -164,10 +164,11 @@ class TGStore { * @param start The starting edge ID. * @param size The number of edges to include. * @param strategy The negative sampling strategy to apply. + * @param device The torch::Device to materialize the batch on. */ [[nodiscard]] virtual auto get_batch( std::size_t start, std::size_t size, - NegStrategy strategy = NegStrategy::None) const -> Batch = 0; + NegStrategy strategy = NegStrategy::None, torch::Device device = torch::kCPU) const -> Batch = 0; /** * @brief Performs a vectorized random-access gather of edge timestamps. * @param e_id Tensor of edge indices [num_edges]. @@ -205,9 +206,10 @@ class TGStore { /** * @brief Retrieves the metadata and target for a specific label event. * @param l_id The index of the label event. + * @param device The torch::Device to materialize data on. * @return A @ref LabelEvent containing affected node IDs and target values. */ - [[nodiscard]] virtual auto get_label_event(std::size_t l_id) const + [[nodiscard]] virtual auto get_label_event(std::size_t l_id, torch::Device device = torch::kCPU) const -> LabelEvent = 0; }; diff --git a/python/bind.cpp b/python/bind.cpp index 4c05fc3..a33b0db 100644 --- a/python/bind.cpp +++ b/python/bind.cpp @@ -342,6 +342,7 @@ Use :meth:`from_memory` or :meth:`from_tguf` to instantiate. .def_prop_ro("edge_count", &tguf::TGStore::edge_count) .def_prop_ro("node_count", &tguf::TGStore::node_count) + .def_prop_ro("label_count", &tguf::TGStore::label_count) .def_prop_ro("msg_dim", &tguf::TGStore::msg_dim) .def_prop_ro("label_dim", &tguf::TGStore::label_dim) .def_prop_ro("node_feat_dim", &tguf::TGStore::node_feat_dim) diff --git a/python/test/test_csv_tguf_roundtrip.py b/python/test/test_csv_tguf_roundtrip.py index 79da005..c2ba50e 100644 --- a/python/test/test_csv_tguf_roundtrip.py +++ b/python/test/test_csv_tguf_roundtrip.py @@ -94,6 +94,7 @@ def test_csv_tguf_roundtrip(resource_dir, output_tguf): assert store.edge_count == 3 assert store.node_count == 31 + assert store.label_count == 2 assert store.msg_dim == 2 assert store.label_dim == 2 assert store.node_feat_dim == 3 diff --git a/python/test/test_store.py b/python/test/test_store.py index 593ab8a..ac573c1 100644 --- a/python/test/test_store.py +++ b/python/test/test_store.py @@ -16,6 +16,7 @@ def test_tgstore_from_memory(get_data): assert store.edge_count == num_edges assert store.node_count == num_edges + 1 + assert store.label_count == 0 assert store.msg_dim == 8 assert store.label_dim == 0 assert store.train_split.end == 15 @@ -46,6 +47,7 @@ def test_tgstore_from_tguf(schema, get_data): assert store.edge_count == num_edges assert store.node_count == num_edges + 1 + assert store.label_count == 0 assert store.msg_dim == 8 assert store.label_dim == 0 assert store.train_split.end == 15 diff --git a/python/tguf/_tguf_py.pyi b/python/tguf/_tguf_py.pyi index 3a88663..af43da1 100644 --- a/python/tguf/_tguf_py.pyi +++ b/python/tguf/_tguf_py.pyi @@ -11,6 +11,7 @@ from typing import Annotated import numpy from numpy.typing import NDArray + class TGUFSchema: """ Metadata defining the layout of a TGUF dataset. @@ -57,62 +58,71 @@ class TGUFSchema: as fully training unless overridden during loading. """ - def __init__( - self, - path: str, - edge_capacity: int | None = None, - msg_dim: int | None = None, - label_dim: int | None = None, - node_feat_capacity: int | None = None, - node_feat_dim: int | None = None, - label_capacity: int | None = None, - negatives_start_e_id: int | None = None, - negatives_per_edge: int | None = None, - val_start: int | None = None, - test_start: int | None = None, - ) -> None: ... + def __init__(self, path: str, edge_capacity: int | None = None, msg_dim: int | None = None, label_dim: int | None = None, node_feat_capacity: int | None = None, node_feat_dim: int | None = None, label_capacity: int | None = None, negatives_start_e_id: int | None = None, negatives_per_edge: int | None = None, val_start: int | None = None, test_start: int | None = None) -> None: ... + @property def path(self) -> str: ... + @path.setter def path(self, arg: str, /) -> None: ... + @property def edge_capacity(self) -> int: ... + @edge_capacity.setter def edge_capacity(self, arg: int, /) -> None: ... + @property def msg_dim(self) -> int: ... + @msg_dim.setter def msg_dim(self, arg: int, /) -> None: ... + @property def label_dim(self) -> int: ... + @label_dim.setter def label_dim(self, arg: int, /) -> None: ... + @property def node_feat_capacity(self) -> int: ... + @node_feat_capacity.setter def node_feat_capacity(self, arg: int, /) -> None: ... + @property def node_feat_dim(self) -> int: ... + @node_feat_dim.setter def node_feat_dim(self, arg: int, /) -> None: ... + @property def label_capacity(self) -> int: ... + @label_capacity.setter def label_capacity(self, arg: int, /) -> None: ... + @property def negatives_start_e_id(self) -> int: ... + @negatives_start_e_id.setter def negatives_start_e_id(self, arg: int, /) -> None: ... + @property def negatives_per_edge(self) -> int: ... + @negatives_per_edge.setter def negatives_per_edge(self, arg: int, /) -> None: ... + @property def val_start(self) -> int | None: ... + @val_start.setter def val_start(self, arg: int | None) -> None: ... + @property def test_start(self) -> int | None: ... + @test_start.setter def test_start(self, arg: int | None) -> None: ... @@ -147,14 +157,8 @@ class Batch: - :class:`TGUFBuilder` """ - def __init__( - self, - src: NDArray, - dst: NDArray, - time: NDArray, - msg: NDArray, - neg_dst: NDArray | None = None, - ) -> None: ... + def __init__(self, src: NDArray, dst: NDArray, time: NDArray, msg: NDArray, neg_dst: NDArray | None = None) -> None: ... + @property def src(self) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: """Source node IDs""" @@ -190,6 +194,7 @@ class LabelEvent: """ def __init__(self, n_id: NDArray, target: NDArray) -> None: ... + @property def n_id(self) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: """Node IDs associated with this label event.""" @@ -215,10 +220,13 @@ class IndexRange: """A contiguous slice of the graph data.""" def __init__(self, arg0: int, arg1: int, /) -> None: ... + @property def start(self) -> int: ... + @property def end(self) -> int: ... + @property def size(self) -> int: ... @@ -231,63 +239,59 @@ class TGStore: """ @staticmethod - def from_memory( - edges: Batch, - node_feats: NDArray | None = None, - label_n_id: NDArray | None = None, - label_time: NDArray | None = None, - label_target: NDArray | None = None, - val_start: int | None = None, - test_start: int | None = None, - ) -> TGStore: + def from_memory(edges: Batch, node_feats: NDArray | None = None, label_n_id: NDArray | None = None, label_time: NDArray | None = None, label_target: NDArray | None = None, val_start: int | None = None, test_start: int | None = None) -> TGStore: """Create a high-speed, purely RAM-based store.""" @staticmethod - def from_tguf( - path: str, val_start: int | None = None, test_start: int | None = None - ) -> TGStore: + def from_tguf(path: str, val_start: int | None = None, test_start: int | None = None) -> TGStore: """Create a memory-mapped store from a TGUF file.""" @property def edge_count(self) -> int: ... + @property def node_count(self) -> int: ... + + @property + def label_count(self) -> int: ... + @property def msg_dim(self) -> int: ... + @property def label_dim(self) -> int: ... + @property def node_feat_dim(self) -> int: ... + @property def train_split(self) -> IndexRange: ... + @property def val_split(self) -> IndexRange: ... + @property def test_split(self) -> IndexRange: ... + @property def train_label_split(self) -> IndexRange: ... + @property def val_label_split(self) -> IndexRange: ... + @property def test_label_split(self) -> IndexRange: ... - def get_batch( - self, start: int, size: int, strategy: NegStrategy = NegStrategy.None_ - ) -> Batch: + + def get_batch(self, start: int, size: int, strategy: NegStrategy = NegStrategy.None_) -> Batch: """Retrieve a zero-copy slice of the graph interaction data.""" - def gather_timestamps( - self, e_id: NDArray - ) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: + def gather_timestamps(self, e_id: NDArray) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: """Vectorized gather of edge timestamps.""" - def gather_msgs( - self, e_id: NDArray - ) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: + def gather_msgs(self, e_id: NDArray) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: """Vectorized gather of edge features (messages).""" - def gather_node_feats( - self, n_id: NDArray - ) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: + def gather_node_feats(self, n_id: NDArray) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: """Vectorized gather of static node features.""" def get_edge_cutoff_for_label_event(self, l_id: int) -> int: @@ -314,6 +318,7 @@ class TGUFBuilder: """ def __init__(self, schema: TGUFSchema) -> None: ... + def append_edges(self, batch: Batch) -> None: """ Append a batch of temporal edges to the dataset. diff --git a/src/tguf/store.cpp b/src/tguf/store.cpp index 7d27375..65259ac 100644 --- a/src/tguf/store.cpp +++ b/src/tguf/store.cpp @@ -326,7 +326,7 @@ class TGStoreImpl final : public TGStore { } [[nodiscard]] auto get_batch(std::size_t start, std::size_t batch_size, - NegStrategy strategy = NegStrategy::None) const + NegStrategy strategy, torch::Device device) const -> Batch override { const auto end = std::min(start + batch_size, num_edges_); const auto s = static_cast(start); @@ -340,7 +340,7 @@ class TGStoreImpl final : public TGStore { TORCH_CHECK(sampler_.has_value(), "Random sampling requested but sampler not initialized " "(train split is empty)"); - batch_neg = sampler_->sample(e - s); + batch_neg = sampler_->sample(e - s).to(device, true); } else if (strategy == NegStrategy::PreComputed) { TGUF_LOG_DEBUG("TGStore: get_batch [{}:{}] (NegStrategy::PreComputed)", start, end); @@ -354,27 +354,31 @@ class TGStoreImpl final : public TGStore { std::to_string(negatives_start_e_id_)); } batch_neg = neg_dst_->slice(0, s - negatives_start_e_id_, - e - negatives_start_e_id_); + e - negatives_start_e_id_).to(device, true); } else { TGUF_LOG_DEBUG("TGStore: get_batch [{}:{}] (NegStrategy::None)", start, end); } - return Batch{.src = src_.slice(0, s, e), - .dst = dst_.slice(0, s, e), - .time = t_.slice(0, s, e), - .msg = msg_.slice(0, s, e), + return Batch{.src = src_.slice(0, s, e).to(device, true), + .dst = dst_.slice(0, s, e).to(device, true), + .time = t_.slice(0, s, e).to(device, true), + .msg = msg_.slice(0, s, e).to(device, true), .neg_dst = batch_neg}; } [[nodiscard]] auto gather_timestamps(const torch::Tensor& e_id) const -> torch::Tensor override { - return t_.index_select(0, e_id.flatten()); + const auto e_id_cpu = e_id.device().is_cpu() ? e_id.flatten() : e_id.to(torch::kCPU).flatten(); + auto out = torch::empty({e_id_cpu.size(0)}, t_.options()); + return at::index_select_out(out, t_, 0, e_id_cpu).to(e_id.device(), true); } [[nodiscard]] auto gather_msgs(const torch::Tensor& e_id) const -> torch::Tensor override { - return msg_.index_select(0, e_id.flatten()); + const auto e_id_cpu = e_id.device().is_cpu() ? e_id.flatten() : e_id.to(torch::kCPU).flatten(); + auto out = torch::empty({e_id_cpu.size(0), msg_.size(1)}, msg_.options()); + return at::index_select_out(out, msg_, 0, e_id_cpu).to(e_id.device(), true); } [[nodiscard]] auto gather_node_feats(const torch::Tensor& n_id) const @@ -384,8 +388,10 @@ class TGStoreImpl final : public TGStore { } // Every ID outside [0, num_nodes-1] hits the padded row (all zeros) - const auto safe_ids = n_id.clamp(0, node_feats_->size(0) - 1); - return node_feats_->index_select(0, safe_ids.flatten()); + const auto n_id_cpu = n_id.device().is_cpu() ? n_id : n_id.to(torch::kCPU); + const auto safe_ids = n_id.clamp(0, node_feats_->size(0) - 1).flatten(); + auto out = torch::empty({safe_ids.size(0), node_feats_->size(1)}, node_feats_->options()); + return at::index_select_out(out, node_feats_.value(), 0, safe_ids).to(n_id.device(), true); } [[nodiscard]] auto get_edge_cutoff_for_label_event(std::size_t l_id) const @@ -396,12 +402,16 @@ class TGStoreImpl final : public TGStore { return stop_e_ids_.at(l_id); } - [[nodiscard]] auto get_label_event(std::size_t l_id) const + [[nodiscard]] auto get_label_event(std::size_t l_id, torch::Device device) const -> LabelEvent override { TORCH_CHECK(l_id < label_events_.size(), "TGStore: Requested LabelEvent at index ", l_id, " but only ", label_events_.size(), " events exist."); - return label_events_.at(l_id); + const auto le = label_events_.at(l_id); + return LabelEvent { + .n_id = le.n_id.to(device, true), + .target= le.target.to(device, true), + }; } private: diff --git a/src/tguf_models/common/sampler.cpp b/src/tguf_models/common/sampler.cpp index 08e3298..be1c69b 100644 --- a/src/tguf_models/common/sampler.cpp +++ b/src/tguf_models/common/sampler.cpp @@ -10,20 +10,20 @@ namespace tgn { LastNeighborLoader::LastNeighborLoader(std::size_t num_nbrs, - std::size_t num_nodes) + std::size_t num_nodes, torch::Device device) : buffer_size_(static_cast(num_nbrs)), buffer_nbrs_(torch::empty({static_cast(num_nodes), static_cast(num_nbrs)}, - torch::TensorOptions().dtype(torch::kLong))), + torch::device(device).dtype(torch::kLong))), buffer_e_id_(torch::empty({static_cast(num_nodes), static_cast(num_nbrs)}, - torch::TensorOptions().dtype(torch::kLong))), + torch::device(device).dtype(torch::kLong))), assoc_(torch::empty({static_cast(num_nodes)}, - torch::TensorOptions().dtype(torch::kLong))) { + torch::device(device).dtype(torch::kLong))) { const auto bytes = buffer_nbrs_.nbytes() + buffer_e_id_.nbytes() + assoc_.nbytes(); - TGUF_LOG_INFO("Sampler: ~{:.2f} MiB allocated ({} nodes, {} nbrs/node)", - bytes / (1024.0 * 1024.0), num_nodes, num_nbrs); + TGUF_LOG_INFO("Sampler: ~{:.2f} MiB allocated on {} ({} nodes, {} nbrs/node)", + bytes / (1024.0 * 1024.0), device.str(), num_nodes, num_nbrs); reset_state(); } diff --git a/src/tguf_models/common/sampler.h b/src/tguf_models/common/sampler.h index a92776a..a9891af 100644 --- a/src/tguf_models/common/sampler.h +++ b/src/tguf_models/common/sampler.h @@ -20,8 +20,9 @@ class LastNeighborLoader { * @brief Constructs the sampler and allocates persistent buffers. * @param num_nbrs The number of most neighbors ($K$) to track per node. * @param num_nodes The total capacity of the node index space ($N$). + * @param device The torch::Device to run on. */ - LastNeighborLoader(std::size_t num_nbrs, std::size_t num_nodes); + LastNeighborLoader(std::size_t num_nbrs, std::size_t num_nodes, torch::Device device = torch::kCPU); /** * @brief Samples the temporal neighborhood and performs local relabeling. diff --git a/src/tguf_models/common/scatter_ops.cpp b/src/tguf_models/common/scatter_ops.cpp index a2cdd41..ad073be 100644 --- a/src/tguf_models/common/scatter_ops.cpp +++ b/src/tguf_models/common/scatter_ops.cpp @@ -50,7 +50,7 @@ auto scatter_softmax(const torch::Tensor& src, const torch::Tensor& index, auto scatter_argmax(const torch::Tensor& src, const torch::Tensor& index, std::int64_t dim_size) -> torch::Tensor { auto res = scatter_max(src, index, dim_size); - auto out = torch::full({dim_size}, /*fill_value*/ dim_size - 1); + auto out = torch::full({dim_size}, /*fill_value*/ dim_size - 1, src.options().dtype(torch::kLong)); // Find where edge values match the winning max for each node const auto mask = src == res.index_select(0, index); diff --git a/src/tguf_models/tgn.cpp b/src/tguf_models/tgn.cpp index 4a2afee..b986178 100644 --- a/src/tguf_models/tgn.cpp +++ b/src/tguf_models/tgn.cpp @@ -20,10 +20,11 @@ namespace tgn { namespace detail { struct TimeEncoderImpl : torch::nn::Module { - explicit TimeEncoderImpl(std::size_t out_channels) { + explicit TimeEncoderImpl(std::size_t out_channels, torch::Device device = torch::kCPU) { lin_ = register_module("lin_", torch::nn::Linear(1, out_channels)); - TGUF_LOG_INFO("TimeEncoder: Initialized (time_embedding_dim={})", - out_channels); + this->to(device); + TGUF_LOG_INFO("TimeEncoder: Initialized on {} (time_embedding_dim={})", + device.str(), out_channels); } auto forward(const torch::Tensor& t) -> torch::Tensor { @@ -38,7 +39,7 @@ TORCH_MODULE(TimeEncoder); struct TransformerConvImpl : torch::nn::Module { TransformerConvImpl(std::size_t in_channels, std::size_t out_channels, std::size_t edge_dim, std::size_t heads, - float dropout = 0.0) + float dropout = 0.0, torch::Device device = torch::kCPU) : dropout_(dropout), H_(static_cast(heads)), C_(static_cast(out_channels)), @@ -50,10 +51,11 @@ struct TransformerConvImpl : torch::nn::Module { "w_e_", torch::nn::Linear(torch::nn::LinearOptions( static_cast(edge_dim), O_) .bias(false))); + this->to(device); TGUF_LOG_INFO( - "TransformerConv: Initialized (in_channels={}, out_channels={}, " + "TransformerConv: Initialized on {} (in_channels={}, out_channels={}, " "heads={}, edge_dim={}, dropout={:.2f})", - in_channels, out_channels, heads, edge_dim, dropout); + device.str(), in_channels, out_channels, heads, edge_dim, dropout); } auto forward(const torch::Tensor& x, const torch::Tensor& edge_index, @@ -101,11 +103,11 @@ struct TGNMemoryImpl : torch::nn::Module { struct MsgStore { torch::Tensor src_, dst_, time_, msg_; - MsgStore(std::int64_t num_nodes, std::int64_t msg_dim) { - src_ = torch::zeros({num_nodes}, torch::kLong); - dst_ = torch::zeros({num_nodes}, torch::kLong); - time_ = torch::zeros({num_nodes}, torch::kLong); - msg_ = torch::zeros({num_nodes, msg_dim}, torch::kFloat); + MsgStore(std::int64_t num_nodes, std::int64_t msg_dim, torch::Device device = torch::kCPU) { + src_ = torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); + dst_ = torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); + time_ = torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); + msg_ = torch::zeros({num_nodes, msg_dim}, torch::device(device).dtype(torch::kFloat)); } auto reset() -> void { @@ -118,12 +120,12 @@ struct TGNMemoryImpl : torch::nn::Module { auto update(const torch::Tensor& src, const torch::Tensor& dst, const torch::Tensor& time, const torch::Tensor msg) -> void { // Find the index of the last (max time) interaction for each source node - auto argmax = scatter_argmax(time, src, src_.size(0)); + const auto argmax = scatter_argmax(time, src, src_.size(0)); // mask out nodes that didn't appear in this batch - auto mask = argmax < src.size(0); - auto active_node_ids = torch::nonzero(mask).view(-1); - auto batch_indices = argmax.index({mask}); + const auto mask = argmax < src.size(0); + const auto active_node_ids = torch::nonzero(mask).view(-1); + const auto batch_indices = argmax.index({mask}); src_.index_put_({active_node_ids}, src.index_select(0, batch_indices)); dst_.index_put_({active_node_ids}, dst.index_select(0, batch_indices)); @@ -137,14 +139,14 @@ struct TGNMemoryImpl : torch::nn::Module { : msg_dim_(msg_dim), num_nodes_(num_nodes), memory_(torch::empty( - {num_nodes, static_cast(cfg.memory_dim)})), + {num_nodes, static_cast(cfg.memory_dim)}, torch::device(cfg.device))), last_update_(torch::empty({num_nodes}, - torch::TensorOptions().dtype(torch::kLong))), + torch::device(cfg.device).dtype(torch::kLong))), assoc_(torch::empty({num_nodes}, - torch::TensorOptions().dtype(torch::kLong))), + torch::device(cfg.device).dtype(torch::kLong))), time_encoder_(time_encoder), - src_store_(num_nodes, msg_dim), - dst_store_(num_nodes, msg_dim) { + src_store_(num_nodes, msg_dim, cfg.device), + dst_store_(num_nodes, msg_dim, cfg.device) { register_buffer("memory_", memory_); register_buffer("last_update_", last_update_); register_buffer("assoc_", assoc_); @@ -155,6 +157,7 @@ struct TGNMemoryImpl : torch::nn::Module { gru_ = register_module("gru_", torch::nn::GRUCell(cell_dim, cfg.memory_dim)); + this->to(cfg.device); reset_state(); auto get_store_bytes = [](const MsgStore& s) { @@ -166,9 +169,9 @@ struct TGNMemoryImpl : torch::nn::Module { assoc_.nbytes() + get_store_bytes(src_store_) + get_store_bytes(dst_store_); TGUF_LOG_INFO( - "TGNMemory: ~{:.2f} MiB allocated ({} nodes, memory_dim: {}, msg_dim: " + "TGNMemory: ~{:.2f} MiB allocated on {} ({} nodes, memory_dim: {}, msg_dim: " "{}, gru_cell_dim: {})", - bytes / (1024.0 * 1024.0), num_nodes_, cfg.memory_dim, msg_dim_, + bytes / (1024.0 * 1024.0), cfg.device.str(), num_nodes_, cfg.memory_dim, msg_dim_, cell_dim); } @@ -211,7 +214,7 @@ struct TGNMemoryImpl : torch::nn::Module { TGUF_LOG_DEBUG( "TGNMemory: Switching to Eval. Flushing memory for all {} nodes", num_nodes_); - update_memory(torch::arange(static_cast(num_nodes_))); + update_memory(torch::arange(static_cast(num_nodes_), memory_.options().dtype(torch::kLong))); src_store_.reset(); dst_store_.reset(); } @@ -285,13 +288,13 @@ struct TGNImpl::Impl { Impl(const TGNConfig& cfg, const std::shared_ptr& store) : cfg_(cfg), store_(store), - nbr_loader_(cfg.num_nbrs, store->node_count()), + nbr_loader_(cfg.num_nbrs, store->node_count(), cfg.device), assoc_(torch::full({static_cast(store->node_count())}, -1, - torch::dtype(torch::kLong))) { - time_encoder_ = detail::TimeEncoder(cfg.time_dim); + torch::device(cfg.device).dtype(torch::kLong))) { + time_encoder_ = detail::TimeEncoder(cfg.time_dim, cfg.device); conv_ = detail::TransformerConv( cfg.memory_dim + store_->node_feat_dim(), cfg.embedding_dim / 2, - store->msg_dim() + cfg.time_dim, cfg.num_heads, cfg.dropout); + store->msg_dim() + cfg.time_dim, cfg.num_heads, cfg.dropout, cfg.device); memory_ = detail::TGNMemory(cfg, time_encoder_, store->msg_dim(), store->node_count()); } @@ -313,6 +316,7 @@ TGNImpl::TGNImpl(const TGNConfig& cfg, register_module("conv", impl_->conv_); impl_->assoc_ = register_buffer("assoc", impl_->assoc_); + this->to(impl_->cfg_.device); } TGNImpl::~TGNImpl() = default; @@ -324,6 +328,8 @@ auto TGNImpl::reset_state() -> void { impl_->nbr_loader_.reset_state(); } +auto TGNImpl::device() const -> torch::Device {return this->parameters()[0].device();} + auto TGNImpl::update_state(const torch::Tensor& src, const torch::Tensor& dst, const torch::Tensor& time, const torch::Tensor& msg) -> void { @@ -345,16 +351,16 @@ auto TGNImpl::forward_internal(const std::vector& input_list) {n_id}, torch::arange(n_id.size(0), impl_->assoc_.options())); // Transformer conv with relative time encoding - const auto t_edges = impl_->store_->gather_timestamps(e_id); + const auto t_edges = impl_->store_->gather_timestamps(e_id).to(impl_->cfg_.device, true); const auto rel_t = last_update.index_select(0, edge_index[0]) - t_edges; const auto rel_t_z = impl_->time_encoder_->forward(rel_t.to(torch::kFloat32)); const auto edge_feat = impl_->store_->msg_dim() > 0 - ? torch::cat({rel_t_z, impl_->store_->gather_msgs(e_id)}, -1) + ? torch::cat({rel_t_z, impl_->store_->gather_msgs(e_id).to(impl_->cfg_.device, true)}, -1) : rel_t_z; const auto node_feat = impl_->store_->node_feat_dim() > 0 - ? torch::cat({memory, impl_->store_->gather_node_feats(n_id)}, -1) + ? torch::cat({memory, impl_->store_->gather_node_feats(n_id).to(impl_->cfg_.device, true)}, -1) : memory; const auto z = impl_->conv_->forward(node_feat, edge_index, edge_feat); From b7c23ddee6ef9d818a8886a7fe1c06ba439565ab Mon Sep 17 00:00:00 2001 From: Jacob-Chmura Date: Mon, 23 Mar 2026 15:47:52 -0400 Subject: [PATCH 3/5] WIP --- examples/link_pred.cpp | 16 +++-- examples/node_pred.cpp | 19 ++++-- include/tgn.h | 8 +-- include/tguf.h | 10 +-- python/tguf/_tguf_py.pyi | 95 +++++++++++++------------- src/tguf/store.cpp | 29 +++++--- src/tguf_models/common/sampler.cpp | 3 +- src/tguf_models/common/sampler.h | 3 +- src/tguf_models/common/scatter_ops.cpp | 3 +- src/tguf_models/tgn.cpp | 59 ++++++++++------ 10 files changed, 142 insertions(+), 103 deletions(-) diff --git a/examples/link_pred.cpp b/examples/link_pred.cpp index 247c6c8..6807e7f 100644 --- a/examples/link_pred.cpp +++ b/examples/link_pred.cpp @@ -20,12 +20,14 @@ util::TGNArgs args{}; std::size_t current_epoch = 1; struct LinkPredictorImpl : torch::nn::Module { - explicit LinkPredictorImpl(std::size_t in_dim, torch::Device device = torch::kCPU) { + explicit LinkPredictorImpl(std::size_t in_dim, + torch::Device device = torch::kCPU) { w_src_ = register_module("w_src_", torch::nn::Linear(in_dim, in_dim)); w_dst_ = register_module("w_dst_", torch::nn::Linear(in_dim, in_dim)); w_final_ = register_module("w_final_", torch::nn::Linear(in_dim, 1)); this->to(device); - TGUF_LOG_INFO("LinkDecoder: Initialized on {} (in_channels={})", device.str(), in_dim); + TGUF_LOG_INFO("LinkDecoder: Initialized on {} (in_channels={})", + device.str(), in_dim); } auto forward(const torch::Tensor& z_src, const torch::Tensor& z_dst) @@ -68,8 +70,9 @@ auto train(tgn::TGN& encoder, LinkPredictor& decoder, torch::optim::Adam& opt, e_id += args.batch_size) { opt.zero_grad(); - const auto batch = store->get_batch(e_id, args.batch_size, - tguf::TGStore::NegStrategy::Random, encoder->device()); + const auto batch = + store->get_batch(e_id, args.batch_size, + tguf::TGStore::NegStrategy::Random, encoder->device()); const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); @@ -111,8 +114,9 @@ auto eval(tgn::TGN& encoder, LinkPredictor& decoder, for (auto e_id = e_range.start(); e_id < e_range.end(); e_id += args.batch_size) { - const auto batch = store->get_batch( - e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed, encoder->device()); + const auto batch = store->get_batch(e_id, args.batch_size, + tguf::TGStore::NegStrategy::PreComputed, + encoder->device()); const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); diff --git a/examples/node_pred.cpp b/examples/node_pred.cpp index 4ced9b3..2f0f319 100644 --- a/examples/node_pred.cpp +++ b/examples/node_pred.cpp @@ -21,7 +21,9 @@ util::TGNArgs args{}; std::size_t current_epoch = 1; struct NodePredictorImpl : torch::nn::Module { - explicit NodePredictorImpl(std::size_t in_dim, std::size_t out_dim, torch::Device device = torch::kCPU, std::size_t hidden_dim=64) { + explicit NodePredictorImpl(std::size_t in_dim, std::size_t out_dim, + torch::Device device = torch::kCPU, + std::size_t hidden_dim = 64) { model_ = torch::nn::Sequential(torch::nn::Linear(in_dim, hidden_dim), torch::nn::ReLU(), torch::nn::Linear(hidden_dim, out_dim)); @@ -45,7 +47,8 @@ TORCH_MODULE(NodePredictor); auto compute_ndcg(const torch::Tensor& y_pred, const torch::Tensor& y_true, std::int64_t k = 10) -> float { k = std::min(k, y_pred.size(-1)); - const auto ranks = torch::arange(1, k + 1, y_pred.options().dtype(torch::kFloat32)); + const auto ranks = + torch::arange(1, k + 1, y_pred.options().dtype(torch::kFloat32)); const auto discounts = torch::log2(ranks + 1.0); const auto [pred_labels, pred_indices] = y_pred.topk(k, -1); @@ -78,7 +81,9 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id); if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; - const auto batch = store->get_batch(e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, encoder->device()); + const auto batch = + store->get_batch(e_id, num_edges_to_process, + tguf::TGStore::NegStrategy::None, encoder->device()); encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id; @@ -127,7 +132,9 @@ auto eval(tgn::TGN& encoder, NodePredictor& decoder, const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id); if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; - const auto batch = store->get_batch(e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, encoder->device()); + const auto batch = + store->get_batch(e_id, num_edges_to_process, + tguf::TGStore::NegStrategy::None, encoder->device()); encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id; @@ -165,8 +172,8 @@ auto main(int argc, char** argv) -> int { .num_nbrs = args.num_nbrs, .dropout = args.dropout}; tgn::TGN encoder(cfg, store); - NodePredictor decoder{cfg.embedding_dim, - store->label_dim() /* num_classes */, cfg.device}; + NodePredictor decoder{cfg.embedding_dim, store->label_dim() /* num_classes */, + cfg.device}; auto params = encoder->parameters(); auto dec_params = decoder->parameters(); diff --git a/include/tgn.h b/include/tgn.h index 692c00c..308a64a 100644 --- a/include/tgn.h +++ b/include/tgn.h @@ -20,10 +20,10 @@ namespace tgn { * @brief Configuration parameters for the TGN model architecture. */ struct TGNConfig { - torch::Device device = torch::kCPU; ///< Default to CPU. - std::size_t embedding_dim = 100; ///< TransformerConv embedding size. - std::size_t memory_dim = 100; ///< TGNMemory embedding size. - std::size_t time_dim = 100; ///< TimeEncoder embedding size. + torch::Device device = torch::kCPU; ///< Default to CPU. + std::size_t embedding_dim = 100; ///< TransformerConv embedding size. + std::size_t memory_dim = 100; ///< TGNMemory embedding size. + std::size_t time_dim = 100; ///< TimeEncoder embedding size. std::size_t num_heads = 2; ///< TransformerConv multi-head attention heads. std::size_t num_nbrs = 10; ///< RecencySampler neighbor buffer size. float dropout = 0.1; ///< TransformerConv dropout. diff --git a/include/tguf.h b/include/tguf.h index bbde6e8..f54ec36 100644 --- a/include/tguf.h +++ b/include/tguf.h @@ -166,9 +166,10 @@ class TGStore { * @param strategy The negative sampling strategy to apply. * @param device The torch::Device to materialize the batch on. */ - [[nodiscard]] virtual auto get_batch( - std::size_t start, std::size_t size, - NegStrategy strategy = NegStrategy::None, torch::Device device = torch::kCPU) const -> Batch = 0; + [[nodiscard]] virtual auto get_batch(std::size_t start, std::size_t size, + NegStrategy strategy = NegStrategy::None, + torch::Device device = torch::kCPU) const + -> Batch = 0; /** * @brief Performs a vectorized random-access gather of edge timestamps. * @param e_id Tensor of edge indices [num_edges]. @@ -209,7 +210,8 @@ class TGStore { * @param device The torch::Device to materialize data on. * @return A @ref LabelEvent containing affected node IDs and target values. */ - [[nodiscard]] virtual auto get_label_event(std::size_t l_id, torch::Device device = torch::kCPU) const + [[nodiscard]] virtual auto get_label_event( + std::size_t l_id, torch::Device device = torch::kCPU) const -> LabelEvent = 0; }; diff --git a/python/tguf/_tguf_py.pyi b/python/tguf/_tguf_py.pyi index af43da1..b5bdee1 100644 --- a/python/tguf/_tguf_py.pyi +++ b/python/tguf/_tguf_py.pyi @@ -11,7 +11,6 @@ from typing import Annotated import numpy from numpy.typing import NDArray - class TGUFSchema: """ Metadata defining the layout of a TGUF dataset. @@ -58,71 +57,62 @@ class TGUFSchema: as fully training unless overridden during loading. """ - def __init__(self, path: str, edge_capacity: int | None = None, msg_dim: int | None = None, label_dim: int | None = None, node_feat_capacity: int | None = None, node_feat_dim: int | None = None, label_capacity: int | None = None, negatives_start_e_id: int | None = None, negatives_per_edge: int | None = None, val_start: int | None = None, test_start: int | None = None) -> None: ... - + def __init__( + self, + path: str, + edge_capacity: int | None = None, + msg_dim: int | None = None, + label_dim: int | None = None, + node_feat_capacity: int | None = None, + node_feat_dim: int | None = None, + label_capacity: int | None = None, + negatives_start_e_id: int | None = None, + negatives_per_edge: int | None = None, + val_start: int | None = None, + test_start: int | None = None, + ) -> None: ... @property def path(self) -> str: ... - @path.setter def path(self, arg: str, /) -> None: ... - @property def edge_capacity(self) -> int: ... - @edge_capacity.setter def edge_capacity(self, arg: int, /) -> None: ... - @property def msg_dim(self) -> int: ... - @msg_dim.setter def msg_dim(self, arg: int, /) -> None: ... - @property def label_dim(self) -> int: ... - @label_dim.setter def label_dim(self, arg: int, /) -> None: ... - @property def node_feat_capacity(self) -> int: ... - @node_feat_capacity.setter def node_feat_capacity(self, arg: int, /) -> None: ... - @property def node_feat_dim(self) -> int: ... - @node_feat_dim.setter def node_feat_dim(self, arg: int, /) -> None: ... - @property def label_capacity(self) -> int: ... - @label_capacity.setter def label_capacity(self, arg: int, /) -> None: ... - @property def negatives_start_e_id(self) -> int: ... - @negatives_start_e_id.setter def negatives_start_e_id(self, arg: int, /) -> None: ... - @property def negatives_per_edge(self) -> int: ... - @negatives_per_edge.setter def negatives_per_edge(self, arg: int, /) -> None: ... - @property def val_start(self) -> int | None: ... - @val_start.setter def val_start(self, arg: int | None) -> None: ... - @property def test_start(self) -> int | None: ... - @test_start.setter def test_start(self, arg: int | None) -> None: ... @@ -157,8 +147,14 @@ class Batch: - :class:`TGUFBuilder` """ - def __init__(self, src: NDArray, dst: NDArray, time: NDArray, msg: NDArray, neg_dst: NDArray | None = None) -> None: ... - + def __init__( + self, + src: NDArray, + dst: NDArray, + time: NDArray, + msg: NDArray, + neg_dst: NDArray | None = None, + ) -> None: ... @property def src(self) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: """Source node IDs""" @@ -194,7 +190,6 @@ class LabelEvent: """ def __init__(self, n_id: NDArray, target: NDArray) -> None: ... - @property def n_id(self) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: """Node IDs associated with this label event.""" @@ -220,13 +215,10 @@ class IndexRange: """A contiguous slice of the graph data.""" def __init__(self, arg0: int, arg1: int, /) -> None: ... - @property def start(self) -> int: ... - @property def end(self) -> int: ... - @property def size(self) -> int: ... @@ -239,59 +231,65 @@ class TGStore: """ @staticmethod - def from_memory(edges: Batch, node_feats: NDArray | None = None, label_n_id: NDArray | None = None, label_time: NDArray | None = None, label_target: NDArray | None = None, val_start: int | None = None, test_start: int | None = None) -> TGStore: + def from_memory( + edges: Batch, + node_feats: NDArray | None = None, + label_n_id: NDArray | None = None, + label_time: NDArray | None = None, + label_target: NDArray | None = None, + val_start: int | None = None, + test_start: int | None = None, + ) -> TGStore: """Create a high-speed, purely RAM-based store.""" @staticmethod - def from_tguf(path: str, val_start: int | None = None, test_start: int | None = None) -> TGStore: + def from_tguf( + path: str, val_start: int | None = None, test_start: int | None = None + ) -> TGStore: """Create a memory-mapped store from a TGUF file.""" @property def edge_count(self) -> int: ... - @property def node_count(self) -> int: ... - @property def label_count(self) -> int: ... - @property def msg_dim(self) -> int: ... - @property def label_dim(self) -> int: ... - @property def node_feat_dim(self) -> int: ... - @property def train_split(self) -> IndexRange: ... - @property def val_split(self) -> IndexRange: ... - @property def test_split(self) -> IndexRange: ... - @property def train_label_split(self) -> IndexRange: ... - @property def val_label_split(self) -> IndexRange: ... - @property def test_label_split(self) -> IndexRange: ... - - def get_batch(self, start: int, size: int, strategy: NegStrategy = NegStrategy.None_) -> Batch: + def get_batch( + self, start: int, size: int, strategy: NegStrategy = NegStrategy.None_ + ) -> Batch: """Retrieve a zero-copy slice of the graph interaction data.""" - def gather_timestamps(self, e_id: NDArray) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: + def gather_timestamps( + self, e_id: NDArray + ) -> Annotated[NDArray[numpy.int64], dict(shape=(1))]: """Vectorized gather of edge timestamps.""" - def gather_msgs(self, e_id: NDArray) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: + def gather_msgs( + self, e_id: NDArray + ) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: """Vectorized gather of edge features (messages).""" - def gather_node_feats(self, n_id: NDArray) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: + def gather_node_feats( + self, n_id: NDArray + ) -> Annotated[NDArray[numpy.float32], dict(shape=(2))]: """Vectorized gather of static node features.""" def get_edge_cutoff_for_label_event(self, l_id: int) -> int: @@ -318,7 +316,6 @@ class TGUFBuilder: """ def __init__(self, schema: TGUFSchema) -> None: ... - def append_edges(self, batch: Batch) -> None: """ Append a batch of temporal edges to the dataset. diff --git a/src/tguf/store.cpp b/src/tguf/store.cpp index 65259ac..0428db4 100644 --- a/src/tguf/store.cpp +++ b/src/tguf/store.cpp @@ -353,8 +353,10 @@ class TGStoreImpl final : public TGStore { std::to_string(s) + " but negative storage starts at " + std::to_string(negatives_start_e_id_)); } - batch_neg = neg_dst_->slice(0, s - negatives_start_e_id_, - e - negatives_start_e_id_).to(device, true); + batch_neg = + neg_dst_ + ->slice(0, s - negatives_start_e_id_, e - negatives_start_e_id_) + .to(device, true); } else { TGUF_LOG_DEBUG("TGStore: get_batch [{}:{}] (NegStrategy::None)", start, end); @@ -369,14 +371,18 @@ class TGStoreImpl final : public TGStore { [[nodiscard]] auto gather_timestamps(const torch::Tensor& e_id) const -> torch::Tensor override { - const auto e_id_cpu = e_id.device().is_cpu() ? e_id.flatten() : e_id.to(torch::kCPU).flatten(); + const auto e_id_cpu = e_id.device().is_cpu() + ? e_id.flatten() + : e_id.to(torch::kCPU).flatten(); auto out = torch::empty({e_id_cpu.size(0)}, t_.options()); return at::index_select_out(out, t_, 0, e_id_cpu).to(e_id.device(), true); } [[nodiscard]] auto gather_msgs(const torch::Tensor& e_id) const -> torch::Tensor override { - const auto e_id_cpu = e_id.device().is_cpu() ? e_id.flatten() : e_id.to(torch::kCPU).flatten(); + const auto e_id_cpu = e_id.device().is_cpu() + ? e_id.flatten() + : e_id.to(torch::kCPU).flatten(); auto out = torch::empty({e_id_cpu.size(0), msg_.size(1)}, msg_.options()); return at::index_select_out(out, msg_, 0, e_id_cpu).to(e_id.device(), true); } @@ -390,8 +396,10 @@ class TGStoreImpl final : public TGStore { // Every ID outside [0, num_nodes-1] hits the padded row (all zeros) const auto n_id_cpu = n_id.device().is_cpu() ? n_id : n_id.to(torch::kCPU); const auto safe_ids = n_id.clamp(0, node_feats_->size(0) - 1).flatten(); - auto out = torch::empty({safe_ids.size(0), node_feats_->size(1)}, node_feats_->options()); - return at::index_select_out(out, node_feats_.value(), 0, safe_ids).to(n_id.device(), true); + auto out = torch::empty({safe_ids.size(0), node_feats_->size(1)}, + node_feats_->options()); + return at::index_select_out(out, node_feats_.value(), 0, safe_ids) + .to(n_id.device(), true); } [[nodiscard]] auto get_edge_cutoff_for_label_event(std::size_t l_id) const @@ -402,15 +410,16 @@ class TGStoreImpl final : public TGStore { return stop_e_ids_.at(l_id); } - [[nodiscard]] auto get_label_event(std::size_t l_id, torch::Device device) const + [[nodiscard]] auto get_label_event(std::size_t l_id, + torch::Device device) const -> LabelEvent override { TORCH_CHECK(l_id < label_events_.size(), "TGStore: Requested LabelEvent at index ", l_id, " but only ", label_events_.size(), " events exist."); const auto le = label_events_.at(l_id); - return LabelEvent { - .n_id = le.n_id.to(device, true), - .target= le.target.to(device, true), + return LabelEvent{ + .n_id = le.n_id.to(device, true), + .target = le.target.to(device, true), }; } diff --git a/src/tguf_models/common/sampler.cpp b/src/tguf_models/common/sampler.cpp index be1c69b..6eff482 100644 --- a/src/tguf_models/common/sampler.cpp +++ b/src/tguf_models/common/sampler.cpp @@ -10,7 +10,8 @@ namespace tgn { LastNeighborLoader::LastNeighborLoader(std::size_t num_nbrs, - std::size_t num_nodes, torch::Device device) + std::size_t num_nodes, + torch::Device device) : buffer_size_(static_cast(num_nbrs)), buffer_nbrs_(torch::empty({static_cast(num_nodes), static_cast(num_nbrs)}, diff --git a/src/tguf_models/common/sampler.h b/src/tguf_models/common/sampler.h index a9891af..9bbc8cd 100644 --- a/src/tguf_models/common/sampler.h +++ b/src/tguf_models/common/sampler.h @@ -22,7 +22,8 @@ class LastNeighborLoader { * @param num_nodes The total capacity of the node index space ($N$). * @param device The torch::Device to run on. */ - LastNeighborLoader(std::size_t num_nbrs, std::size_t num_nodes, torch::Device device = torch::kCPU); + LastNeighborLoader(std::size_t num_nbrs, std::size_t num_nodes, + torch::Device device = torch::kCPU); /** * @brief Samples the temporal neighborhood and performs local relabeling. diff --git a/src/tguf_models/common/scatter_ops.cpp b/src/tguf_models/common/scatter_ops.cpp index ad073be..be0d399 100644 --- a/src/tguf_models/common/scatter_ops.cpp +++ b/src/tguf_models/common/scatter_ops.cpp @@ -50,7 +50,8 @@ auto scatter_softmax(const torch::Tensor& src, const torch::Tensor& index, auto scatter_argmax(const torch::Tensor& src, const torch::Tensor& index, std::int64_t dim_size) -> torch::Tensor { auto res = scatter_max(src, index, dim_size); - auto out = torch::full({dim_size}, /*fill_value*/ dim_size - 1, src.options().dtype(torch::kLong)); + auto out = torch::full({dim_size}, /*fill_value*/ dim_size - 1, + src.options().dtype(torch::kLong)); // Find where edge values match the winning max for each node const auto mask = src == res.index_select(0, index); diff --git a/src/tguf_models/tgn.cpp b/src/tguf_models/tgn.cpp index b986178..adad28e 100644 --- a/src/tguf_models/tgn.cpp +++ b/src/tguf_models/tgn.cpp @@ -20,7 +20,8 @@ namespace tgn { namespace detail { struct TimeEncoderImpl : torch::nn::Module { - explicit TimeEncoderImpl(std::size_t out_channels, torch::Device device = torch::kCPU) { + explicit TimeEncoderImpl(std::size_t out_channels, + torch::Device device = torch::kCPU) { lin_ = register_module("lin_", torch::nn::Linear(1, out_channels)); this->to(device); TGUF_LOG_INFO("TimeEncoder: Initialized on {} (time_embedding_dim={})", @@ -103,11 +104,16 @@ struct TGNMemoryImpl : torch::nn::Module { struct MsgStore { torch::Tensor src_, dst_, time_, msg_; - MsgStore(std::int64_t num_nodes, std::int64_t msg_dim, torch::Device device = torch::kCPU) { - src_ = torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); - dst_ = torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); - time_ = torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); - msg_ = torch::zeros({num_nodes, msg_dim}, torch::device(device).dtype(torch::kFloat)); + MsgStore(std::int64_t num_nodes, std::int64_t msg_dim, + torch::Device device = torch::kCPU) { + src_ = + torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); + dst_ = + torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); + time_ = + torch::zeros({num_nodes}, torch::device(device).dtype(torch::kLong)); + msg_ = torch::zeros({num_nodes, msg_dim}, + torch::device(device).dtype(torch::kFloat)); } auto reset() -> void { @@ -138,10 +144,11 @@ struct TGNMemoryImpl : torch::nn::Module { std::int64_t msg_dim, std::int64_t num_nodes) : msg_dim_(msg_dim), num_nodes_(num_nodes), - memory_(torch::empty( - {num_nodes, static_cast(cfg.memory_dim)}, torch::device(cfg.device))), - last_update_(torch::empty({num_nodes}, - torch::device(cfg.device).dtype(torch::kLong))), + memory_( + torch::empty({num_nodes, static_cast(cfg.memory_dim)}, + torch::device(cfg.device))), + last_update_(torch::empty( + {num_nodes}, torch::device(cfg.device).dtype(torch::kLong))), assoc_(torch::empty({num_nodes}, torch::device(cfg.device).dtype(torch::kLong))), time_encoder_(time_encoder), @@ -169,10 +176,11 @@ struct TGNMemoryImpl : torch::nn::Module { assoc_.nbytes() + get_store_bytes(src_store_) + get_store_bytes(dst_store_); TGUF_LOG_INFO( - "TGNMemory: ~{:.2f} MiB allocated on {} ({} nodes, memory_dim: {}, msg_dim: " + "TGNMemory: ~{:.2f} MiB allocated on {} ({} nodes, memory_dim: {}, " + "msg_dim: " "{}, gru_cell_dim: {})", - bytes / (1024.0 * 1024.0), cfg.device.str(), num_nodes_, cfg.memory_dim, msg_dim_, - cell_dim); + bytes / (1024.0 * 1024.0), cfg.device.str(), num_nodes_, cfg.memory_dim, + msg_dim_, cell_dim); } auto reset_state() -> void { @@ -214,7 +222,8 @@ struct TGNMemoryImpl : torch::nn::Module { TGUF_LOG_DEBUG( "TGNMemory: Switching to Eval. Flushing memory for all {} nodes", num_nodes_); - update_memory(torch::arange(static_cast(num_nodes_), memory_.options().dtype(torch::kLong))); + update_memory(torch::arange(static_cast(num_nodes_), + memory_.options().dtype(torch::kLong))); src_store_.reset(); dst_store_.reset(); } @@ -292,9 +301,10 @@ struct TGNImpl::Impl { assoc_(torch::full({static_cast(store->node_count())}, -1, torch::device(cfg.device).dtype(torch::kLong))) { time_encoder_ = detail::TimeEncoder(cfg.time_dim, cfg.device); - conv_ = detail::TransformerConv( - cfg.memory_dim + store_->node_feat_dim(), cfg.embedding_dim / 2, - store->msg_dim() + cfg.time_dim, cfg.num_heads, cfg.dropout, cfg.device); + conv_ = detail::TransformerConv(cfg.memory_dim + store_->node_feat_dim(), + cfg.embedding_dim / 2, + store->msg_dim() + cfg.time_dim, + cfg.num_heads, cfg.dropout, cfg.device); memory_ = detail::TGNMemory(cfg, time_encoder_, store->msg_dim(), store->node_count()); } @@ -328,7 +338,9 @@ auto TGNImpl::reset_state() -> void { impl_->nbr_loader_.reset_state(); } -auto TGNImpl::device() const -> torch::Device {return this->parameters()[0].device();} +auto TGNImpl::device() const -> torch::Device { + return this->parameters()[0].device(); +} auto TGNImpl::update_state(const torch::Tensor& src, const torch::Tensor& dst, const torch::Tensor& time, const torch::Tensor& msg) @@ -351,16 +363,21 @@ auto TGNImpl::forward_internal(const std::vector& input_list) {n_id}, torch::arange(n_id.size(0), impl_->assoc_.options())); // Transformer conv with relative time encoding - const auto t_edges = impl_->store_->gather_timestamps(e_id).to(impl_->cfg_.device, true); + const auto t_edges = + impl_->store_->gather_timestamps(e_id).to(impl_->cfg_.device, true); const auto rel_t = last_update.index_select(0, edge_index[0]) - t_edges; const auto rel_t_z = impl_->time_encoder_->forward(rel_t.to(torch::kFloat32)); const auto edge_feat = impl_->store_->msg_dim() > 0 - ? torch::cat({rel_t_z, impl_->store_->gather_msgs(e_id).to(impl_->cfg_.device, true)}, -1) + ? torch::cat({rel_t_z, impl_->store_->gather_msgs(e_id).to( + impl_->cfg_.device, true)}, + -1) : rel_t_z; const auto node_feat = impl_->store_->node_feat_dim() > 0 - ? torch::cat({memory, impl_->store_->gather_node_feats(n_id).to(impl_->cfg_.device, true)}, -1) + ? torch::cat({memory, impl_->store_->gather_node_feats(n_id).to( + impl_->cfg_.device, true)}, + -1) : memory; const auto z = impl_->conv_->forward(node_feat, edge_index, edge_feat); From 5657efae111a3f819b6363c3e01182f8a1811283 Mon Sep 17 00:00:00 2001 From: Jacob-Chmura Date: Mon, 23 Mar 2026 15:52:09 -0400 Subject: [PATCH 4/5] WIP --- examples/link_pred.cpp | 5 ++--- examples/node_pred.cpp | 14 +++++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/link_pred.cpp b/examples/link_pred.cpp index 6807e7f..baa87c7 100644 --- a/examples/link_pred.cpp +++ b/examples/link_pred.cpp @@ -111,12 +111,11 @@ auto eval(tgn::TGN& encoder, LinkPredictor& decoder, std::vector perf_list; const auto e_range = store->val_split(); + const auto device = encoder->device(); for (auto e_id = e_range.start(); e_id < e_range.end(); e_id += args.batch_size) { - const auto batch = store->get_batch(e_id, args.batch_size, - tguf::TGStore::NegStrategy::PreComputed, - encoder->device()); + const auto batch = store->get_batch(e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed, device; const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); diff --git a/examples/node_pred.cpp b/examples/node_pred.cpp index 2f0f319..c5def09 100644 --- a/examples/node_pred.cpp +++ b/examples/node_pred.cpp @@ -71,6 +71,7 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, float total_loss{0}; + const auto device = encoder->device(); const auto e_range = store->train_split(); const auto l_range = store->train_label_split(); auto e_id = e_range.start(); @@ -82,8 +83,7 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; const auto batch = - store->get_batch(e_id, num_edges_to_process, - tguf::TGStore::NegStrategy::None, encoder->device()); + store->get_batch(e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, device; encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id; @@ -91,7 +91,7 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, opt.zero_grad(); - const auto label_event = store->get_label_event(l_id++, encoder->device()); + const auto label_event = store->get_label_event(l_id++, device); const auto [z] = encoder->forward(label_event.n_id); const auto y_pred = decoder->forward(z); @@ -123,6 +123,7 @@ auto eval(tgn::TGN& encoder, NodePredictor& decoder, std::vector perf_list; + const auto device = encoder->device(); const auto e_range = store->val_split(); const auto l_range = store->val_label_split(); auto e_id = e_range.start(); @@ -132,15 +133,14 @@ auto eval(tgn::TGN& encoder, NodePredictor& decoder, const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id); if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; - const auto batch = - store->get_batch(e_id, num_edges_to_process, - tguf::TGStore::NegStrategy::None, encoder->device()); + const auto batch = store->get_batch( + e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, device); encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id; } - const auto label_event = store->get_label_event(l_id++, encoder->device()); + const auto label_event = store->get_label_event(l_id++, device); const auto [z] = encoder->forward(label_event.n_id); const auto y_pred = decoder->forward(z); perf_list.push_back(compute_ndcg(y_pred, label_event.target)); From 9707755836c6f4d4851d199c312c1c5d552e2cfb Mon Sep 17 00:00:00 2001 From: Jacob-Chmura Date: Mon, 23 Mar 2026 15:53:57 -0400 Subject: [PATCH 5/5] WIP --- examples/link_pred.cpp | 9 +++++---- examples/node_pred.cpp | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/link_pred.cpp b/examples/link_pred.cpp index baa87c7..fa45397 100644 --- a/examples/link_pred.cpp +++ b/examples/link_pred.cpp @@ -65,14 +65,14 @@ auto train(tgn::TGN& encoder, LinkPredictor& decoder, torch::optim::Adam& opt, float total_loss{0}; const auto e_range = store->train_split(); + const auto device = encoder->device(); for (auto e_id = e_range.start(); e_id < e_range.end(); e_id += args.batch_size) { opt.zero_grad(); - const auto batch = - store->get_batch(e_id, args.batch_size, - tguf::TGStore::NegStrategy::Random, encoder->device()); + const auto batch = store->get_batch( + e_id, args.batch_size, tguf::TGStore::NegStrategy::Random, device); const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); @@ -115,7 +115,8 @@ auto eval(tgn::TGN& encoder, LinkPredictor& decoder, for (auto e_id = e_range.start(); e_id < e_range.end(); e_id += args.batch_size) { - const auto batch = store->get_batch(e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed, device; + const auto batch = store->get_batch( + e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed, device); const auto [z_src, z_dst, z_neg] = encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten()); diff --git a/examples/node_pred.cpp b/examples/node_pred.cpp index c5def09..02db32f 100644 --- a/examples/node_pred.cpp +++ b/examples/node_pred.cpp @@ -82,8 +82,8 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt, const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id); if (e_id < stop_e_id) { const auto num_edges_to_process = stop_e_id - e_id; - const auto batch = - store->get_batch(e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, device; + const auto batch = store->get_batch( + e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, device); encoder->update_state(batch.src, batch.dst, batch.time, batch.msg); e_id = stop_e_id;