Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ else() # Target Linux x86_64
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu${CUDA_TAG}/libtorch-shared-with-deps-2.10.0%2Bcu${CUDA_TAG}.zip")
message(STATUS "TGUF: Target System is LINUX (CUDA ${CUDA_VERSION}).")
enable_language(CUDA)
add_compile_definitions(TGN_WITH_CUDA)
endif()
endif()

message(STATUS "Downloading Libtorch (${LIBTORCH_URL})")
FetchContent_Declare(
libtorch
URL ${LIBTORCH_URL}
DOWNLOAD_EXTRACT_TIMESTAMP ON
)
FetchContent_MakeAvailable(libtorch)
message(STATUS "Downloading Libtorch (${LIBTORCH_URL}) - done")

set(CMAKE_PREFIX_PATH ${libtorch_SOURCE_DIR})
find_package(Torch REQUIRED)
Expand Down
13 changes: 8 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ GPU_ARCH ?= native
CMAKE_FLAGS := -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCUDA_VERSION=$(CUDA_VERSION)

ifneq ($(CUDA_VERSION), cpu)
CMAKE_FLAGS += -DCMAKE_CUDA_ARCHITECTURES=$(GPU_ARCH)

# Add Torch-specific Arch list (converts 80 -> 8.0)
ifneq ($(GPU_ARCH), native)
TORCH_ARCH := $(shell echo $(GPU_ARCH) | sed 's/\([0-9]\)\([0-9]\)/\1.\2/')
Expand All @@ -22,6 +20,11 @@ ifneq ($(CUDA_VERSION), cpu)
endif
endif

PYTHON_BUILD_FLAGS := -DTGN_BUILD_PYTHON=ON
ifneq ($(CUDA_VERSION), cpu)
PYTHON_BUILD_FLAGS += -DCMAKE_CUDA_ARCHITECTURES=$(GPU_ARCH)
endif

NPROCS := $(shell nproc 2>/dev/null || sysctl -n hw.logicalcpu)

EXAMPLE_LINK := $(BUILD_DIR)/examples/tgn_link_pred
Expand Down Expand Up @@ -147,11 +150,11 @@ data/%.tguf: python

.PHONY: run-link-%
run-link-%: examples data/%.tguf
@$(EXAMPLE_LINK) data/$*.tguf
@$(EXAMPLE_LINK) data/$*.tguf $(ARGS)

.PHONY: run-node-%
run-node-%: examples data/%.tguf
@$(EXAMPLE_NODE) data/$*.tguf
@$(EXAMPLE_NODE) data/$*.tguf $(ARGS)

.PHONY: profile-build
profile-build: $(PROFILE_DIR)/CMakeCache.txt
Expand All @@ -173,7 +176,7 @@ download-%: data/%.tguf
python:
@(cd python && \
uv sync --group dev --no-install-project && \
SKBUILD_CMAKE_ARGS="-DTGN_BUILD_PYTHON=ON" \
SKBUILD_CMAKE_ARGS="$(PYTHON_BUILD_FLAGS)" \
uv pip install -e . --no-build-isolation)

.PHONY: test-python
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ make run-node-tgbn-trade

```sh
# Example: Cuda 12.6 on an A100 (Arch 80)
CUDA_VERSION=12.6 GPU_ARCH=80 make run-link-tgbl-wiki
CUDA_VERSION=12.6 GPU_ARCH=80 make run-link-tgbl-wiki ARGS="--device cuda:0"
```

> \[!TIP\]
Expand Down
20 changes: 13 additions & 7 deletions examples/link_pred.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include <torch/torch.h>

Check warning on line 1 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:1:1 [misc-include-cleaner]

included header torch.h is not used directly

#include <chrono>
#include <cstddef>
#include <iostream>
#include <memory>
#include <numeric>
#include <string>

Check warning on line 8 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:8:1 [misc-include-cleaner]

included header string is not used directly
#include <utility>

Check warning on line 9 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:9:1 [misc-include-cleaner]

included header utility is not used directly
#include <vector>

Check warning on line 10 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:10:1 [misc-include-cleaner]

included header vector is not used directly

#include "logging.h"
#include "tgn.h"
Expand All @@ -19,16 +19,19 @@
util::TGNArgs args{};
std::size_t current_epoch = 1;

struct LinkPredictorImpl : torch::nn::Module {

Check warning on line 22 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:22:39 [misc-include-cleaner]

no header providing "torch::nn::Module" is directly included
explicit LinkPredictorImpl(std::size_t in_dim) {
explicit LinkPredictorImpl(std::size_t in_dim,
torch::Device device = torch::kCPU) {

Check warning on line 24 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:24:60 [misc-include-cleaner]

no header providing "c10::kCPU" is directly included

Check warning on line 24 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:24:37 [misc-include-cleaner]

no header providing "c10::Device" is directly included
w_src_ = register_module("w_src_", torch::nn::Linear(in_dim, in_dim));
w_dst_ = register_module("w_dst_", torch::nn::Linear(in_dim, in_dim));
w_final_ = register_module("w_final_", torch::nn::Linear(in_dim, 1));
TGUF_LOG_INFO("LinkDecoder: Initialized (in_channels={})", in_dim);
this->to(device);
TGUF_LOG_INFO("LinkDecoder: Initialized on {} (in_channels={})",
device.str(), in_dim);
}

auto forward(const torch::Tensor& z_src, const torch::Tensor& z_dst)

Check warning on line 33 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:33:16 [bugprone-easily-swappable-parameters]

2 adjacent parameters of 'forward' of similar type ('const torch::Tensor &') are easily swapped by mistake

Check warning on line 33 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:33:8 [readability-convert-member-functions-to-static]

method 'forward' can be made static
-> torch::Tensor {

Check warning on line 34 in examples/link_pred.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

examples/link_pred.cpp:34:17 [misc-include-cleaner]

no header providing "torch::Tensor" is directly included
const auto z = torch::relu(w_src_->forward(z_src) + w_dst_->forward(z_dst));
return w_final_->forward(z).view(-1);
}
Expand Down Expand Up @@ -62,13 +65,14 @@

float total_loss{0};
const auto e_range = store->train_split();
const auto device = encoder->device();

for (auto e_id = e_range.start(); e_id < e_range.end();
e_id += args.batch_size) {
opt.zero_grad();

const auto batch = store->get_batch(e_id, args.batch_size,
tguf::TGStore::NegStrategy::Random);
const auto batch = store->get_batch(
e_id, args.batch_size, tguf::TGStore::NegStrategy::Random, device);
const auto [z_src, z_dst, z_neg] =
encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten());

Expand Down Expand Up @@ -107,11 +111,12 @@

std::vector<float> perf_list;
const auto e_range = store->val_split();
const auto device = encoder->device();

for (auto e_id = e_range.start(); e_id < e_range.end();
e_id += args.batch_size) {
const auto batch = store->get_batch(
e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed);
e_id, args.batch_size, tguf::TGStore::NegStrategy::PreComputed, device);
const auto [z_src, z_dst, z_neg] =
encoder->forward(batch.src, batch.dst, batch.neg_dst->flatten());

Expand Down Expand Up @@ -149,14 +154,15 @@

const std::shared_ptr<tguf::TGStore> store =
tguf::TGStore::from_tguf(args.tguf_path);
const auto cfg = tgn::TGNConfig{.embedding_dim = args.embedding_dim,
const auto cfg = tgn::TGNConfig{.device = args.device,
.embedding_dim = args.embedding_dim,
.memory_dim = args.memory_dim,
.time_dim = args.time_dim,
.num_heads = args.num_heads,
.num_nbrs = args.num_nbrs,
.dropout = args.dropout};
tgn::TGN encoder(cfg, store);
LinkPredictor decoder{cfg.embedding_dim};
LinkPredictor decoder{cfg.embedding_dim, cfg.device};

auto params = encoder->parameters();
auto dec_params = decoder->parameters();
Expand Down
28 changes: 18 additions & 10 deletions examples/node_pred.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ std::size_t current_epoch = 1;

struct NodePredictorImpl : torch::nn::Module {
explicit NodePredictorImpl(std::size_t in_dim, std::size_t out_dim,
torch::Device device = torch::kCPU,
std::size_t hidden_dim = 64) {
model_ = torch::nn::Sequential(torch::nn::Linear(in_dim, hidden_dim),
torch::nn::ReLU(),
torch::nn::Linear(hidden_dim, out_dim));
register_module("model_", model_);
this->to(device);
TGUF_LOG_INFO(
"NodeDecoder: Initialized (in_channels={}, hidden_dim={}, "
"NodeDecoder: Initialized on {} (in_channels={}, hidden_dim={}, "
"out_channels={})",
in_dim, hidden_dim, out_dim);
device.str(), in_dim, hidden_dim, out_dim);
}

auto forward(const torch::Tensor& z_node) -> torch::Tensor {
Expand All @@ -45,7 +47,8 @@ TORCH_MODULE(NodePredictor);
auto compute_ndcg(const torch::Tensor& y_pred, const torch::Tensor& y_true,
std::int64_t k = 10) -> float {
k = std::min(k, y_pred.size(-1));
const auto ranks = torch::arange(1, k + 1).to(torch::kFloat32);
const auto ranks =
torch::arange(1, k + 1, y_pred.options().dtype(torch::kFloat32));
const auto discounts = torch::log2(ranks + 1.0);

const auto [pred_labels, pred_indices] = y_pred.topk(k, -1);
Expand All @@ -68,6 +71,7 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt,

float total_loss{0};

const auto device = encoder->device();
const auto e_range = store->train_split();
const auto l_range = store->train_label_split();
auto e_id = e_range.start();
Expand All @@ -78,15 +82,16 @@ auto train(tgn::TGN& encoder, NodePredictor& decoder, torch::optim::Adam& opt,
const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id);
if (e_id < stop_e_id) {
const auto num_edges_to_process = stop_e_id - e_id;
const auto batch = store->get_batch(e_id, num_edges_to_process);
const auto batch = store->get_batch(
e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, device);

encoder->update_state(batch.src, batch.dst, batch.time, batch.msg);
e_id = stop_e_id;
}

opt.zero_grad();

const auto label_event = store->get_label_event(l_id++);
const auto label_event = store->get_label_event(l_id++, device);
const auto [z] = encoder->forward(label_event.n_id);
const auto y_pred = decoder->forward(z);

Expand Down Expand Up @@ -118,6 +123,7 @@ auto eval(tgn::TGN& encoder, NodePredictor& decoder,

std::vector<float> perf_list;

const auto device = encoder->device();
const auto e_range = store->val_split();
const auto l_range = store->val_label_split();
auto e_id = e_range.start();
Expand All @@ -127,13 +133,14 @@ auto eval(tgn::TGN& encoder, NodePredictor& decoder,
const auto stop_e_id = store->get_edge_cutoff_for_label_event(l_id);
if (e_id < stop_e_id) {
const auto num_edges_to_process = stop_e_id - e_id;
const auto batch = store->get_batch(e_id, num_edges_to_process);
const auto batch = store->get_batch(
e_id, num_edges_to_process, tguf::TGStore::NegStrategy::None, device);

encoder->update_state(batch.src, batch.dst, batch.time, batch.msg);
e_id = stop_e_id;
}

const auto label_event = store->get_label_event(l_id++);
const auto label_event = store->get_label_event(l_id++, device);
const auto [z] = encoder->forward(label_event.n_id);
const auto y_pred = decoder->forward(z);
perf_list.push_back(compute_ndcg(y_pred, label_event.target));
Expand All @@ -157,15 +164,16 @@ auto main(int argc, char** argv) -> int {

const std::shared_ptr<tguf::TGStore> store =
tguf::TGStore::from_tguf(args.tguf_path);
const auto cfg = tgn::TGNConfig{.embedding_dim = args.embedding_dim,
const auto cfg = tgn::TGNConfig{.device = args.device,
.embedding_dim = args.embedding_dim,
.memory_dim = args.memory_dim,
.time_dim = args.time_dim,
.num_heads = args.num_heads,
.num_nbrs = args.num_nbrs,
.dropout = args.dropout};
tgn::TGN encoder(cfg, store);
NodePredictor decoder{cfg.embedding_dim,
store->label_dim() /* num_classes */};
NodePredictor decoder{cfg.embedding_dim, store->label_dim() /* num_classes */,
cfg.device};

auto params = encoder->parameters();
auto dec_params = decoder->parameters();
Expand Down
34 changes: 32 additions & 2 deletions examples/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
#include <iostream>
#include <string>

#ifdef TGN_WITH_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif

namespace util {

struct TGNArgs {
std::string tguf_path;

torch::Device device = torch::kCPU;

std::size_t epochs = 10;
std::size_t batch_size = 200;
double lr = 1e-4;
Expand All @@ -29,6 +35,7 @@ inline auto parse_args(int argc, char** argv) -> TGNArgs {
auto print_usage = [argv]() {
std::cerr << "Usage: " << argv[0] << " <path_to_tguf> [options]\n"
<< "Options:\n"
<< " --device <device> (default: cpu)\n"
<< " --epochs <N> (default: 10)\n"
<< " --batch-size <N> (default: 200)\n"
<< " --lr <val> (default: 1e-4)\n"
Expand Down Expand Up @@ -73,7 +80,9 @@ inline auto parse_args(int argc, char** argv) -> TGNArgs {
std::string_view arg{argv[i]};
std::string_view val{argv[i + 1]};

if (arg == "--epochs") {
if (arg == "--device") {
args.device = torch::Device(std::string(val));
} else if (arg == "--epochs") {
args.epochs = to_type.template operator()<std::size_t>(val);
} else if (arg == "--batch-size") {
args.batch_size = to_type.template operator()<std::size_t>(val);
Expand All @@ -98,6 +107,7 @@ inline auto parse_args(int argc, char** argv) -> TGNArgs {
}

TGUF_LOG_INFO(" TGUF Path: {}", args.tguf_path);
TGUF_LOG_INFO(" Device: {}", args.device.str());
TGUF_LOG_INFO(" Epochs: {}", args.epochs);
TGUF_LOG_INFO(" Batch Size: {}", args.batch_size);
TGUF_LOG_INFO(" Learning Rate:{:.2e}", args.lr);
Expand Down Expand Up @@ -182,6 +192,26 @@ inline auto log_torch_backend_info() -> void {
#else
TGUF_LOG_WARN("LibTorch Backend | BLAS: Intel MKL not found");
#endif
};

#ifdef TGN_WITH_CUDA
if (torch::cuda::is_available()) {
const auto device_count = torch::cuda::device_count();
TGUF_LOG_INFO("LibTorch Backend | CUDA: Enabled ({} device(s) found)",
device_count);

for (auto i = 0; i < device_count; ++i) {
const auto prop = at::cuda::getDeviceProperties(i);
TGUF_LOG_INFO(
"LibTorch Backend | Device {}: {} | Compute Capability: {}.{}", i,
prop->name, prop->major, prop->minor);
}
} else {
TGUF_LOG_WARN(
"LibTorch Backend | CUDA backend linked but no GPU devices found");
}
#else
TGUF_LOG_INFO("LibTorch Backend | CUDA: Not linked with CUDA backend");
#endif
}; // namespace util

} // namespace util
10 changes: 7 additions & 3 deletions include/tgn.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ namespace tgn {
* @brief Configuration parameters for the TGN model architecture.
*/
struct TGNConfig {
std::size_t embedding_dim = 100; ///< TransformerConv embedding size.
std::size_t memory_dim = 100; ///< TGNMemory embedding size.
std::size_t time_dim = 100; ///< TimeEncoder embedding size.
torch::Device device = torch::kCPU; ///< Default to CPU.
std::size_t embedding_dim = 100; ///< TransformerConv embedding size.
std::size_t memory_dim = 100; ///< TGNMemory embedding size.
std::size_t time_dim = 100; ///< TimeEncoder embedding size.
std::size_t num_heads = 2; ///< TransformerConv multi-head attention heads.
std::size_t num_nbrs = 10; ///< RecencySampler neighbor buffer size.
float dropout = 0.1; ///< TransformerConv dropout.
Expand All @@ -49,6 +50,9 @@ class TGNImpl : public torch::nn::Module {
const torch::Tensor& time, const torch::Tensor& msg)
-> void;

/** @brief Get the torch::Device used by the module */
auto device() const -> torch::Device;

/**
* @brief Variadic forward pass.
* @param inputs Tensors of node IDs to compute embeddings for.
Expand Down
12 changes: 8 additions & 4 deletions include/tguf.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@ class TGStore {
* @param start The starting edge ID.
* @param size The number of edges to include.
* @param strategy The negative sampling strategy to apply.
* @param device The torch::Device to materialize the batch on.
*/
[[nodiscard]] virtual auto get_batch(
std::size_t start, std::size_t size,
NegStrategy strategy = NegStrategy::None) const -> Batch = 0;
[[nodiscard]] virtual auto get_batch(std::size_t start, std::size_t size,
NegStrategy strategy = NegStrategy::None,
torch::Device device = torch::kCPU) const
-> Batch = 0;

/** * @brief Performs a vectorized random-access gather of edge timestamps.
* @param e_id Tensor of edge indices [num_edges].
Expand Down Expand Up @@ -205,9 +207,11 @@ class TGStore {

/** * @brief Retrieves the metadata and target for a specific label event.
* @param l_id The index of the label event.
* @param device The torch::Device to materialize data on.
* @return A @ref LabelEvent containing affected node IDs and target values.
*/
[[nodiscard]] virtual auto get_label_event(std::size_t l_id) const
[[nodiscard]] virtual auto get_label_event(
std::size_t l_id, torch::Device device = torch::kCPU) const
-> LabelEvent = 0;
};

Expand Down
1 change: 1 addition & 0 deletions python/bind.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <nanobind/nanobind.h>

Check failure on line 1 in python/bind.cpp

View workflow job for this annotation

GitHub Actions / cpp-lint

python/bind.cpp:1:10 [clang-diagnostic-error]

'nanobind/nanobind.h' file not found
#include <nanobind/ndarray.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
Expand Down Expand Up @@ -342,6 +342,7 @@

.def_prop_ro("edge_count", &tguf::TGStore::edge_count)
.def_prop_ro("node_count", &tguf::TGStore::node_count)
.def_prop_ro("label_count", &tguf::TGStore::label_count)
.def_prop_ro("msg_dim", &tguf::TGStore::msg_dim)
.def_prop_ro("label_dim", &tguf::TGStore::label_dim)
.def_prop_ro("node_feat_dim", &tguf::TGStore::node_feat_dim)
Expand Down
1 change: 1 addition & 0 deletions python/test/test_csv_tguf_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_csv_tguf_roundtrip(resource_dir, output_tguf):

assert store.edge_count == 3
assert store.node_count == 31
assert store.label_count == 2
assert store.msg_dim == 2
assert store.label_dim == 2
assert store.node_feat_dim == 3
Expand Down
2 changes: 2 additions & 0 deletions python/test/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_tgstore_from_memory(get_data):

assert store.edge_count == num_edges
assert store.node_count == num_edges + 1
assert store.label_count == 0
assert store.msg_dim == 8
assert store.label_dim == 0
assert store.train_split.end == 15
Expand Down Expand Up @@ -46,6 +47,7 @@ def test_tgstore_from_tguf(schema, get_data):

assert store.edge_count == num_edges
assert store.node_count == num_edges + 1
assert store.label_count == 0
assert store.msg_dim == 8
assert store.label_dim == 0
assert store.train_split.end == 15
Expand Down
Loading
Loading