Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions gridfm_graphkit/datasets/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,11 @@
ANG_MAX = 8 # Angle max (deg)
RATE_A = 9 # Thermal limit
B_ON = 10 # Branch on/off



# ====================================
# === EXPECTED OUTPUT DIMENSIONS ===
# ====================================
BUS_OUT_DIMENSIONS = 4
GEN_OUT_DIMENSIONS = 1
104 changes: 68 additions & 36 deletions gridfm_graphkit/datasets/hetero_powergrid_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from gridfm_graphkit.datasets.utils import (
split_dataset,
split_dataset_by_load_scenario_idx,
split_from_existing_files,
)
from gridfm_graphkit.datasets.powergrid_hetero_dataset import HeteroGridDatasetDisk
import numpy as np
import random
import warnings
import os
import lightning as L
from pathlib import Path
from typing import List
from lightning.pytorch.loggers import MLFlowLogger

Expand Down Expand Up @@ -96,6 +98,11 @@ def __init__(
"split_by_load_scenario_idx",
False,
)
self.split_from_existing_files = getattr(
args.data,
"split_from_existing_files",
None,
)
self.args = args
self.normalizer_stats_path = normalizer_stats_path
self.data_normalizers = []
Expand All @@ -108,6 +115,15 @@ def __init__(
self.test_scenario_ids: List[List[int]] = []
self._is_setup_done = False

if self.split_by_load_scenario_idx:
assert self.split_from_existing_files is None, " either `split_by_load_scenario_idx` or `split_from_existing_files` may be used, not both"

if self.split_from_existing_files is not None:
assert isinstance(self.split_from_existing_files, str), "`split_from_existing_files` must be an existing folder in string format"
self.split_from_existing_files = Path(self.split_from_existing_files)
assert self.split_from_existing_files.is_dir(), "`split_from_existing_files` must be an existing folder in string format"


def setup(self, stage: str):
if self._is_setup_done:
print(f"Setup already done for stage={stage}, skipping...")
Expand Down Expand Up @@ -161,49 +177,64 @@ def setup(self, stage: str):

# Create a subset
all_indices = list(range(len(dataset)))
# Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order
random.seed(self.args.seed)
random.shuffle(all_indices)
subset_indices = all_indices[:num_scenarios]
splits_dir = Path(data_path_network)
splits_dir = splits_dir / "splits"

# load_scenario for each scenario in the subset
load_scenarios = dataset.load_scenarios[subset_indices]

dataset = Subset(dataset, subset_indices)

# Random seed set before every split, same as above
np.random.seed(self.args.seed)
if self.split_by_load_scenario_idx:
train_dataset, val_dataset, test_dataset = (
split_dataset_by_load_scenario_idx(
if self.split_from_existing_files is not None:
(train_dataset, val_dataset, test_dataset), subset_indices = (
split_from_existing_files(
dataset,
splits_dir,
self.split_from_existing_files,
)
)
train_scenario_ids = subset_indices["train"]
val_scenario_ids = subset_indices["val"]
test_scenario_ids = subset_indices["test"]
else:
# Random seed set before every shuffle for reproducibility in case the power grid datasets are analyzed in a different order
random.seed(self.args.seed)
random.shuffle(all_indices)
subset_indices = all_indices[:num_scenarios]

# load_scenario for each scenario in the subset
load_scenarios = dataset.load_scenarios[subset_indices]

dataset = Subset(dataset, subset_indices)

# Random seed set before every split, same as above
np.random.seed(self.args.seed)
if self.split_by_load_scenario_idx:
train_dataset, val_dataset, test_dataset = (
split_dataset_by_load_scenario_idx(
dataset,
self.data_dir,
load_scenarios,
self.args.data.val_ratio,
self.args.data.test_ratio,
)
)
else:
train_dataset, val_dataset, test_dataset = split_dataset(
dataset,
self.data_dir,
load_scenarios,
self.args.data.val_ratio,
self.args.data.test_ratio,
)

# Extract scenario IDs for each split
train_scenario_ids = self._extract_scenario_ids(
train_dataset,
subset_indices,
)
else:
train_dataset, val_dataset, test_dataset = split_dataset(
dataset,
self.data_dir,
self.args.data.val_ratio,
self.args.data.test_ratio,
val_scenario_ids = self._extract_scenario_ids(
val_dataset,
subset_indices,
)
test_scenario_ids = self._extract_scenario_ids(
test_dataset,
subset_indices,
)

# Extract scenario IDs for each split
train_scenario_ids = self._extract_scenario_ids(
train_dataset,
subset_indices,
)
val_scenario_ids = self._extract_scenario_ids(
val_dataset,
subset_indices,
)
test_scenario_ids = self._extract_scenario_ids(
test_dataset,
subset_indices,
)

# Fit normalizer: restore from saved stats only for fit_on_train
# normalizers (global baseMVA must match the model's training run).
Expand All @@ -214,7 +245,8 @@ def setup(self, stage: str):
and network in saved_stats
and data_normalizer.fit_strategy == "fit_on_train"
)
if use_saved:
# if use_saved:
if False:
print(f"Restoring normalizer for {network} from saved stats")
data_normalizer.fit_from_dict(saved_stats[network])
else:
Expand Down
23 changes: 23 additions & 0 deletions gridfm_graphkit/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,26 @@ def split_dataset_by_load_scenario_idx(
test_dataset = Subset(dataset, test_indices)

return train_dataset, val_dataset, test_dataset


def split_from_existing_files(
dataset,
save_to_folder: str,
splits_folder: str,
) -> Tuple[Subset, Subset, Subset]:
output=[]

indices = {}

for split in ["train", "val", "test"]:
split_file = splits_folder / f"{split}.pt"
assert split_file.is_file(), f"{str(split_file)} does not exist"
split_indices = torch.load(str(split_file), weights_only=True)
split_dataset = Subset(dataset, split_indices)
output.append(split_dataset)
split_indices = list(split_indices)
print(f'{split=} {len(split_indices)=}')
indices[split]=[int(t.item()) for t in split_indices]

output = tuple(output)
return output, indices
5 changes: 5 additions & 0 deletions gridfm_graphkit/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from gridfm_graphkit.models.gnn_heterogeneous_gns import GNS_heterogeneous
from gridfm_graphkit.models.fcnn import FullyConnectedNN
from gridfm_graphkit.models.gnn_heterogeneous import HeterogeneousGNN

from gridfm_graphkit.models.utils import (
PhysicsDecoderOPF,
PhysicsDecoderPF,
Expand All @@ -7,6 +10,8 @@

__all__ = [
"GNS_heterogeneous",
"FullyConnectedNN",
"HeterogeneousGNN",
"PhysicsDecoderOPF",
"PhysicsDecoderPF",
"PhysicsDecoderSE",
Expand Down
166 changes: 166 additions & 0 deletions gridfm_graphkit/models/fcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import torch
from torch import nn
from torch_geometric.nn import HeteroConv, TransformerConv
from gridfm_graphkit.io.registries import MODELS_REGISTRY
from gridfm_graphkit.io.param_handler import get_physics_decoder
from torch_scatter import scatter_add
from gridfm_graphkit.models.utils import (
ComputeBranchFlow,
ComputeNodeInjection,
ComputeNodeResiduals,
bound_with_sigmoid,
)
from gridfm_graphkit.datasets.globals import (
# Bus feature indices
VM_H,
VA_H,
MIN_VM_H,
MAX_VM_H,
# Output feature indices
VM_OUT,
PG_OUT_GEN,
# Generator feature indices
PG_H,
MIN_PG,
MAX_PG,
BUS_OUT_DIMENSIONS,
GEN_OUT_DIMENSIONS,
)





@MODELS_REGISTRY.register("FullyConnectedNN")
class FullyConnectedNN(nn.Module):
"""

"""

def __init__(self, args) -> None:
super().__init__()
self.num_layers = args.model.num_layers
self.hidden_dim = args.model.hidden_size
self.input_bus_dim = args.model.input_bus_dim
self.input_gen_dim = args.model.input_gen_dim
self.edge_dim = args.model.edge_dim
self.task = args.task.task_name
self.dropout = getattr(args.model, "dropout", 0.0)

# projections for each node type
self.input_proj_bus = nn.Sequential(
nn.Linear(self.input_bus_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
)

self.input_proj_gen = nn.Sequential(
nn.Linear(self.input_gen_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
)

self.input_proj_edge = nn.Sequential(
nn.Linear(self.edge_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
)

# Build hetero layers: HeteroConv of TransformerConv per relation
self.layers = nn.ModuleList()
self.norms = nn.LayerNorm(self.hidden_dim)
for i in range(self.num_layers):
layer = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, self.hidden_dim),
)
self.layers.append(layer)

# Norms for node representations (note: after HeteroConv each node type will have size out_dim * heads)


# Separate shared MLPs to produce final bus/gen outputs (predictions y)
self.mlp_bus = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, BUS_OUT_DIMENSIONS),
)

self.mlp_gen = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
nn.LeakyReLU(),
nn.Linear(self.hidden_dim, GEN_OUT_DIMENSIONS),
)

self.activation = nn.LeakyReLU()


def forward(self, x_dict, edge_index_dict, edge_attr_dict, mask_dict):
"""
x_dict: {"bus": Tensor[num_bus, bus_feat], "gen": Tensor[num_gen, gen_feat]}
edge_index_dict: keys like ("bus","connects","bus"), ("gen","connected_to","bus"), ("bus","connected_to","gen")
edge_attr_dict: same keys -> edge attributes (bus-bus requires G,B)
batch_dict: dict mapping node types to batch tensors (if using batching). Not used heavily here but kept for API parity.
mask: optional mask per node (applies when computing residuals)
"""


# 1) initial projections
h_bus = self.input_proj_bus(x_dict["bus"]) # [num_bus, hidden_dim]
h_gen = self.input_proj_gen(x_dict["gen"]) # [num_gen, hidden_dim]

# concatenate data for forward propagation
combined_data = torch.cat((h_bus, h_gen), dim=0)

num_bus = x_dict["bus"].size(0)

# iterate layers
for layer in self.layers:
layer_output = layer(combined_data) # [Nb+Ng, hidden_dim]
layer_output = self.norms(layer_output) # [Nb+Ng, hidden_dim]
layer_output = self.activation(layer_output)

# # skip connection
combined_data = combined_data + layer_output

# split data
# print(f'\n\n\n\n\n{combined_data.shape=}')
out_bus = combined_data[:num_bus]
out_gen = combined_data[num_bus:]
# print(f'\n\n\n\n\n{out_bus.shape=}')
# print(f'\n\n\n\n\n{out_gen.shape=}')


# Decode bus and generator predictions
output_temp = self.mlp_bus(out_bus) # [num_buses, 4] -> [Vm, Va, Pg, Qg]
gen_temp = self.mlp_gen(out_gen) # [num_gens, 1] -> Pg



# print(f'\n\n\n\n\n{output_temp.shape=}')
# print(f'{gen_temp.shape=}')


return {"bus": output_temp, "gen": gen_temp}
Loading