From 5258f59058644bcda3531324db024b7adc7569f9 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:24:16 -0500 Subject: [PATCH 1/9] util Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- .../datasets/hetero_powergrid_datamodule.py | 104 ++++++++++++------ 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py index 4ac0125..a168012 100644 --- a/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py +++ b/gridfm_graphkit/datasets/hetero_powergrid_datamodule.py @@ -12,6 +12,7 @@ from gridfm_graphkit.datasets.utils import ( split_dataset, split_dataset_by_load_scenario_idx, + split_from_existing_files, ) from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk import numpy as np @@ -19,6 +20,7 @@ import warnings import os import lightning as L +from pathlib import Path from typing import List from lightning.pytorch.loggers import MLFlowLogger @@ -96,6 +98,11 @@ def __init__( "split_by_load_scenario_idx", False, ) + self.split_from_existing_files = getattr( + args.data, + "split_from_existing_files", + None, + ) self.args = args self.normalizer_stats_path = normalizer_stats_path self.data_normalizers = [] @@ -108,6 +115,15 @@ def __init__( self.test_scenario_ids: List[List[int]] = [] self._is_setup_done = False + if self.split_by_load_scenario_idx: + assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both" + + if self.split_from_existing_files is not None: + assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format" + self.split_from_existing_files = Path(self.split_from_existing_files) + assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format" + + def setup(self, stage: str): if self._is_setup_done: print(f"Setup already done for stage={stage}, skipping...") @@ -161,49 +177,64 @@ def setup(self, stage: str): # Create a subset all_indices = list(range(len(dataset))) - # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order - random.seed(self.args.seed) - random.shuffle(all_indices) - subset_indices = all_indices[:num_scenarios] + splits_dir = Path(data_path_network) + splits_dir = splits_dir / "splits" - # load_scenario for each scenario in the subset - load_scenarios = dataset.load_scenarios[subset_indices] - - dataset = Subset(dataset, subset_indices) - - # Random seed set before every split, same as above - np.random.seed(self.args.seed) - if self.split_by_load_scenario_idx: - train_dataset, val_dataset, test_dataset = ( - split_dataset_by_load_scenario_idx( + if self.split_from_existing_files is not None: + (train_dataset, val_dataset, test_dataset), subset_indices = ( + split_from_existing_files( + dataset, + splits_dir, + self.split_from_existing_files, + ) + ) + train_scenario_ids = subset_indices["train"] + val_scenario_ids = subset_indices["val"] + test_scenario_ids = subset_indices["test"] + else: + # Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order + random.seed(self.args.seed) + random.shuffle(all_indices) + subset_indices = all_indices[:num_scenarios] + + # load_scenario for each scenario in the subset + load_scenarios = dataset.load_scenarios[subset_indices] + + dataset = Subset(dataset, subset_indices) + + # Random seed set before every split, same as above + np.random.seed(self.args.seed) + if self.split_by_load_scenario_idx: + train_dataset, val_dataset, test_dataset = ( + split_dataset_by_load_scenario_idx( + dataset, + self.data_dir, + load_scenarios, + self.args.data.val_ratio, + self.args.data.test_ratio, + ) + ) + else: + train_dataset, val_dataset, test_dataset = split_dataset( dataset, self.data_dir, - load_scenarios, self.args.data.val_ratio, self.args.data.test_ratio, ) + + # Extract scenario IDs for each split + train_scenario_ids = self._extract_scenario_ids( + train_dataset, + subset_indices, ) - else: - train_dataset, val_dataset, test_dataset = split_dataset( - dataset, - self.data_dir, - self.args.data.val_ratio, - self.args.data.test_ratio, + val_scenario_ids = self._extract_scenario_ids( + val_dataset, + subset_indices, + ) + test_scenario_ids = self._extract_scenario_ids( + test_dataset, + subset_indices, ) - - # Extract scenario IDs for each split - train_scenario_ids = self._extract_scenario_ids( - train_dataset, - subset_indices, - ) - val_scenario_ids = self._extract_scenario_ids( - val_dataset, - subset_indices, - ) - test_scenario_ids = self._extract_scenario_ids( - test_dataset, - subset_indices, - ) # Fit normalizer: restore from saved stats only for fit_on_train # normalizers (global baseMVA must match the model's training run). @@ -214,7 +245,8 @@ def setup(self, stage: str): and network in saved_stats and data_normalizer.fit_strategy == "fit_on_train" ) - if use_saved: + # if use_saved: + if False: print(f"Restoring normalizer for {network} from saved stats") data_normalizer.fit_from_dict(saved_stats[network]) else: From 58e7acd17690b3bb9e4c24b53fdfddaf9cfbc2eb Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:24:26 -0500 Subject: [PATCH 2/9] util Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/datasets/utils.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/gridfm_graphkit/datasets/utils.py b/gridfm_graphkit/datasets/utils.py index f330d49..1e73097 100644 --- a/gridfm_graphkit/datasets/utils.py +++ b/gridfm_graphkit/datasets/utils.py @@ -90,3 +90,26 @@ def split_dataset_by_load_scenario_idx( test_dataset = Subset(dataset, test_indices) return train_dataset, val_dataset, test_dataset + + +def split_from_existing_files( + dataset, + save_to_folder: str, + splits_folder: str, +) -> Tuple[Subset, Subset, Subset]: + output=[] + + indices = {} + + for split in ["train", "val", "test"]: + split_file = splits_folder / f"{split}.pt" + assert split_file.is_file(), f"{str(split_file)} does not exist" + split_indices = torch.load(str(split_file), weights_only=True) + split_dataset = Subset(dataset, split_indices) + output.append(split_dataset) + split_indices = list(split_indices) + print(f'{split=} {len(split_indices)=}') + indices[split]=[int(t.item()) for t in split_indices] + + output = tuple(output) + return output, indices \ No newline at end of file From 9c12f560e48e6ac38dbd5c0aafd0fcd5e8215f77 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:24:47 -0500 Subject: [PATCH 3/9] configurable optimizer Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/models/__init__.py | 7 +++++++ gridfm_graphkit/tasks/base_task.py | 30 ++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index f824535..91ab717 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,4 +1,8 @@ from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous +from gridfm_graphkit.models.fcnn import FullyConnectedNN +from gridfm_graphkit.models.gnn_heterogeneous import HeterogeneousGNN +# from gridfm_graphkit.models.gnn_homogeneous import HomogeneousGNN + from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, PhysicsDecoderPF, @@ -7,6 +11,9 @@ __all__ = [ "GNS_heterogeneous", + "FullyConnectedNN", + "HeterogeneousGNN", + # "HomogeneousGNN", "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", diff --git a/gridfm_graphkit/tasks/base_task.py b/gridfm_graphkit/tasks/base_task.py index ec75ccc..fe5b682 100644 --- a/gridfm_graphkit/tasks/base_task.py +++ b/gridfm_graphkit/tasks/base_task.py @@ -5,6 +5,7 @@ from lightning.pytorch.loggers import MLFlowLogger import torch from torch.optim.lr_scheduler import ReduceLROnPlateau +from collections.abc import Mapping class BaseTask(L.LightningModule, ABC): @@ -70,23 +71,36 @@ def on_fit_start(self): stats_dict[self.args.data.networks[i]] = normalizer.get_stats() torch.save(stats_dict, os.path.join(log_dir, "normalizer_stats.pt")) + def configure_optimizers(self): - self.optimizer = torch.optim.AdamW( + if self.args.optimizer.type is None: + self.args.optimizer.type = "Adam" + optimizer = getattr(torch.optim, self.args.optimizer.type) + print(f'{self.args.optimizer.optimizer_params=}') + if not isinstance(self.args.optimizer.optimizer_params, Mapping): + self.args.optimizer.optimizer_params = self.args.optimizer.optimizer_params.to_dict() + self.optimizer = optimizer( self.model.parameters(), lr=self.args.optimizer.learning_rate, - betas=(self.args.optimizer.beta1, self.args.optimizer.beta2), + **self.args.optimizer.optimizer_params, #unpack all other optim parameters ) - self.scheduler = ReduceLROnPlateau( + scheduler_type = getattr(self.args.optimizer, "scheduler_type", None) + if scheduler_type is None: + return {"optimizer": self.optimizer} + + #TODO: add interval handling for scheduler + scheduler = getattr(torch.optim.lr_scheduler, scheduler_type) + if not isinstance(self.args.optimizer.scheduler_params, Mapping): + self.args.optimizer.scheduler_params = self.args.optimizer.scheduler_params.to_dict() + self.scheduler = scheduler( self.optimizer, - mode="min", - factor=self.args.optimizer.lr_decay, - patience=self.args.optimizer.lr_patience, + **self.args.optimizer.scheduler_params ) - return { + config_optim = { "optimizer": self.optimizer, "lr_scheduler": { "scheduler": self.scheduler, "monitor": "Validation loss", - "reduce_on_plateau": True, }, } + return config_optim From 00031c3f96c8535a6ed0e44cfa1091e074863c19 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:25:05 -0500 Subject: [PATCH 4/9] new losses Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 238 +++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index d253d2b..7942ba3 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -19,6 +19,12 @@ PG_OUT, # Generator feature indices PG_H, + C0_H, + C1_H, + C2_H, + # Qg Limits + MIN_QG_H, + MAX_QG_H, ) @@ -36,6 +42,8 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, + batch_dict=None, ): """ Compute the loss. @@ -72,6 +80,8 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, + batch_dict=None, ): loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) return {"loss": loss, "Masked MSE loss": loss.detach()} @@ -91,6 +101,8 @@ def forward( edge_attr, mask_dict, model=None, + x_dict=None, + batch_dict=None, ): loss = F.mse_loss( pred_dict["gen"][mask_dict["gen"][:, : (PG_H + 1)]], @@ -115,6 +127,8 @@ def forward( edge_attr, mask_dict, model=None, + x_dict=None, + batch_dict=None, ): if self.args.task == "OptimalPowerFlow": pred_cols = [VM_OUT, VA_OUT, QG_OUT] @@ -152,6 +166,8 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, + batch_dict=None, ): loss = F.mse_loss(pred, target, reduction=self.reduction) return {"loss": loss, "MSE loss": loss.detach()} @@ -185,6 +201,8 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, + batch_dict=None, ): """ Compute the weighted sum of all specified losses. @@ -211,6 +229,8 @@ def forward( edge_attr, mask, model, + x_dict, + batch_dict, ) # Assume each loss function returns a dictionary with a "loss" key @@ -241,6 +261,8 @@ def forward( edge_attr=None, mask=None, model=None, + x_dict=None, + batch_dict=None, ): total_loss = 0.0 loss_details = {} @@ -291,6 +313,8 @@ def forward( edge_attr, mask_dict, model=None, + x_dict=None, + batch_dict=None, ): if self.dim == "VM": temp_pred = pred_dict["bus"][:, VM_OUT] @@ -322,3 +346,217 @@ def forward( f"MSE loss {self.dim}": mse_loss.detach(), f"MAE loss {self.dim}": mae_loss.detach(), } + + +@LOSS_REGISTRY.register("QgViolationPenalty") +class QgViolationPenaltyLoss(BaseLoss): + """Standard Mean Squared Error loss.""" + + def __init__(self, loss_args, args): + super().__init__() + + def forward( + self, + pred, + target, + edge_index=None, + edge_attr=None, + mask=None, + model=None, + x_dict=None, + batch_dict=None, + ): + # --- Qg limit violation mask --- + Qg_pred = pred["bus"][:, QG_OUT] + Qg_max = x_dict["bus"][:, MAX_QG_H] + Qg_min = x_dict["bus"][:, MIN_QG_H] + + max_penalty_mask = (Qg_pred > Qg_max) + min_penalty_mask = (Qg_pred < Qg_min) + + mask_PQ = mask["PQ"] # PQ buses + mask_PV = mask["PV"] # PV buses + mask_REF = mask["REF"] # Reference buses + + loss = 0.0 + # where there are violations, compute penalty loss + Qg_over = F.relu(Qg_pred - Qg_max) # amount above max limit + Qg_under = F.relu(Qg_min - Qg_pred) # amount below min limit + + Qg_over = Qg_over[max_penalty_mask].mean() + Qg_under = Qg_under[min_penalty_mask].mean() + + if Qg_over!=Qg_over: # replacing nan with 0 + Qg_over = 0.0 + if Qg_under!=Qg_under: # replacing nan with 0 + Qg_under = 0.0 + + penalty_loss = Qg_over + Qg_under + loss += penalty_loss + + try: + output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} + except: + output = {"loss": loss, "Qg Violation Penalty loss": loss} + + return output + + + + +@LOSS_REGISTRY.register("QgViolationBarrier") +class QgViolationBarrierLoss(BaseLoss): + """ + QgViolation Barrier loss function. + * https://en.wikipedia.org/wiki/Barrier_function + Available barrier functions are defined in the self.barriers dictionary. + References for relaxed barrier functions: + * https://arxiv.org/abs/1602.01321 + * https://arxiv.org/abs/1904.04205v2 + * https://ieeexplore.ieee.org/document/7493643/ + + Modified from https://github.com/pnnl/neuromancer + Copyright © 2021, Battelle Memorial Institute + https://github.com/pnnl/neuromancer/blob/master/LICENSE.md + """ + def __init__(self, loss_args, args): + super().__init__() + + self.barrier_name = getattr(loss_args, "barrier", 'log10') + self.shift = getattr(loss_args, "shift", 1) + self.alpha = getattr(loss_args, "alpha", 0.5) + self.upper_bound = getattr(loss_args, "upper_bound", 1.0) + + # choices of barrier functions + # warning: log10, log, inverse, and softlog might get numerically unstable + # softexp is numerically stable and thus a prefered option + self.barriers = { + 'log10': lambda value: -torch.log10(-value), + 'log': lambda value: -torch.log(-value), + 'inverse': lambda value: 1 / (-value), + 'softexp': lambda value: (torch.exp(self.alpha * value) - 1) / self.alpha + self.alpha, + 'softlog': lambda value: -torch.log(1 + self.alpha * (-value - self.alpha)) / self.alpha, + 'expshift': lambda value: torch.exp(value + self.shift) + } + self.barrier = self._set_barrier() + + def _set_barrier(self): + if self.barrier_name in self.barriers: + return self.barriers[self.barrier_name] + else: + assert callable(barrier), \ + f'The barrier, {barrier} must be a key in {self.barriers} or a callable.' + return barrier + + def forward( + self, + pred, + target, + edge_index=None, + edge_attr=None, + mask=None, + model=None, + x_dict=None, + batch_dict=None, + ): + """ + Calculate the magnitudes of constraint violations via log barriers + cviolation > 0 -> penalty (i.e. beyond constraints) + cviolation <= 0 -> barrier (i.e. within constraints) + """ + + # --- Qg limit violation mask --- + Qg_pred = pred["bus"][:, QG_OUT] + Qg_max = x_dict["bus"][:, MAX_QG_H] + Qg_min = x_dict["bus"][:, MIN_QG_H] + + max_penalty_mask = (Qg_pred > Qg_max) + min_penalty_mask = (Qg_pred < Qg_min) + + mask_PQ = mask["PQ"] # PQ buses + mask_PV = mask["PV"] # PV buses + mask_REF = mask["REF"] # Reference buses + + loss = 0.0 + + if max_penalty_mask.any() or min_penalty_mask.any(): + # where there are violations, compute penalty loss + Qg_over = F.relu(Qg_pred - Qg_max) # amount above max limit + Qg_under = F.relu(Qg_min - Qg_pred) # amount below min limit + + Qg_over = Qg_over[max_penalty_mask].mean() + Qg_under = Qg_under[min_penalty_mask].mean() + + if Qg_over!=Qg_over: # replacing nan with 0 + Qg_over = 0.0 + if Qg_under!=Qg_under: # replacing nan with 0 + Qg_under = 0.0 + + penalty_loss = Qg_over + Qg_under + loss += penalty_loss + + if (~max_penalty_mask).any() or (~min_penalty_mask).any(): + Qg_barrier_amount_max = Qg_pred - Qg_max + Qg_barrier_amount_min = Qg_min - Qg_pred + + cbarrier_max = self.barrier(Qg_barrier_amount_max) + cbarrier_max[cbarrier_max != cbarrier_max] = 0.0 # replacing nan with 0 -> infeasibility + cbarrier_max[cbarrier_max == float("Inf")] = 0.0 # replacing inf with 0 -> active constraints + cbarrier_max = torch.clamp(cbarrier_max, min=0.0, max=self.upper_bound) + barrier_loss_max = cbarrier_max[~max_penalty_mask].mean() + + cbarrier_min = self.barrier(Qg_barrier_amount_min) + cbarrier_min[cbarrier_min != cbarrier_min] = 0.0 # replacing nan with 0 -> infeasibility + cbarrier_min[cbarrier_min == float("Inf")] = 0.0 # replacing inf with 0 -> active constraints + cbarrier_min = torch.clamp(cbarrier_min, min=0.0, max=self.upper_bound) + barrier_loss_min = cbarrier_min[~min_penalty_mask].mean() + + barrier_loss = barrier_loss_min + barrier_loss_max + loss+=barrier_loss + + return {"loss": loss, "Qg Violation Barrier loss": loss.detach()} + + + +@LOSS_REGISTRY.register("OptimalityLoss") +class OptimalityLoss(BaseLoss): + """ + """ + + def __init__(self, loss_args, args): + super().__init__() + + def forward( + self, + pred, + target, + edge_index=None, + edge_attr=None, + mask=None, + model=None, + x_dict=None, + batch_dict=None, + ): + c0 = x_dict["gen"][:, C0_H] + c1 = x_dict["gen"][:, C1_H] + c2 = x_dict["gen"][:, C2_H] + target_pg = target["gen"].squeeze() + pred_pg = pred["gen"].squeeze() + gen_cost_gt = c0 + c1 * target_pg + c2 * target_pg**2 + gen_cost_pred = c0 + c1 * pred_pg + c2 * pred_pg**2 + + gen_batch = batch_dict["gen"] # shape: [N_gen_total] + + cost_gt = scatter_add(gen_cost_gt, gen_batch, dim=0) + cost_pred = scatter_add(gen_cost_pred, gen_batch, dim=0) + + loss = torch.mean(torch.abs((cost_pred - cost_gt) / cost_gt * 100)) + if loss!=loss: + loss=0.0 + + try: + output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} + except: + output = {"loss": loss, "Qg Violation Penalty loss": loss} + + return output From 65009df5383abfd224f02f2bae5b5ea6964f0c0f Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:25:42 -0500 Subject: [PATCH 5/9] baseline models Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/models/fcnn.py | 166 ++++++++++++++++ gridfm_graphkit/models/gnn_heterogeneous.py | 205 ++++++++++++++++++++ 2 files changed, 371 insertions(+) create mode 100644 gridfm_graphkit/models/fcnn.py create mode 100644 gridfm_graphkit/models/gnn_heterogeneous.py diff --git a/gridfm_graphkit/models/fcnn.py b/gridfm_graphkit/models/fcnn.py new file mode 100644 index 0000000..87a9ac8 --- /dev/null +++ b/gridfm_graphkit/models/fcnn.py @@ -0,0 +1,166 @@ +import torch +from torch import nn +from torch_geometric.nn import HeteroConv, TransformerConv +from gridfm_graphkit.io.registries import MODELS_REGISTRY +from gridfm_graphkit.io.param_handler import get_physics_decoder +from torch_scatter import scatter_add +from gridfm_graphkit.models.utils import ( + ComputeBranchFlow, + ComputeNodeInjection, + ComputeNodeResiduals, + bound_with_sigmoid, +) +from gridfm_graphkit.datasets.globals import ( + # Bus feature indices + VM_H, + VA_H, + MIN_VM_H, + MAX_VM_H, + # Output feature indices + VM_OUT, + PG_OUT_GEN, + # Generator feature indices + PG_H, + MIN_PG, + MAX_PG, + BUS_OUT_DIMENSIONS, + GEN_OUT_DIMENSIONS, +) + + + + + +@MODELS_REGISTRY.register("FullyConnectedNN") +class FullyConnectedNN(nn.Module): + """ + + """ + + def __init__(self, args) -> None: + super().__init__() + self.num_layers = args.model.num_layers + self.hidden_dim = args.model.hidden_size + self.input_bus_dim = args.model.input_bus_dim + self.input_gen_dim = args.model.input_gen_dim + self.edge_dim = args.model.edge_dim + self.task = args.task.task_name + self.dropout = getattr(args.model, "dropout", 0.0) + + # projections for each node type + self.input_proj_bus = nn.Sequential( + nn.Linear(self.input_bus_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + ) + + self.input_proj_gen = nn.Sequential( + nn.Linear(self.input_gen_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + ) + + self.input_proj_edge = nn.Sequential( + nn.Linear(self.edge_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + ) + + # Build hetero layers: HeteroConv of TransformerConv per relation + self.layers = nn.ModuleList() + self.norms = nn.LayerNorm(self.hidden_dim) + for i in range(self.num_layers): + layer = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + ) + self.layers.append(layer) + + # Norms for node representations (note: after HeteroConv each node type will have size out_dim * heads) + + + # Separate shared MLPs to produce final bus/gen outputs (predictions y) + self.mlp_bus = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, BUS_OUT_DIMENSIONS), + ) + + self.mlp_gen = nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, GEN_OUT_DIMENSIONS), + ) + + self.activation = nn.LeakyReLU() + + + def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + """ + x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) + batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity. + mask: optional mask per node (applies when computing residuals) + """ + + + # 1) initial projections + h_bus = self.input_proj_bus(x_dict["bus"]) # [num_bus, hidden_dim] + h_gen = self.input_proj_gen(x_dict["gen"]) # [num_gen, hidden_dim] + + # concatenate data for forward propagation + combined_data = torch.cat((h_bus, h_gen), dim=0) + + num_bus = x_dict["bus"].size(0) + + # iterate layers + for layer in self.layers: + layer_output = layer(combined_data) # [Nb+Ng, hidden_dim] + layer_output = self.norms(layer_output) # [Nb+Ng, hidden_dim] + layer_output = self.activation(layer_output) + + # # skip connection + combined_data = combined_data + layer_output + + # split data + # print(f'\n\n\n\n\n{combined_data.shape=}') + out_bus = combined_data[:num_bus] + out_gen = combined_data[num_bus:] + # print(f'\n\n\n\n\n{out_bus.shape=}') + # print(f'\n\n\n\n\n{out_gen.shape=}') + + + # Decode bus and generator predictions + output_temp = self.mlp_bus(out_bus) # [num_buses, 4] -> [Vm, Va, Pg, Qg] + gen_temp = self.mlp_gen(out_gen) # [num_gens, 1] -> Pg + + + + # print(f'\n\n\n\n\n{output_temp.shape=}') + # print(f'{gen_temp.shape=}') + + + return {"bus": output_temp, "gen": gen_temp} \ No newline at end of file diff --git a/gridfm_graphkit/models/gnn_heterogeneous.py b/gridfm_graphkit/models/gnn_heterogeneous.py new file mode 100644 index 0000000..11be75c --- /dev/null +++ b/gridfm_graphkit/models/gnn_heterogeneous.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +from torch_geometric.nn import HeteroConv, TransformerConv +from gridfm_graphkit.io.registries import MODELS_REGISTRY +from gridfm_graphkit.io.param_handler import get_physics_decoder +from torch_scatter import scatter_add +from gridfm_graphkit.models.utils import ( + ComputeBranchFlow, + ComputeNodeInjection, + ComputeNodeResiduals, + bound_with_sigmoid, +) +from gridfm_graphkit.datasets.globals import ( + # Bus feature indices + VM_H, + VA_H, + MIN_VM_H, + MAX_VM_H, + # Output feature indices + VM_OUT, + PG_OUT_GEN, + # Generator feature indices + PG_H, + MIN_PG, + MAX_PG, + BUS_OUT_DIMENSIONS, + GEN_OUT_DIMENSIONS, +) + + +@MODELS_REGISTRY.register("HeterogeneousGNN") +class HeterogeneousGNN(nn.Module): + """ + Heterogeneous version of your Transformer-based GNN for buses and generators. + - Expects node features as dict: x_dict = {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + - Expects edge_index_dict and edge_attr_dict with keys: + ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + (edge_attr only needed for bus-bus currently; other relations can be None) + """ + + def __init__(self, args) -> None: + super().__init__() + self.num_layers = args.model.num_layers + self.hidden_dim = args.model.hidden_size + self.input_bus_dim = args.model.input_bus_dim + self.input_gen_dim = args.model.input_gen_dim + self.output_bus_dim = args.model.output_bus_dim + self.output_gen_dim = args.model.output_gen_dim + self.edge_dim = args.model.edge_dim + self.heads = args.model.attention_head + self.task = args.task.task_name + self.dropout = getattr(args.model, "dropout", 0.0) + + # projections for each node type + self.input_proj_bus = nn.Sequential( + nn.Linear(self.input_bus_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + ) + + self.input_proj_gen = nn.Sequential( + nn.Linear(self.input_gen_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + ) + + self.input_proj_edge = nn.Sequential( + nn.Linear(self.edge_dim, self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + ) + + # a small physics MLP that will take residuals (real, imag) and return a correction + self.physics_mlp = nn.Sequential( + nn.Linear(2, self.hidden_dim * self.heads), + nn.LeakyReLU(), + ) + + # Build hetero layers: HeteroConv of TransformerConv per relation + self.layers = nn.ModuleList() + self.norms_bus = nn.ModuleList() + self.norms_gen = nn.ModuleList() + for i in range(self.num_layers): + # in-channels depend on whether it is first layer (hidden_dim) or subsequent (hidden_dim * heads) + in_bus = self.hidden_dim if i == 0 else self.hidden_dim * self.heads + in_gen = self.hidden_dim if i == 0 else self.hidden_dim * self.heads + out_dim = self.hidden_dim # TransformerConv will output hidden_dim (per head reduction in HeteroConv call) + + # relation -> conv module mapping + conv_dict = { + ("bus", "connects", "bus"): TransformerConv( + in_bus, + out_dim, + heads=self.heads, + edge_dim=self.hidden_dim, + dropout=self.dropout, + beta=True, + ), + ("gen", "connected_to", "bus"): TransformerConv( + in_gen, + out_dim, + heads=self.heads, + dropout=self.dropout, + beta=True, + ), + ("bus", "connected_to", "gen"): TransformerConv( + in_bus, + out_dim, + heads=self.heads, + dropout=self.dropout, + beta=True, + ), + } + + hetero_conv = HeteroConv(conv_dict, aggr="sum") + self.layers.append(hetero_conv) + + # Norms for node representations (note: after HeteroConv each node type will have size out_dim * heads) + self.norms_bus.append(nn.LayerNorm(out_dim * self.heads)) + self.norms_gen.append(nn.LayerNorm(out_dim * self.heads)) + + # Separate shared MLPs to produce final bus/gen outputs (predictions y) + self.mlp_bus = nn.Sequential( + nn.Linear(self.hidden_dim * self.heads, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, BUS_OUT_DIMENSIONS), + ) + + self.mlp_gen = nn.Sequential( + nn.Linear(self.hidden_dim * self.heads, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + nn.LeakyReLU(), + nn.Linear(self.hidden_dim, GEN_OUT_DIMENSIONS), + ) + + # mask param (kept similar to your original) + self.activation = nn.LeakyReLU() + # self.branch_flow_layer = ComputeBranchFlow() + # self.node_injection_layer = ComputeNodeInjection() + # self.node_residuals_layer = ComputeNodeResiduals() + # self.physics_decoder = get_physics_decoder(args) + + # container for monitoring residual norms per layer and type + # self.layer_residuals = {} + + def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict): + """ + x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]} + edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen") + edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B) + batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity. + mask: optional mask per node (applies when computing residuals) + """ + + self.layer_residuals = {} + + # 1) initial projections + h_bus = self.input_proj_bus(x_dict["bus"]) # [num_bus, hidden_dim] + h_gen = self.input_proj_gen(x_dict["gen"]) # [num_gen, hidden_dim] + + num_bus = x_dict["bus"].size(0) + + edge_attr_proj_dict = {} + for key, edge_attr in edge_attr_dict.items(): + if edge_attr is not None: + edge_attr_proj_dict[key] = self.input_proj_edge(edge_attr) + else: + edge_attr_proj_dict[key] = None + + # bus_mask = mask_dict["bus"][:, VM_H : VA_H + 1] + # gen_mask = mask_dict["gen"][:, : (PG_H + 1)] + # bus_fixed = x_dict["bus"][:, VM_H : VA_H + 1] + # gen_fixed = x_dict["gen"][:, : (PG_H + 1)] + + # iterate layers + for i, conv in enumerate(self.layers): + out_dict = conv( + {"bus": h_bus, "gen": h_gen}, + edge_index_dict, + edge_attr_proj_dict, + ) + out_bus = out_dict["bus"] # [Nb, hidden_dim * heads] + out_gen = out_dict["gen"] # [Ng, hidden_dim * heads] + + out_bus = self.activation(self.norms_bus[i](out_bus)) + out_gen = self.activation(self.norms_gen[i](out_gen)) + + # # skip connection + h_bus = h_bus + out_bus if out_bus.shape == h_bus.shape else out_bus + h_gen = h_gen + out_gen if out_gen.shape == h_gen.shape else out_gen + + + # Decode bus and generator predictions + output_temp = self.mlp_bus(h_bus) # [num_buses, 4] -> [Vm, Va, Pg, Qg] + gen_temp = self.mlp_gen(h_gen) # [num_gens, 1] -> Pg + + # Decode bus and generator predictions + # bus_temp = self.mlp_bus(h_bus) # [Nb, 2] -> Vm, Va + # gen_temp = self.mlp_gen(h_gen) # [Ng, 1] -> Pg + + return {"bus": output_temp, "gen": gen_temp} \ No newline at end of file From 4d0da4cbf1272c636be4858fd05ce2c1031bbc68 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 6 Mar 2026 18:38:42 -0500 Subject: [PATCH 6/9] style Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/models/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gridfm_graphkit/models/__init__.py b/gridfm_graphkit/models/__init__.py index 91ab717..ed8a750 100644 --- a/gridfm_graphkit/models/__init__.py +++ b/gridfm_graphkit/models/__init__.py @@ -1,7 +1,6 @@ from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous from gridfm_graphkit.models.fcnn import FullyConnectedNN from gridfm_graphkit.models.gnn_heterogeneous import HeterogeneousGNN -# from gridfm_graphkit.models.gnn_homogeneous import HomogeneousGNN from gridfm_graphkit.models.utils import ( PhysicsDecoderOPF, @@ -13,7 +12,6 @@ "GNS_heterogeneous", "FullyConnectedNN", "HeterogeneousGNN", - # "HomogeneousGNN", "PhysicsDecoderOPF", "PhysicsDecoderPF", "PhysicsDecoderSE", From 8e46608b7e37099dc61771fc5ccbeb63126a9937 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Sat, 7 Mar 2026 11:40:59 -0500 Subject: [PATCH 7/9] new globals Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/datasets/globals.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gridfm_graphkit/datasets/globals.py b/gridfm_graphkit/datasets/globals.py index ab3c7e3..11bd4bf 100644 --- a/gridfm_graphkit/datasets/globals.py +++ b/gridfm_graphkit/datasets/globals.py @@ -52,3 +52,11 @@ ANG_MAX = 8 # Angle max (deg) RATE_A = 9 # Thermal limit B_ON = 10 # Branch on/off + + + +# ==================================== +# === EXPECTED OUTPUT DIMENSIONS === +# ==================================== +BUS_OUT_DIMENSIONS = 4 +GEN_OUT_DIMENSIONS = 1 \ No newline at end of file From 5e3955c8c5f12c2dd3538d8d487f3cab997265fd Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Wed, 11 Mar 2026 15:35:39 -0400 Subject: [PATCH 8/9] update sgared step Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/tasks/reconstruction_tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gridfm_graphkit/tasks/reconstruction_tasks.py b/gridfm_graphkit/tasks/reconstruction_tasks.py index 8742646..2d770b4 100644 --- a/gridfm_graphkit/tasks/reconstruction_tasks.py +++ b/gridfm_graphkit/tasks/reconstruction_tasks.py @@ -57,6 +57,8 @@ def shared_step(self, batch): batch.edge_attr_dict, batch.mask_dict, model=self.model, + x_dict=batch.x_dict, + batch_dict=batch.batch_dict, ) return output, loss_dict From fbe838a5ddfde1f176279f2c349935964af16393 Mon Sep 17 00:00:00 2001 From: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> Date: Fri, 13 Mar 2026 04:19:16 -0400 Subject: [PATCH 9/9] rename loss Signed-off-by: Naomi Simumba <7224231+naomi-simumba@users.noreply.github.com> --- gridfm_graphkit/training/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gridfm_graphkit/training/loss.py b/gridfm_graphkit/training/loss.py index 7942ba3..33b905a 100644 --- a/gridfm_graphkit/training/loss.py +++ b/gridfm_graphkit/training/loss.py @@ -555,8 +555,8 @@ def forward( loss=0.0 try: - output = {"loss": loss, "Qg Violation Penalty loss": loss.detach()} + output = {"loss": loss, "Optimality loss": loss.detach()} except: - output = {"loss": loss, "Qg Violation Penalty loss": loss} + output = {"loss": loss, "Optimality loss": loss} return output