Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/AMSlib/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,31 @@
#include "ArrayRef.hpp"
#include "wf/debug.h"

// Forward declarations for graph surrogate friend functions
namespace ams
{
class AMSWorkflow;
bool tryGraphSurrogate(AMSWorkflow*,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally I would like to avoid AMSWorkflow spilling into the surrogate part of the code. Can we move these to workflow.hpp?

const AMSHomogeneousGraph&,
SmallVector<AMSTensor>&);
bool tryGraphSurrogate(AMSWorkflow*,
const AMSHeterogeneousGraph&,
SmallVector<AMSTensor>&);
} // namespace ams

//! ----------------------------------------------------------------------------
//! An implementation for a surrogate model
//! ----------------------------------------------------------------------------
class SurrogateModel
{
// Friend declarations for graph surrogate execution
// Note: These are defined in interface.cpp within the ams namespace
friend bool ams::tryGraphSurrogate(ams::AMSWorkflow*,
const ams::AMSHomogeneousGraph&,
ams::SmallVector<ams::AMSTensor>&);
friend bool ams::tryGraphSurrogate(ams::AMSWorkflow*,
const ams::AMSHeterogeneousGraph&,
ams::SmallVector<ams::AMSTensor>&);

private:
const std::string _model_path;
Expand Down
121 changes: 81 additions & 40 deletions src/AMSlib/wf/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,28 +292,30 @@ static ams::AMSHeterogeneousGraph torchToAMSHeterogeneousGraph(
return out;
}

static c10::impl::GenericDict amsNodeStoresToTorchDict(
static c10::Dict<std::string, c10::Dict<std::string, torch::Tensor>>
amsNodeStoresToTorchDict(
const std::unordered_map<std::string, ams::AMSTensorMap>& node_stores)
{
c10::impl::GenericDict out(c10::StringType::get(), c10::AnyType::get());
c10::Dict<std::string, c10::Dict<std::string, torch::Tensor>> out;

for (const auto& [store_name, store] : node_stores) {
out.insert(store_name, c10::IValue(amsTensorMapToTorchDict(store)));
out.insert(store_name, amsTensorMapToTorchDict(store));
}

return out;
}

static c10::impl::GenericDict amsEdgeStoresToTorchDict(
static c10::Dict<std::string, c10::Dict<std::string, torch::Tensor>>
amsEdgeStoresToTorchDict(
const std::unordered_map<ams::EdgeType,
ams::AMSTensorMap,
ams::EdgeTypeHash>& edge_stores)
{
c10::impl::GenericDict out(c10::StringType::get(), c10::AnyType::get());
c10::Dict<std::string, c10::Dict<std::string, torch::Tensor>> out;

for (const auto& [edge_type, store] : edge_stores) {
out.insert(ams::edgeTypeToString(edge_type),
c10::IValue(amsTensorMapToTorchDict(store)));
amsTensorMapToTorchDict(store));
}

return out;
Expand All @@ -324,12 +326,9 @@ static c10::impl::GenericDict amsToTorchHeterogeneousGraph(
{
c10::impl::GenericDict out(c10::StringType::get(), c10::AnyType::get());

out.insert("node_stores",
c10::IValue(amsNodeStoresToTorchDict(g.node_stores)));
out.insert("edge_stores",
c10::IValue(amsEdgeStoresToTorchDict(g.edge_stores)));
out.insert("global_store",
c10::IValue(amsTensorMapToTorchDict(g.global_store)));
out.insert("node_stores", amsNodeStoresToTorchDict(g.node_stores));
out.insert("edge_stores", amsEdgeStoresToTorchDict(g.edge_stores));
out.insert("global_store", amsTensorMapToTorchDict(g.global_store));

return out;
}
Expand Down Expand Up @@ -380,42 +379,84 @@ void callApplication(ams::HeterogeneousGraphDomainFn CallBack,
}

// ============================================================================
// Graph surrogate execution stub (seam for future implementation)
// Graph surrogate execution (in ams namespace for friend access)
// ============================================================================

static bool tryGraphSurrogate(ams::AMSWorkflow* executor,
const ams::AMSHomogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs)
namespace ams
{

bool tryGraphSurrogate(AMSWorkflow* executor,
const AMSHomogeneousGraph& graph,
SmallVector<AMSTensor>& outs)
{
// TODO: Implement graph surrogate execution when models support graphs
// This is the integration point for future graph-based ML inference
//
// Future implementation should:
// 1. Check if executor has a model that accepts graph inputs
// 2. Convert AMSHomogeneousGraph to model input format
// 3. Run model inference and uncertainty quantification
// 4. If UQ passes threshold, populate outs and return true
// 5. Otherwise return false to trigger fallback
//
// For now, always return false (no surrogate available)
(void)executor;
(void)graph;
(void)outs;
return false;
// Check if model is available
if (!executor || !executor->MLModel) {
return false;
}

try {
// Convert AMS graph → Torch Dict[str, Tensor]
auto torch_graph = amsToTorchHomogeneousGraph(graph);

// Call model forward pass
std::vector<torch::jit::IValue> inputs = {torch::jit::IValue(torch_graph)};
auto result = executor->MLModel->module.forward(inputs);

// Extract prediction from tuple [prediction, uncertainty]
auto result_tuple = result.toTuple();
if (result_tuple->elements().size() < 1) {
return false;
}
torch::Tensor prediction = result_tuple->elements()[0].toTensor();

// Convert prediction → AMSTensor and populate outs
outs.clear();
outs.push_back(torchToAMSTensorView(prediction));

return true;
} catch (const std::exception& e) {
// Model invocation failed - fallback to physics
return false;
}
}

static bool tryGraphSurrogate(ams::AMSWorkflow* executor,
const ams::AMSHeterogeneousGraph& graph,
ams::SmallVector<ams::AMSTensor>& outs)
bool tryGraphSurrogate(AMSWorkflow* executor,
const AMSHeterogeneousGraph& graph,
SmallVector<AMSTensor>& outs)
{
// TODO: Implement graph surrogate execution when models support graphs
// See homogeneous version for implementation notes
(void)executor;
(void)graph;
(void)outs;
return false;
// Check if model is available
if (!executor || !executor->MLModel) {
return false;
}

try {
// Convert AMS graph → Torch GenericDict
auto torch_graph = amsToTorchHeterogeneousGraph(graph);

// Call model forward pass
std::vector<torch::jit::IValue> inputs = {torch::jit::IValue(torch_graph)};
auto result = executor->MLModel->module.forward(inputs);

// Extract prediction from tuple [prediction, uncertainty]
auto result_tuple = result.toTuple();
if (result_tuple->elements().size() < 1) {
return false;
}
torch::Tensor prediction = result_tuple->elements()[0].toTensor();

// Convert prediction → AMSTensor and populate outs
outs.clear();
outs.push_back(torchToAMSTensorView(prediction));

return true;
} catch (const std::exception& e) {
// Model invocation failed - fallback to physics
return false;
}
}

} // namespace ams

// ============================================================================
// Graph-based callAMS overloads
// ============================================================================
Expand Down
7 changes: 7 additions & 0 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ namespace ams
{
class AMSWorkflow
{
// Friend declarations for graph surrogate execution access to private MLModel
friend bool tryGraphSurrogate(AMSWorkflow*,
const AMSHomogeneousGraph&,
SmallVector<AMSTensor>&);
friend bool tryGraphSurrogate(AMSWorkflow*,
const AMSHeterogeneousGraph&,
SmallVector<AMSTensor>&);

/** @brief A string identifier describing the domain-model being solved. */
std::string domainName;
Expand Down
3 changes: 3 additions & 0 deletions tests/AMSlib/ams_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ ADD_AMS_UNIT_TEST(AMS_INT_INTERFACE int_interface)
BUILD_UNIT_TEST(ams_graph_fallback test_graph_fallback.cpp)
ADD_AMS_UNIT_TEST(AMS_GRAPH_FALLBACK ams_graph_fallback)

BUILD_UNIT_TEST(ams_graph_surrogate test_graph_surrogate.cpp)
ADD_AMS_UNIT_TEST(AMS_GRAPH_SURROGATE ams_graph_surrogate)

Loading
Loading