From c8e30098fe7c51fd9c52faa36dc18cbcae437c73 Mon Sep 17 00:00:00 2001 From: Yohann Dudouit Date: Thu, 23 Apr 2026 18:12:35 -0700 Subject: [PATCH 1/2] Support graph surrogate model execution. - Add tests for homogeneous and heterogenous graph surrogate models. --- src/AMSlib/ml/surrogate.hpp | 18 ++ src/AMSlib/wf/interface.cpp | 120 ++++++--- src/AMSlib/wf/workflow.hpp | 7 + tests/AMSlib/ams_interface/CMakeLists.txt | 3 + .../ams_interface/test_graph_surrogate.cpp | 228 ++++++++++++++++++ tests/AMSlib/models/CMakeLists.txt | 5 + tests/AMSlib/models/ams_model.py | 172 ++++++++++++- tests/AMSlib/models/generate.sh | 2 + tests/AMSlib/models/generate_graph_models.py | 143 +++++++++++ 9 files changed, 657 insertions(+), 41 deletions(-) create mode 100644 tests/AMSlib/ams_interface/test_graph_surrogate.cpp create mode 100644 tests/AMSlib/models/generate_graph_models.py diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp index 05916510..c9cf07a1 100644 --- a/src/AMSlib/ml/surrogate.hpp +++ b/src/AMSlib/ml/surrogate.hpp @@ -25,12 +25,30 @@ #include "ArrayRef.hpp" #include "wf/debug.h" +// Forward declarations for graph surrogate friend functions +namespace ams { +class AMSWorkflow; +bool tryGraphSurrogate(AMSWorkflow*, + const AMSHomogeneousGraph&, + SmallVector&); +bool tryGraphSurrogate(AMSWorkflow*, + const AMSHeterogeneousGraph&, + SmallVector&); +} //! ---------------------------------------------------------------------------- //! An implementation for a surrogate model //! ---------------------------------------------------------------------------- class SurrogateModel { + // Friend declarations for graph surrogate execution + // Note: These are defined in interface.cpp within the ams namespace + friend bool ams::tryGraphSurrogate(ams::AMSWorkflow*, + const ams::AMSHomogeneousGraph&, + ams::SmallVector&); + friend bool ams::tryGraphSurrogate(ams::AMSWorkflow*, + const ams::AMSHeterogeneousGraph&, + ams::SmallVector&); private: const std::string _model_path; diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp index 0f763423..b6f60db9 100644 --- a/src/AMSlib/wf/interface.cpp +++ b/src/AMSlib/wf/interface.cpp @@ -292,28 +292,30 @@ static ams::AMSHeterogeneousGraph torchToAMSHeterogeneousGraph( return out; } -static c10::impl::GenericDict amsNodeStoresToTorchDict( +static c10::Dict> +amsNodeStoresToTorchDict( const std::unordered_map& node_stores) { - c10::impl::GenericDict out(c10::StringType::get(), c10::AnyType::get()); + c10::Dict> out; for (const auto& [store_name, store] : node_stores) { - out.insert(store_name, c10::IValue(amsTensorMapToTorchDict(store))); + out.insert(store_name, amsTensorMapToTorchDict(store)); } return out; } -static c10::impl::GenericDict amsEdgeStoresToTorchDict( +static c10::Dict> +amsEdgeStoresToTorchDict( const std::unordered_map& edge_stores) { - c10::impl::GenericDict out(c10::StringType::get(), c10::AnyType::get()); + c10::Dict> out; for (const auto& [edge_type, store] : edge_stores) { out.insert(ams::edgeTypeToString(edge_type), - c10::IValue(amsTensorMapToTorchDict(store))); + amsTensorMapToTorchDict(store)); } return out; @@ -324,12 +326,9 @@ static c10::impl::GenericDict amsToTorchHeterogeneousGraph( { c10::impl::GenericDict out(c10::StringType::get(), c10::AnyType::get()); - out.insert("node_stores", - c10::IValue(amsNodeStoresToTorchDict(g.node_stores))); - out.insert("edge_stores", - c10::IValue(amsEdgeStoresToTorchDict(g.edge_stores))); - out.insert("global_store", - c10::IValue(amsTensorMapToTorchDict(g.global_store))); + out.insert("node_stores", amsNodeStoresToTorchDict(g.node_stores)); + out.insert("edge_stores", amsEdgeStoresToTorchDict(g.edge_stores)); + out.insert("global_store", amsTensorMapToTorchDict(g.global_store)); return out; } @@ -380,42 +379,83 @@ void callApplication(ams::HeterogeneousGraphDomainFn CallBack, } // ============================================================================ -// Graph surrogate execution stub (seam for future implementation) +// Graph surrogate execution (in ams namespace for friend access) // ============================================================================ -static bool tryGraphSurrogate(ams::AMSWorkflow* executor, - const ams::AMSHomogeneousGraph& graph, - ams::SmallVector& outs) +namespace ams { + +bool tryGraphSurrogate(AMSWorkflow* executor, + const AMSHomogeneousGraph& graph, + SmallVector& outs) { - // TODO: Implement graph surrogate execution when models support graphs - // This is the integration point for future graph-based ML inference - // - // Future implementation should: - // 1. Check if executor has a model that accepts graph inputs - // 2. Convert AMSHomogeneousGraph to model input format - // 3. Run model inference and uncertainty quantification - // 4. If UQ passes threshold, populate outs and return true - // 5. Otherwise return false to trigger fallback - // - // For now, always return false (no surrogate available) - (void)executor; - (void)graph; - (void)outs; - return false; + // Check if model is available + if (!executor || !executor->MLModel) { + return false; + } + + try { + // Convert AMS graph → Torch Dict[str, Tensor] + auto torch_graph = amsToTorchHomogeneousGraph(graph); + + // Call model forward pass + std::vector inputs = {torch::jit::IValue(torch_graph)}; + auto result = executor->MLModel->module.forward(inputs); + + // Extract prediction from tuple [prediction, uncertainty] + auto result_tuple = result.toTuple(); + if (result_tuple->elements().size() < 1) { + return false; + } + torch::Tensor prediction = result_tuple->elements()[0].toTensor(); + + // Convert prediction → AMSTensor and populate outs + outs.clear(); + outs.push_back(torchToAMSTensorView(prediction)); + + return true; + } catch (const std::exception& e) { + // Model invocation failed - fallback to physics + return false; + } } -static bool tryGraphSurrogate(ams::AMSWorkflow* executor, - const ams::AMSHeterogeneousGraph& graph, - ams::SmallVector& outs) +bool tryGraphSurrogate(AMSWorkflow* executor, + const AMSHeterogeneousGraph& graph, + SmallVector& outs) { - // TODO: Implement graph surrogate execution when models support graphs - // See homogeneous version for implementation notes - (void)executor; - (void)graph; - (void)outs; - return false; + // Check if model is available + if (!executor || !executor->MLModel) { + return false; + } + + try { + // Convert AMS graph → Torch GenericDict + auto torch_graph = amsToTorchHeterogeneousGraph(graph); + + // Call model forward pass + std::vector inputs = {torch::jit::IValue(torch_graph)}; + auto result = executor->MLModel->module.forward(inputs); + + // Extract prediction from tuple [prediction, uncertainty] + auto result_tuple = result.toTuple(); + if (result_tuple->elements().size() < 1) { + return false; + } + torch::Tensor prediction = result_tuple->elements()[0].toTensor(); + + // Convert prediction → AMSTensor and populate outs + outs.clear(); + outs.push_back(torchToAMSTensorView(prediction)); + + return true; + } catch (const std::exception& e) { + // Model invocation failed - fallback to physics + return false; + } } +} // namespace ams + // ============================================================================ // Graph-based callAMS overloads // ============================================================================ diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index 1c7e6328..6825b3ec 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -35,6 +35,13 @@ namespace ams { class AMSWorkflow { + // Friend declarations for graph surrogate execution access to private MLModel + friend bool tryGraphSurrogate(AMSWorkflow*, + const AMSHomogeneousGraph&, + SmallVector&); + friend bool tryGraphSurrogate(AMSWorkflow*, + const AMSHeterogeneousGraph&, + SmallVector&); /** @brief A string identifier describing the domain-model being solved. */ std::string domainName; diff --git a/tests/AMSlib/ams_interface/CMakeLists.txt b/tests/AMSlib/ams_interface/CMakeLists.txt index aba6adaf..17ee7034 100644 --- a/tests/AMSlib/ams_interface/CMakeLists.txt +++ b/tests/AMSlib/ams_interface/CMakeLists.txt @@ -30,3 +30,6 @@ ADD_AMS_UNIT_TEST(AMS_INT_INTERFACE int_interface) BUILD_UNIT_TEST(ams_graph_fallback test_graph_fallback.cpp) ADD_AMS_UNIT_TEST(AMS_GRAPH_FALLBACK ams_graph_fallback) +BUILD_UNIT_TEST(ams_graph_surrogate test_graph_surrogate.cpp) +ADD_AMS_UNIT_TEST(AMS_GRAPH_SURROGATE ams_graph_surrogate) + diff --git a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp new file mode 100644 index 00000000..b23b55d4 --- /dev/null +++ b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp @@ -0,0 +1,228 @@ +#include + +#include "AMS.h" +#include "AMSGraph.hpp" +#include "AMSTensor.hpp" + +using namespace ams; + +// Paths to graph test models (relative to test executable) +static const char* HOMOGENEOUS_GRAPH_MODEL_PATH = "../models/homogeneous_graph.pt"; +static const char* HETEROGENEOUS_GRAPH_MODEL_PATH = "../models/heterogeneous_graph.pt"; + +CATCH_TEST_CASE("AMSExecute homogeneous graph surrogate execution", + "[wf][graph][surrogate]") +{ + AMSInit(); + + // Setup: Register model with actual generated model path + auto model = AMSRegisterAbstractModel("test_homo_surrogate", 0.5, + HOMOGENEOUS_GRAPH_MODEL_PATH, false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create simple homogeneous graph with 'x' field (expected by test model) + AMSHomogeneousGraph graph; + + // Insert node features tensor named 'x' + AMSTensor::IntDimType node_shape[] = {10, 16}; // 10 nodes, 16 features + AMSTensor::IntDimType node_strides[] = {16, 1}; + auto node_features = AMSTensor::create( + ams::ArrayRef(node_shape, 2), + ams::ArrayRef(node_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with test data + float* features_data = node_features.data(); + for (int i = 0; i < 160; ++i) { + features_data[i] = static_cast(i) * 0.1f; + } + + insertTensor(graph, "x", std::move(node_features)); + + // Define callback (should NOT be called if surrogate succeeds) + bool callback_invoked = false; + HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(containsTensor(g, "x")); + const auto* x = findTensor(g, "x"); + CATCH_REQUIRE(x != nullptr); + CATCH_REQUIRE(x->shape()[0] == 10); + CATCH_REQUIRE(x->shape()[1] == 16); + + // Create output tensor (8 features per node, matching model output) + AMSTensor::IntDimType out_shape[] = {10, 8}; + AMSTensor::IntDimType out_strides[] = {8, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(out_shape, 2), + ams::ArrayRef(out_strides, 2), + AMSResourceType::AMS_HOST); + + // Fill with physics computation result + float* out_data = out_tensor.data(); + for (int i = 0; i < 80; ++i) { + out_data[i] = static_cast(i); + } + + outputs.clear(); + outputs.push_back(std::move(out_tensor)); + }; + + // Execute + SmallVector outs; + AMSExecute(executor, callback, graph, outs); + + // Model is available, so surrogate should be used (callback NOT invoked) + CATCH_REQUIRE_FALSE(callback_invoked); + CATCH_REQUIRE(outs.size() == 1); + CATCH_REQUIRE(outs[0].shape()[0] == 10); + CATCH_REQUIRE(outs[0].shape()[1] == 8); +} + +CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution", + "[wf][graph][surrogate]") +{ + AMSInit(); + + // Setup: Register model with actual generated model path + auto model = AMSRegisterAbstractModel("test_hetero_surrogate", 0.5, + HETEROGENEOUS_GRAPH_MODEL_PATH, false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create heterogeneous graph + AMSHeterogeneousGraph graph; + + // Add node store for "node" type with 'x' features + // Note: Using fixed "node" name to match test fixture model expectation + AMSTensorMap node_store; + AMSTensor::IntDimType node_shape[] = {10, 16}; + AMSTensor::IntDimType node_strides[] = {16, 1}; + auto node_features = AMSTensor::create( + ams::ArrayRef(node_shape, 2), + ams::ArrayRef(node_strides, 2), + AMSResourceType::AMS_HOST); + + float* node_data = node_features.data(); + for (int i = 0; i < 160; ++i) { + node_data[i] = static_cast(i) * 0.1f; + } + + insertTensor(node_store, "x", std::move(node_features)); + graph.node_stores["node"] = std::move(node_store); + + // Add edge store with dummy data (empty dicts can cause TorchScript issues) + AMSTensorMap edge_store; + AMSTensor::IntDimType dummy_shape[] = {1, 1}; + AMSTensor::IntDimType dummy_strides[] = {1, 1}; + auto dummy_edge = AMSTensor::create( + ams::ArrayRef(dummy_shape, 2), + ams::ArrayRef(dummy_strides, 2), + AMSResourceType::AMS_HOST); + dummy_edge.data()[0] = 0.0f; + insertTensor(edge_store, "dummy", std::move(dummy_edge)); + EdgeType edge_type{"node", "edge", "node"}; + graph.edge_stores[edge_type] = std::move(edge_store); + + // Add global store with dummy data + AMSTensor::IntDimType global_shape[] = {1, 1}; + AMSTensor::IntDimType global_strides[] = {1, 1}; + auto dummy_global = AMSTensor::create( + ams::ArrayRef(global_shape, 2), + ams::ArrayRef(global_strides, 2), + AMSResourceType::AMS_HOST); + dummy_global.data()[0] = 0.0f; + insertTensor(graph.global_store, "dummy", std::move(dummy_global)); + + // Define callback (should NOT be called if surrogate succeeds) + bool callback_invoked = false; + HeterogeneousGraphDomainFn callback = + [&](const AMSHeterogeneousGraph& g, SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(g.containsNodeStore("node")); + const auto* node_store = g.findNodeStore("node"); + CATCH_REQUIRE(node_store != nullptr); + CATCH_REQUIRE(containsTensor(*node_store, "x")); + + // Create output tensor + AMSTensor::IntDimType out_shape[] = {10, 8}; + AMSTensor::IntDimType out_strides[] = {8, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(out_shape, 2), + ams::ArrayRef(out_strides, 2), + AMSResourceType::AMS_HOST); + + float* out_data = out_tensor.data(); + for (int i = 0; i < 80; ++i) { + out_data[i] = static_cast(i); + } + + outputs.clear(); + outputs.push_back(std::move(out_tensor)); + }; + + // Execute + SmallVector outs; + AMSExecute(executor, callback, graph, outs); + + // Model is available, so surrogate should be used (callback NOT invoked) + CATCH_REQUIRE_FALSE(callback_invoked); + CATCH_REQUIRE(outs.size() == 1); + CATCH_REQUIRE(outs[0].shape()[0] == 10); + CATCH_REQUIRE(outs[0].shape()[1] == 8); +} + +CATCH_TEST_CASE("Graph surrogate with no model triggers fallback", + "[wf][graph][surrogate][fallback]") +{ + AMSInit(); + + // Setup: Register model with empty path (no model available) + auto model = AMSRegisterAbstractModel( + "test_no_model", 0.5, "", false); + AMSExecutor executor = AMSCreateExecutor(model, 0, 1); + + // Create simple homogeneous graph + AMSHomogeneousGraph graph; + AMSTensor::IntDimType shape[] = {5, 16}; + AMSTensor::IntDimType strides[] = {16, 1}; + auto features = AMSTensor::create( + ams::ArrayRef(shape, 2), + ams::ArrayRef(strides, 2), + AMSResourceType::AMS_HOST); + + float* data = features.data(); + for (int i = 0; i < 80; ++i) { + data[i] = 1.0f; + } + + insertTensor(graph, "x", std::move(features)); + + // Define callback + bool callback_invoked = false; + HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + AMSTensor::IntDimType out_shape[] = {5, 8}; + AMSTensor::IntDimType out_strides[] = {8, 1}; + auto out = AMSTensor::create( + ams::ArrayRef(out_shape, 2), + ams::ArrayRef(out_strides, 2), + AMSResourceType::AMS_HOST); + + outputs.clear(); + outputs.push_back(std::move(out)); + }; + + // Execute + SmallVector outs; + AMSExecute(executor, callback, graph, outs); + + // Invalid model should trigger fallback + CATCH_REQUIRE(callback_invoked); + CATCH_REQUIRE(outs.size() == 1); +} diff --git a/tests/AMSlib/models/CMakeLists.txt b/tests/AMSlib/models/CMakeLists.txt index 0d58e4fd..6ce3ca6a 100644 --- a/tests/AMSlib/models/CMakeLists.txt +++ b/tests/AMSlib/models/CMakeLists.txt @@ -3,6 +3,7 @@ set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/generate.py" "${CMAKE_CURRENT_SOURCE_DIR}/generate_linear_model.py" "${CMAKE_CURRENT_SOURCE_DIR}/generate_base_models.py" + "${CMAKE_CURRENT_SOURCE_DIR}/generate_graph_models.py" ) # Where to drop the .pt files @@ -28,6 +29,8 @@ set(GENERATED_CPU_MODELS ${CMAKE_CURRENT_BINARY_DIR}/linear_scripted_single_cpu_duq_max.pt ${CMAKE_CURRENT_BINARY_DIR}/linear_scripted_single_cpu_duq_mean.pt ${CMAKE_CURRENT_BINARY_DIR}/linear_scripted_single_cpu_random.pt + ${CMAKE_CURRENT_BINARY_DIR}/homogeneous_graph.pt + ${CMAKE_CURRENT_BINARY_DIR}/heterogeneous_graph.pt ) if (WITH_CUDA OR WITH_HIP) @@ -61,6 +64,8 @@ set(GENERATOR_SCRIPTS "${CMAKE_CURRENT_SOURCE_DIR}/generate.py" "${CMAKE_CURRENT_SOURCE_DIR}/generate_linear_model.py" "${CMAKE_CURRENT_SOURCE_DIR}/generate_base_models.py" + "${CMAKE_CURRENT_SOURCE_DIR}/generate_graph_models.py" + "${CMAKE_CURRENT_SOURCE_DIR}/ams_model.py" ) # 1. Stamp missing -> regenerate diff --git a/tests/AMSlib/models/ams_model.py b/tests/AMSlib/models/ams_model.py index 24533b3c..4e195943 100644 --- a/tests/AMSlib/models/ams_model.py +++ b/tests/AMSlib/models/ams_model.py @@ -1,10 +1,15 @@ from pathlib import Path -from typing import Dict, Tuple +from typing import Any, Dict, Tuple import torch import torch.nn as nn from torch import Tensor + +# ============================================================================== +# Tensor model wrapper +# ============================================================================== + class AMSModel(nn.Module): _ams_dtype: torch.dtype _ams_device: torch.device @@ -56,6 +61,171 @@ def create_ams_model( return scripted +# ============================================================================== +# Graph model wrappers +# ============================================================================== + +class AMSHomogeneousGraphModel(nn.Module): + """AMS wrapper for homogeneous graph models. + + This wrapper exposes AMS metadata methods and forwards graph dictionaries + directly to the wrapped model without signature mismatch. + """ + ams_info: Dict[str, str] + + def __init__(self, model: nn.Module, dtype: torch.dtype, device: torch.device): + super().__init__() + self._model = model + + # Convert dtype to string + if dtype == torch.float32: + ams_dtype = "float32" + elif dtype == torch.float64: + ams_dtype = "float64" + else: + raise RuntimeError(f"AMS library does not support dtype {dtype}") + + # Device type as string + ams_device = device.type + + # Store in old-style format for compatibility + self.ams_info = {"ams_type": ams_dtype, "ams_device": ams_device} + + @torch.jit.export + def get_ams_info(self) -> Dict[str, str]: + return self.ams_info + + def forward(self, graph: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + """Forward pass for homogeneous graph. + + Args: + graph: Dict[str, Tensor] representing homogeneous graph + + Returns: + Tuple of (prediction, uncertainty) tensors + """ + return self._model(graph) + + +def create_ams_homogeneous_graph_model( + model: nn.Module, + device: torch.device, + precision: torch.dtype, +): + """Create AMS-wrapped homogeneous graph model. + + Args: + model: PyTorch model with forward(graph: Dict[str, Tensor]) -> Tuple[Tensor, Tensor] + device: Target device + precision: Target dtype + + Returns: + TorchScript scripted module ready for AMS + """ + if not isinstance(device, torch.device): + raise RuntimeError(f"Expected device to be torch.device, got {type(device)}") + + if not isinstance(precision, torch.dtype): + raise RuntimeError(f"Expected precision to be torch.dtype, got {type(precision)}") + + model = model.eval().to(device=device, dtype=precision) + + # Script the inner model + inner = torch.jit.script(model) + + # Wrap in AMS metadata wrapper + ams = AMSHomogeneousGraphModel(inner, precision, device) + + # Script the wrapper + scripted = torch.jit.script(ams) + return scripted + + +class AMSHeterogeneousGraphModel(nn.Module): + """AMS wrapper for heterogeneous graph models. + + This wrapper exposes AMS metadata methods and forwards heterogeneous graph + structures to the wrapped model. + + Note: The forward signature uses a generic Dict type to work around + TorchScript limitations with deeply nested dict structures. + """ + ams_info: Dict[str, str] + + def __init__(self, model: nn.Module, dtype: torch.dtype, device: torch.device): + super().__init__() + self._model = model + + # Convert dtype to string + if dtype == torch.float32: + ams_dtype = "float32" + elif dtype == torch.float64: + ams_dtype = "float64" + else: + raise RuntimeError(f"AMS library does not support dtype {dtype}") + + # Device type as string + ams_device = device.type + + # Store in old-style format for compatibility + self.ams_info = {"ams_type": ams_dtype, "ams_device": ams_device} + + @torch.jit.export + def get_ams_info(self) -> Dict[str, str]: + return self.ams_info + + def forward(self, graph: Dict[str, Any]) -> Tuple[Tensor, Tensor]: + """Forward pass for heterogeneous graph. + + Args: + graph: Nested dict with node_stores/edge_stores/global_store + Uses Dict[str, Any] to handle mixed value types at top level + + Returns: + Tuple of (prediction, uncertainty) tensors + """ + return self._model(graph) + + +def create_ams_heterogeneous_graph_model( + model: nn.Module, + device: torch.device, + precision: torch.dtype, +): + """Create AMS-wrapped heterogeneous graph model. + + Args: + model: PyTorch model with forward(graph: Dict[str, Any]) -> Tuple[Tensor, Tensor] + Uses Dict[str, Any] to handle mixed top-level value types + device: Target device + precision: Target dtype + + Returns: + TorchScript scripted module ready for AMS + """ + if not isinstance(device, torch.device): + raise RuntimeError(f"Expected device to be torch.device, got {type(device)}") + + if not isinstance(precision, torch.dtype): + raise RuntimeError(f"Expected precision to be torch.dtype, got {type(precision)}") + + model = model.eval().to(device=device, dtype=precision) + + # Script the inner model + inner = torch.jit.script(model) + + # Wrap in AMS metadata wrapper + ams = AMSHeterogeneousGraphModel(inner, precision, device) + + # Script the wrapper + scripted = torch.jit.script(ams) + return scripted + + +# ============================================================================== +# Legacy model wrapper (deprecated) +# ============================================================================== + class AMSModelOld(nn.Module): ams_info: Dict[str, str] diff --git a/tests/AMSlib/models/generate.sh b/tests/AMSlib/models/generate.sh index f6f1dbb4..409379fd 100755 --- a/tests/AMSlib/models/generate.sh +++ b/tests/AMSlib/models/generate.sh @@ -19,3 +19,5 @@ python ${root_dir}/generate_base_models.py --out-dir ${directory} python ${root_dir}/generate_linear_model.py ${directory} 8 9 +python ${root_dir}/generate_graph_models.py --out-dir ${directory} + diff --git a/tests/AMSlib/models/generate_graph_models.py b/tests/AMSlib/models/generate_graph_models.py new file mode 100644 index 00000000..5444c25c --- /dev/null +++ b/tests/AMSlib/models/generate_graph_models.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +"""Generate test TorchScript models for graph surrogate execution. + +This script creates minimal test models for homogeneous and heterogeneous graphs +that can be used to test the graph surrogate execution path in AMS. +""" + +import os +import sys +from pathlib import Path +from typing import Any, Dict + +import torch +import torch.nn as nn + +# Add current directory to path to import ams_model helper +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(CURRENT_DIR) +from ams_model import ( + create_ams_homogeneous_graph_model, + create_ams_heterogeneous_graph_model, +) + + +class HomogeneousGraphModel(nn.Module): + """Simple model for homogeneous graphs. + + Accepts a Dict[str, Tensor] representing a homogeneous graph. + Reads the 'x' field (node features) and applies a linear transformation. + Returns (prediction, uncertainty) tuple where uncertainty is fixed low value. + """ + + def __init__(self): + super().__init__() + self.linear = nn.Linear(16, 8) + + def forward( + self, graph: Dict[str, torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + # Read 'x' field from graph (node features) + x = graph['x'] + + # Simple prediction: linear transform + prediction = self.linear(x) + + # Low fixed uncertainty (always accept for testing) + uncertainty = torch.full( + (x.shape[0], 1), 0.01, dtype=x.dtype, device=x.device + ) + + return prediction, uncertainty + + +class HeterogeneousGraphModel(nn.Module): + """Simple test model for heterogeneous graphs. + + This is a narrow test fixture designed to be TorchScript-scriptable. + It expects a specific node store named "node" with an 'x' field. + + Input structure: + { + 'node_stores': {'node': Dict[str, Tensor], ...}, + 'edge_stores': {...}, + 'global_store': Dict[str, Tensor] + } + + Reads the 'x' field from the 'node' node store and applies transformation. + Returns (prediction, uncertainty) tuple where uncertainty is fixed low value. + """ + + def __init__(self): + super().__init__() + self.linear = nn.Linear(16, 8) + + def forward( + self, graph: Dict[str, Any] + ) -> tuple[torch.Tensor, torch.Tensor]: + # Extract node_stores with proper type recovery for TorchScript + # The top-level dict has mixed types (node_stores/edge_stores are dicts, + # global_store is also a dict but with different structure) + # Use Any and isinstance to work around TorchScript limitations + node_stores_any = graph['node_stores'] + + # Recover proper type using torch.jit.isinstance + assert torch.jit.isinstance(node_stores_any, Dict[str, Dict[str, torch.Tensor]]) + node_stores = torch.jit.annotate(Dict[str, Dict[str, torch.Tensor]], node_stores_any) + + # Use fixed node store name "node" (test fixture, not generic) + node_store = node_stores['node'] + x = node_store['x'] + + # Simple prediction + prediction = self.linear(x) + + # Low fixed uncertainty + uncertainty = torch.full( + (x.shape[0], 1), 0.01, dtype=x.dtype, device=x.device + ) + + return prediction, uncertainty + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Generate test graph models for AMS graph surrogate execution" + ) + parser.add_argument( + "--out-dir", + type=Path, + required=True, + help="Output directory where models will be written", + ) + args = parser.parse_args() + + out_dir = args.out_dir.resolve() + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cpu") + dtype = torch.float32 + + # Generate homogeneous graph model + print("[info] Generating homogeneous graph model...") + homo_model = HomogeneousGraphModel().to(device=device, dtype=dtype) + homo_wrapped = create_ams_homogeneous_graph_model(homo_model, device, dtype) + homo_path = out_dir / "homogeneous_graph.pt" + homo_wrapped.save(str(homo_path)) + print(f"[info] Saved homogeneous graph model: {homo_path}") + + # Generate heterogeneous graph model + print("[info] Generating heterogeneous graph model...") + hetero_model = HeterogeneousGraphModel().to(device=device, dtype=dtype) + hetero_wrapped = create_ams_heterogeneous_graph_model(hetero_model, device, dtype) + hetero_path = out_dir / "heterogeneous_graph.pt" + hetero_wrapped.save(str(hetero_path)) + print(f"[info] Saved heterogeneous graph model: {hetero_path}") + + print("[info] Done generating graph models") + + +if __name__ == "__main__": + main() From 7ba9e071d01dd1c84bca7402890ce0235de81c19 Mon Sep 17 00:00:00 2001 From: Yohann Date: Fri, 24 Apr 2026 14:10:52 -0700 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/AMSlib/ml/surrogate.hpp | 5 +- src/AMSlib/wf/interface.cpp | 3 +- .../ams_interface/test_graph_surrogate.cpp | 81 ++++++++++--------- 3 files changed, 48 insertions(+), 41 deletions(-) diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp index c9cf07a1..e9cb326a 100644 --- a/src/AMSlib/ml/surrogate.hpp +++ b/src/AMSlib/ml/surrogate.hpp @@ -26,7 +26,8 @@ #include "wf/debug.h" // Forward declarations for graph surrogate friend functions -namespace ams { +namespace ams +{ class AMSWorkflow; bool tryGraphSurrogate(AMSWorkflow*, const AMSHomogeneousGraph&, @@ -34,7 +35,7 @@ bool tryGraphSurrogate(AMSWorkflow*, bool tryGraphSurrogate(AMSWorkflow*, const AMSHeterogeneousGraph&, SmallVector&); -} +} // namespace ams //! ---------------------------------------------------------------------------- //! An implementation for a surrogate model diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp index b6f60db9..53dc3625 100644 --- a/src/AMSlib/wf/interface.cpp +++ b/src/AMSlib/wf/interface.cpp @@ -382,7 +382,8 @@ void callApplication(ams::HeterogeneousGraphDomainFn CallBack, // Graph surrogate execution (in ams namespace for friend access) // ============================================================================ -namespace ams { +namespace ams +{ bool tryGraphSurrogate(AMSWorkflow* executor, const AMSHomogeneousGraph& graph, diff --git a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp index b23b55d4..24249d94 100644 --- a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp +++ b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp @@ -7,8 +7,10 @@ using namespace ams; // Paths to graph test models (relative to test executable) -static const char* HOMOGENEOUS_GRAPH_MODEL_PATH = "../models/homogeneous_graph.pt"; -static const char* HETEROGENEOUS_GRAPH_MODEL_PATH = "../models/heterogeneous_graph.pt"; +static const char* HOMOGENEOUS_GRAPH_MODEL_PATH = + "../models/homogeneous_graph.pt"; +static const char* HETEROGENEOUS_GRAPH_MODEL_PATH = + "../models/heterogeneous_graph.pt"; CATCH_TEST_CASE("AMSExecute homogeneous graph surrogate execution", "[wf][graph][surrogate]") @@ -16,8 +18,10 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph surrogate execution", AMSInit(); // Setup: Register model with actual generated model path - auto model = AMSRegisterAbstractModel("test_homo_surrogate", 0.5, - HOMOGENEOUS_GRAPH_MODEL_PATH, false); + auto model = AMSRegisterAbstractModel("test_homo_surrogate", + 0.5, + HOMOGENEOUS_GRAPH_MODEL_PATH, + false); AMSExecutor executor = AMSCreateExecutor(model, 0, 1); // Create simple homogeneous graph with 'x' field (expected by test model) @@ -87,8 +91,10 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution", AMSInit(); // Setup: Register model with actual generated model path - auto model = AMSRegisterAbstractModel("test_hetero_surrogate", 0.5, - HETEROGENEOUS_GRAPH_MODEL_PATH, false); + auto model = AMSRegisterAbstractModel("test_hetero_surrogate", + 0.5, + HETEROGENEOUS_GRAPH_MODEL_PATH, + false); AMSExecutor executor = AMSCreateExecutor(model, 0, 1); // Create heterogeneous graph @@ -137,32 +143,32 @@ CATCH_TEST_CASE("AMSExecute heterogeneous graph surrogate execution", // Define callback (should NOT be called if surrogate succeeds) bool callback_invoked = false; - HeterogeneousGraphDomainFn callback = - [&](const AMSHeterogeneousGraph& g, SmallVector& outputs) { - callback_invoked = true; - - // Verify graph structure - CATCH_REQUIRE(g.containsNodeStore("node")); - const auto* node_store = g.findNodeStore("node"); - CATCH_REQUIRE(node_store != nullptr); - CATCH_REQUIRE(containsTensor(*node_store, "x")); - - // Create output tensor - AMSTensor::IntDimType out_shape[] = {10, 8}; - AMSTensor::IntDimType out_strides[] = {8, 1}; - auto out_tensor = AMSTensor::create( - ams::ArrayRef(out_shape, 2), - ams::ArrayRef(out_strides, 2), - AMSResourceType::AMS_HOST); - - float* out_data = out_tensor.data(); - for (int i = 0; i < 80; ++i) { - out_data[i] = static_cast(i); - } - - outputs.clear(); - outputs.push_back(std::move(out_tensor)); - }; + HeterogeneousGraphDomainFn callback = [&](const AMSHeterogeneousGraph& g, + SmallVector& outputs) { + callback_invoked = true; + + // Verify graph structure + CATCH_REQUIRE(g.containsNodeStore("node")); + const auto* node_store = g.findNodeStore("node"); + CATCH_REQUIRE(node_store != nullptr); + CATCH_REQUIRE(containsTensor(*node_store, "x")); + + // Create output tensor + AMSTensor::IntDimType out_shape[] = {10, 8}; + AMSTensor::IntDimType out_strides[] = {8, 1}; + auto out_tensor = AMSTensor::create( + ams::ArrayRef(out_shape, 2), + ams::ArrayRef(out_strides, 2), + AMSResourceType::AMS_HOST); + + float* out_data = out_tensor.data(); + for (int i = 0; i < 80; ++i) { + out_data[i] = static_cast(i); + } + + outputs.clear(); + outputs.push_back(std::move(out_tensor)); + }; // Execute SmallVector outs; @@ -181,18 +187,17 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback", AMSInit(); // Setup: Register model with empty path (no model available) - auto model = AMSRegisterAbstractModel( - "test_no_model", 0.5, "", false); + auto model = AMSRegisterAbstractModel("test_no_model", 0.5, "", false); AMSExecutor executor = AMSCreateExecutor(model, 0, 1); // Create simple homogeneous graph AMSHomogeneousGraph graph; AMSTensor::IntDimType shape[] = {5, 16}; AMSTensor::IntDimType strides[] = {16, 1}; - auto features = AMSTensor::create( - ams::ArrayRef(shape, 2), - ams::ArrayRef(strides, 2), - AMSResourceType::AMS_HOST); + auto features = + AMSTensor::create(ams::ArrayRef(shape, 2), + ams::ArrayRef(strides, 2), + AMSResourceType::AMS_HOST); float* data = features.data(); for (int i = 0; i < 80; ++i) {