diff --git a/feature_reconstruction_task.py b/feature_reconstruction_task.py new file mode 100644 index 0000000..5975cfe --- /dev/null +++ b/feature_reconstruction_task.py @@ -0,0 +1,359 @@ +import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau +import lightning as L +from pytorch_lightning.utilities import rank_zero_only +import numpy as np +import os +import pandas as pd + +from lightning.pytorch.loggers import MLFlowLogger +from gridfm_graphkit.io.param_handler import load_model, get_loss_function +import torch.nn.functional as F +from gridfm_graphkit.datasets.globals import PQ, PV, REF, PD, QD, PG, QG, VM, VA + + +class FeatureReconstructionTask(L.LightningModule): + """ + PyTorch Lightning task for node feature reconstruction on power grid graphs. + + This task wraps a GridFM model inside a LightningModule and defines the full + training, validation, testing, and prediction logic. It is designed to + reconstruct masked node features from graph-structured input data, using + datasets and normalizers provided by `gridfm-graphkit`. + + Args: + args (NestedNamespace): Experiment configuration. Expected fields include `training.batch_size`, `optimizer.*`, etc. + node_normalizers (list): One normalizer per dataset to (de)normalize node features. + edge_normalizers (list): One normalizer per dataset to (de)normalize edge features. + + Attributes: + model (torch.nn.Module): model loaded via `load_model`. + loss_fn (callable): Loss function resolved from configuration. + batch_size (int): Training batch size. From ``args.training.batch_size`` + node_normalizers (list): Dataset-wise node feature normalizers. + edge_normalizers (list): Dataset-wise edge feature normalizers. + + Methods: + forward(x, pe, edge_index, edge_attr, batch, mask=None): + Forward pass with optional feature masking. + training_step(batch): + One training step: computes loss, logs metrics, returns loss. + validation_step(batch, batch_idx): + One validation step: computes losses and logs metrics. + test_step(batch, batch_idx, dataloader_idx=0): + Evaluate on test data, compute per-node-type MSEs, and log per-dataset metrics. + predict_step(batch, batch_idx, dataloader_idx=0): + Run inference and return denormalized outputs + node masks. + configure_optimizers(): + Setup Adam optimizer and ReduceLROnPlateau scheduler. + on_fit_start(): + Save normalization statistics at the beginning of training. + on_test_end(): + Collect test metrics across datasets and export summary CSV reports. + + Notes: + - Node types are distinguished using the global constants (`PQ`, `PV`, `REF`). + - The datamodule must provide `batch.mask` for masking node features. + - Test metrics include per-node-type RMSE for [Pd, Qd, Pg, Qg, Vm, Va]. + - Reports are saved under `/test/.csv`. + + Example: + ```python + model = FeatureReconstructionTask(args, node_normalizers, edge_normalizers) + output = model(batch.x, batch.pe, batch.edge_index, batch.edge_attr, batch.batch) + ``` + """ + + def __init__(self, args, node_normalizers, edge_normalizers): + super().__init__() + self.model = load_model(args=args) + self.args = args + self.loss_fn = get_loss_function(args) + self.batch_size = int(args.training.batch_size) + self.node_normalizers = node_normalizers + self.edge_normalizers = edge_normalizers + self.save_hyperparameters() + + def forward(self, x, pe, edge_index, edge_attr, batch, mask=None): + if mask is not None: + mask_value_expanded = self.model.mask_value.expand(x.shape[0], -1) + x[:, : mask.shape[1]][mask] = mask_value_expanded[mask] + return self.model(x, pe, edge_index, edge_attr, batch) + + @rank_zero_only + def on_fit_start(self): + # Determine save path + if isinstance(self.logger, MLFlowLogger): + log_dir = os.path.join( + self.logger.save_dir, + self.logger.experiment_id, + self.logger.run_id, + "artifacts", + "stats", + ) + else: + log_dir = os.path.join(self.logger.save_dir, "stats") + + os.makedirs(log_dir, exist_ok=True) + log_stats_path = os.path.join(log_dir, "normalization_stats.txt") + + # Collect normalization stats + with open(log_stats_path, "w") as log_file: + for i, normalizer in enumerate(self.node_normalizers): + log_file.write( + f"Node Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", + ) + + for i, normalizer in enumerate(self.edge_normalizers): + log_file.write( + f"Edge Normalizer {self.args.data.networks[i]} stats:\n{normalizer.get_stats()}\n\n", + ) + + def shared_step(self, batch): + output = self.forward( + x=batch.x, + pe=batch.pe, + edge_index=batch.edge_index, + edge_attr=batch.edge_attr, + batch=batch.batch, + mask=batch.mask, + ) + + loss_dict = self.loss_fn( + output, + batch.y, + batch.edge_index, + batch.edge_attr, + batch.mask, + x=batch.x, + ) + return output, loss_dict + + def training_step(self, batch): + _, loss_dict = self.shared_step(batch) + current_lr = self.optimizer.param_groups[0]["lr"] + metrics = {} + metrics["Training Loss"] = loss_dict["loss"].detach() + metrics["Learning Rate"] = current_lr + for metric, value in metrics.items(): + self.log( + metric, + value, + batch_size=batch.num_graphs, + sync_dist=True, + on_epoch=True, + prog_bar=True, + logger=True, + on_step=False, + ) + + return loss_dict["loss"] + + def validation_step(self, batch, batch_idx): + _, loss_dict = self.shared_step(batch) + loss_dict["loss"] = loss_dict["loss"].detach() + for metric, value in loss_dict.items(): + metric_name = f"Validation {metric}" + self.log( + metric_name, + value, + batch_size=batch.num_graphs, + sync_dist=True, + on_epoch=True, + prog_bar=True, + logger=True, + on_step=False, + ) + + return loss_dict["loss"] + + def test_step(self, batch, batch_idx, dataloader_idx=0): + output, loss_dict = self.shared_step(batch) + + dataset_name = self.args.data.networks[dataloader_idx] + + output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output) + target_denorm = self.node_normalizers[dataloader_idx].inverse_transform(batch.y) + + mask_PQ = batch.x[:, PQ] == 1 + mask_PV = batch.x[:, PV] == 1 + mask_REF = batch.x[:, REF] == 1 + + mse_PQ = F.mse_loss( + output_denorm[mask_PQ], + target_denorm[mask_PQ], + reduction="none", + ) + mse_PV = F.mse_loss( + output_denorm[mask_PV], + target_denorm[mask_PV], + reduction="none", + ) + mse_REF = F.mse_loss( + output_denorm[mask_REF], + target_denorm[mask_REF], + reduction="none", + ) + + mse_PQ = mse_PQ.mean(dim=0) + mse_PV = mse_PV.mean(dim=0) + mse_REF = mse_REF.mean(dim=0) + + loss_dict["MSE PQ nodes - PD"] = mse_PQ[PD] + loss_dict["MSE PV nodes - PD"] = mse_PV[PD] + loss_dict["MSE REF nodes - PD"] = mse_REF[PD] + + loss_dict["MSE PQ nodes - QD"] = mse_PQ[QD] + loss_dict["MSE PV nodes - QD"] = mse_PV[QD] + loss_dict["MSE REF nodes - QD"] = mse_REF[QD] + + loss_dict["MSE PQ nodes - PG"] = mse_PQ[PG] + loss_dict["MSE PV nodes - PG"] = mse_PV[PG] + loss_dict["MSE REF nodes - PG"] = mse_REF[PG] + + loss_dict["MSE PQ nodes - QG"] = mse_PQ[QG] + loss_dict["MSE PV nodes - QG"] = mse_PV[QG] + loss_dict["MSE REF nodes - QG"] = mse_REF[QG] + + loss_dict["MSE PQ nodes - VM"] = mse_PQ[VM] + loss_dict["MSE PV nodes - VM"] = mse_PV[VM] + loss_dict["MSE REF nodes - VM"] = mse_REF[VM] + + loss_dict["MSE PQ nodes - VA"] = mse_PQ[VA] + loss_dict["MSE PV nodes - VA"] = mse_PV[VA] + loss_dict["MSE REF nodes - VA"] = mse_REF[VA] + + loss_dict["Test loss"] = loss_dict.pop("loss").detach() + for metric, value in loss_dict.items(): + metric_name = f"{dataset_name}/{metric}" + if "p.u." in metric: + # Denormalize metrics expressed in p.u. + value *= self.node_normalizers[dataloader_idx].baseMVA + metric_name = metric_name.replace("in p.u.", "").strip() + self.log( + metric_name, + value, + batch_size=batch.num_graphs, + add_dataloader_idx=False, + sync_dist=True, + logger=False, + ) + return + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + output, _ = self.shared_step(batch) + output_denorm = self.node_normalizers[dataloader_idx].inverse_transform(output) + + # Count buses and generate per-node scenario_id + bus_counts = batch.batch.unique(return_counts=True)[1] + scenario_ids = batch.scenario_id # shape: [num_graphs] + scenario_per_node = torch.cat( + [ + torch.full((count,), sid, dtype=torch.int32) + for count, sid in zip(bus_counts, scenario_ids) + ], + ) + + bus_numbers = np.concatenate([np.arange(count.item()) for count in bus_counts]) + + return { + "output": output_denorm.cpu().numpy(), + "scenario_id": scenario_per_node, + "bus_number": bus_numbers, + } + + @rank_zero_only + def on_test_end(self): + if isinstance(self.logger, MLFlowLogger): + artifact_dir = os.path.join( + self.logger.save_dir, + self.logger.experiment_id, + self.logger.run_id, + "artifacts", + ) + else: + artifact_dir = self.logger.save_dir + + final_metrics = self.trainer.callback_metrics + grouped_metrics = {} + + for full_key, value in final_metrics.items(): + try: + value = value.item() + except AttributeError: + pass + + if "/" in full_key: + dataset_name, metric = full_key.split("/", 1) + if dataset_name not in grouped_metrics: + grouped_metrics[dataset_name] = {} + grouped_metrics[dataset_name][metric] = value + + for dataset, metrics in grouped_metrics.items(): + rmse_PQ = [ + metrics.get(f"MSE PQ nodes - {label}", float("nan")) ** 0.5 + for label in ["PD", "QD", "PG", "QG", "VM", "VA"] + ] + rmse_PV = [ + metrics.get(f"MSE PV nodes - {label}", float("nan")) ** 0.5 + for label in ["PD", "QD", "PG", "QG", "VM", "VA"] + ] + rmse_REF = [ + metrics.get(f"MSE REF nodes - {label}", float("nan")) ** 0.5 + for label in ["PD", "QD", "PG", "QG", "VM", "VA"] + ] + + avg_active_res = metrics.get("Active Power Loss", " ") + avg_reactive_res = metrics.get("Reactive Power Loss", " ") + + data = { + "Metric": [ + "RMSE-PQ", + "RMSE-PV", + "RMSE-REF", + "Avg. active res. (MW)", + "Avg. reactive res. (MVar)", + ], + "Pd (MW)": [ + rmse_PQ[0], + rmse_PV[0], + rmse_REF[0], + avg_active_res, + avg_reactive_res, + ], + "Qd (MVar)": [rmse_PQ[1], rmse_PV[1], rmse_REF[1], " ", " "], + "Pg (MW)": [rmse_PQ[2], rmse_PV[2], rmse_REF[2], " ", " "], + "Qg (MVar)": [rmse_PQ[3], rmse_PV[3], rmse_REF[3], " ", " "], + "Vm (p.u.)": [rmse_PQ[4], rmse_PV[4], rmse_REF[4], " ", " "], + "Va (degree)": [rmse_PQ[5], rmse_PV[5], rmse_REF[5], " ", " "], + } + + df = pd.DataFrame(data) + + test_dir = os.path.join(artifact_dir, "test") + os.makedirs(test_dir, exist_ok=True) + csv_path = os.path.join(test_dir, f"{dataset}.csv") + df.to_csv(csv_path, index=False) + + def configure_optimizers(self): + self.optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.args.optimizer.learning_rate, + betas=(self.args.optimizer.beta1, self.args.optimizer.beta2), + ) + + self.scheduler = ReduceLROnPlateau( + self.optimizer, + mode="min", + factor=self.args.optimizer.lr_decay, + patience=self.args.optimizer.lr_patience, + ) + config_optim = { + "optimizer": self.optimizer, + "lr_scheduler": { + "scheduler": self.scheduler, + "monitor": "Validation loss", + "reduce_on_plateau": True, + }, + } + return config_optim diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..ac3d098 --- /dev/null +++ b/loss.py @@ -0,0 +1,444 @@ +from gridfm_graphkit.datasets.globals import PD, QD, PG, QG, VM, VA, G, B, REF + +import torch.nn.functional as F +import torch +from torch_geometric.utils import to_torch_coo_tensor +import torch.nn as nn + +from torch_geometric.utils import to_dense_adj +from collections import deque + + +class MaskedMSELoss(nn.Module): + """ + Mean Squared Error loss computed only on masked elements. + """ + + def __init__(self, reduction="mean"): + super(MaskedMSELoss, self).__init__() + self.reduction = reduction + + def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None, x=None): + loss = F.mse_loss(pred[mask], target[mask], reduction=self.reduction) + return {"loss": loss, "Masked MSE loss": loss.detach()} + + +class MSELoss(nn.Module): + """Standard Mean Squared Error loss.""" + + def __init__(self, reduction="mean"): + super(MSELoss, self).__init__() + self.reduction = reduction + + def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None, x=None): + loss = F.mse_loss(pred, target, reduction=self.reduction) + return {"loss": loss, "MSE loss": loss.detach()} + + +class SCELoss(nn.Module): + """Scaled Cosine Error Loss with optional masking and normalization.""" + + def __init__(self, alpha=3): + super(SCELoss, self).__init__() + self.alpha = alpha + + def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None, x=None): + if mask is not None: + pred = F.normalize(pred[mask], p=2, dim=-1) + target = F.normalize(target[mask], p=2, dim=-1) + else: + pred = F.normalize(pred, p=2, dim=-1) + target = F.normalize(target, p=2, dim=-1) + + loss = ((1 - (pred * target).sum(dim=-1)).pow(self.alpha)).mean() + + return { + "loss": loss, + "SCE loss": loss.detach(), + } + + +class PBELoss(nn.Module): + """ + Loss based on the Power Balance Equations. + """ + + def __init__(self, visualization=False): + super(PBELoss, self).__init__() + + self.visualization = visualization + + def forward(self, pred, target, edge_index, edge_attr, mask, x=None): + # Create a temporary copy of pred to avoid modifying it + temp_pred = pred.clone() + + # If a value is not masked, then use the original one + unmasked = ~mask + temp_pred[unmasked] = target[unmasked] + + # Voltage magnitudes and angles + V_m = temp_pred[:, VM] # Voltage magnitudes + V_a = temp_pred[:, VA] # Voltage angles + + # Compute the complex voltage vector V + V = V_m * torch.exp(1j * V_a) + + # Compute the conjugate of V + V_conj = torch.conj(V) + + # Extract edge attributes for Y_bus + edge_complex = edge_attr[:, G] + 1j * edge_attr[:, B] + + # Construct sparse admittance matrix (real and imaginary parts separately) + Y_bus_sparse = to_torch_coo_tensor( + edge_index, + edge_complex, + size=(target.size(0), target.size(0)), + ) + + # Conjugate of the admittance matrix + Y_bus_conj = torch.conj(Y_bus_sparse) + + # Compute the complex power injection S_injection + S_injection = torch.diag(V) @ Y_bus_conj @ V_conj + + # Compute net power balance + net_P = temp_pred[:, PG] - temp_pred[:, PD] + net_Q = temp_pred[:, QG] - temp_pred[:, QD] + S_net_power_balance = net_P + 1j * net_Q + + # Power balance loss + loss = torch.mean( + torch.abs(S_net_power_balance - S_injection), + ) # Mean of absolute complex power value + + real_loss_power = torch.mean( + torch.abs(torch.real(S_net_power_balance - S_injection)), + ) + imag_loss_power = torch.mean( + torch.abs(torch.imag(S_net_power_balance - S_injection)), + ) + if self.visualization: + return { + "loss": loss, + "Power loss in p.u.": loss.detach(), + "Active Power Loss in p.u.": real_loss_power.detach(), + "Reactive Power Loss in p.u.": imag_loss_power.detach(), + "Nodal Active Power Loss in p.u.": torch.abs( + torch.real(S_net_power_balance - S_injection), + ), + "Nodal Reactive Power Loss in p.u.": torch.abs( + torch.imag(S_net_power_balance - S_injection), + ), + } + else: + return { + "loss": loss, + "Power loss in p.u.": loss.detach(), + "Active Power Loss in p.u.": real_loss_power.detach(), + "Reactive Power Loss in p.u.": imag_loss_power.detach(), + } + + +class MixedLoss(nn.Module): + """ + Combines multiple loss functions with weighted sum. + + Args: + loss_functions (list[nn.Module]): List of loss functions. + weights (list[float]): Corresponding weights for each loss function. + """ + + def __init__(self, loss_functions, weights): + super(MixedLoss, self).__init__() + + if len(loss_functions) != len(weights): + raise ValueError( + "The number of loss functions must match the number of weights.", + ) + + self.loss_functions = nn.ModuleList(loss_functions) + self.weights = weights + + def forward(self, pred, target, edge_index=None, edge_attr=None, mask=None, x=None): + """ + Compute the weighted sum of all specified losses. + + Parameters: + + - pred: Predictions. + - target: Ground truth. + - edge_index: Optional edge index for graph-based losses. + - edge_attr: Optional edge attributes for graph-based losses. + - mask: Optional mask to filter the inputs for certain losses. + + Returns: + - A dictionary with the total loss and individual losses. + """ + total_loss = 0.0 + loss_details = {} + + for i, loss_fn in enumerate(self.loss_functions): + loss_output = loss_fn( + pred, + target, + edge_index=edge_index, + edge_attr=edge_attr, + mask=mask, + x=x, + ) + + # Assume each loss function returns a dictionary with a "loss" key + individual_loss = loss_output.pop("loss") + weighted_loss = self.weights[i] * individual_loss + + total_loss += weighted_loss + + # Add other keys from the loss output to the details + for key, val in loss_output.items(): + loss_details[key] = val + + loss_details["loss"] = total_loss + return loss_details + + +class VLDLoss(nn.Module): + """ + Global connectivity / energization constraint to the REF bus, + considering both failed nodes (low Vm) and failed edges (|G|,|B| below threshold). + + Args: + voltage_threshold: Below this, node is considered failed. + edge_GB_threshold: Edge is considered failed if both |G| and |B| are below this threshold (in normalized units). + beta: Sharpness of sigmoid in connectivity update. + penalty_scale: Global scale for this loss. + margin_factor: margin = margin_factor * voltage_threshold. + safety_margin: Extra hops added to max BFS distance for K. + undirected: If True, treat connectivity as undirected. + max_K_cap: Upper cap on K to avoid extreme propagation depth. + + Training procedure + - Builds a "healthy" connectivity graph using only edges with + sufficiently large |G| and |B|. + - Uses BFS from REF on this healthy-edge graph to choose K per graph. + - Propagates a soft connectivity score h from REF through: + * healthy edges, and + * healthy nodes. + - Penalizes: + * reachable nodes: high V but h ~ 0 (disconnected but energized) + * unreachable nodes: any non-zero V (must be ~0 if islanded). + """ + + def __init__(self, visualization=False): + super(VLDLoss, self).__init__() + + self.visualization = visualization + + self.voltage_threshold = 1e-3 + self.edge_GB_threshold = 1e-6 + self.beta = 8.0 + self.penalty_scale = 1.0 + self.margin_factor = 0.5 + self.safety_margin = 1 + self.undirected = True + self.max_K_cap = 64 + self.INF = 10**9 # sentinel for unreachable + + + """ + self.voltage_threshold = args.voltage_loss_detector.voltage_threshold + self.edge_GB_threshold = args.voltage_loss_detector.edge_GB_threshold + self.beta = args.voltage_loss_detector.beta + self.penalty_scale = args.voltage_loss_detector.penalty_scale + self.margin_factor = args.voltage_loss_detector.margin_factor + self.safety_margin = args.voltage_loss_detector.safety_margin + self.undirected = args.voltage_loss_detector.undirected + self.max_K_cap = args.voltage_loss_detector.max_K_caps + """ + + # ---------- build healthy-edge adjacency / edge_index ---------- + + def healthy_edge_mask(self, edge_attr): + """ + Decide which edges are "healthy" based on G,B magnitude. + + Args: + edge_attr: (E, F_edge), with G,B at indices G,B. + + Returns: + mask: (E,) bool tensor, True for healthy edges. + """ + G_vals = edge_attr[:, G] + B_vals = edge_attr[:, B] + # Edge is failed if both |G| and |B| are below threshold + healthy = (G_vals.abs() >= self.edge_GB_threshold) | ( + B_vals.abs() >= self.edge_GB_threshold + ) + return healthy + + def pruned_edge_index(self, edge_index, edge_attr): + """ + Keep only healthy edges for connectivity graph. + """ + healthy = self.healthy_edge_mask(edge_attr) # (E,) + return edge_index[:, healthy] + + # ---------- BFS utilities ---------- + + def bfs_distance_from_ref(self, edge_index, num_nodes, ref_idx: int): + """ + Unweighted BFS distances from REF node on the connectivity graph. + + Returns: + dist: (N,) tensor with hop distances (INF for unreachable nodes). + """ + row, col = edge_index + adj = [[] for _ in range(num_nodes)] + # undirected BFS on connectivity + for u, v in zip(row.tolist(), col.tolist()): + adj[u].append(v) + adj[v].append(u) + + dist = [self.INF] * num_nodes + dist[ref_idx] = 0 + q = deque([ref_idx]) + + while q: + u = q.popleft() + for v in adj[u]: + if dist[v] == self.INF: + dist[v] = dist[u] + 1 + q.append(v) + + return torch.tensor(dist, dtype=torch.long, device=edge_index.device) + + def choose_K_for_graph(self, edge_index, num_nodes, ref_idx): + """ + Per-graph K from BFS on healthy-edge graph. + """ + if edge_index.numel() == 0: + # no healthy edges: only REF is trivially connected + dist = torch.full((num_nodes,), self.INF, dtype=torch.long, device=edge_index.device) + dist[ref_idx] = 0 + reachable = dist < self.INF + return 1, dist, reachable + + dist = self.bfs_distance_from_ref(edge_index, num_nodes, ref_idx) + reachable = dist < self.INF + + if reachable.sum() <= 1: + return 1, dist, reachable + + max_dist = dist[reachable].max().item() + K = max_dist + self.safety_margin + K = max(1, min(K, self.max_K_cap)) + return K, dist, reachable + + # ---------- adjacency for propagation (healthy edges only) ---------- + + def build_normalized_adj(self, edge_index, num_nodes): + A = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0].float() # (N, N) + + if self.undirected: + A = ((A + A.t()) > 0).float() + + deg = A.sum(dim=1, keepdim=True) # (N, 1) + deg = torch.clamp(deg, min=1.0) + A_norm = A / deg + return A_norm + + # ---------- main forward ---------- + + def forward(self, pred, target, edge_index, edge_attr, mask, x): + """ + Args: + pred: (N, F) predicted node features. + target: (N, F) ground truth node features. + edge_index: (2, E) full graph edges (can include failed edges). + edge_attr: (E, F_edge) edge features with G,B. + mask: (N, M) boolean mask used to form hybrid predictions. + + Returns: + dict with: + - "loss": scalar Voltage Detector loss + - "Voltage Detector Loss": detached scalar + """ + device = pred.device + + # 1) Hybrid prediction: target on unmasked, pred on masked + temp_pred = pred.clone() + unmasked = ~mask + temp_pred[unmasked] = target[unmasked] + + N = temp_pred.size(0) + + # 2) Voltage and REF indicator + Vm = temp_pred[:, VM] # (N,) + ref_indicator = x[:, REF] # (N,) + ref_idx = torch.argmax(ref_indicator).item() + + # 3) Build connectivity graph using only healthy edges + edge_index_healthy = self.pruned_edge_index(edge_index, edge_attr) + + # 4) Per-graph K from BFS on healthy-edge graph + K, dist, reachable = self.choose_K_for_graph(edge_index_healthy, N, ref_idx) + unreachable = ~reachable + + # 5) Node health from voltage + failed = (Vm < self.voltage_threshold).float() # (N,) + healthy_nodes = 1.0 - failed # (N,) + + # 6) Degree-normalized adjacency on healthy-edge graph + A_norm = self.build_normalized_adj(edge_index_healthy, N).to(device) # (N, N) + + # 7) Initial connectivity: only REF is connected + h = torch.zeros_like(healthy_nodes) + h[ref_idx] = 1.0 + h = h.unsqueeze(-1) # (N, 1) + + # 8) Propagate connectivity K steps through healthy nodes & edges + for _ in range(K): + neigh = torch.matmul(A_norm.t(), h) # (N, 1) + # soft AND with node health + h = torch.sigmoid(self.beta * (neigh * healthy_nodes.unsqueeze(-1))) + h[ref_idx] = 1.0 + h = h.squeeze(-1) # (N,) + + # 9) Hinge margin on voltage + margin = self.margin_factor * self.voltage_threshold + excess = torch.clamp(Vm - margin, min=0.0) + + # 10) Reachable nodes: disconnected-but-energized violations + reachable_violation_mask = reachable & (h < 0.5) & (excess > 0.0) + violation_reach = excess[reachable_violation_mask] ** 2 + loss_reach = ( + violation_reach.mean() + if violation_reach.numel() > 0 + else torch.tensor(0.0, device=device) + ) + + # 11) Unreachable (islanded) nodes: any non-zero voltage is a violation + excess_unreach = excess[unreachable] + loss_unreach = ( + (excess_unreach**2).mean() + if excess_unreach.numel() > 0 + else torch.tensor(0.0, device=device) + ) + + loss = self.penalty_scale * (loss_reach + loss_unreach) + + if self.visualization: + return { + "loss": loss, + "Voltage Detector Loss": loss.detach(), + "Loss of reachable nodes.": loss_reach.detach(), + "Loss of unreachable nodes.": loss_unreach.detach() + } + else: + return { + "loss": loss, + "Voltage Detector Loss": loss.detach(), + "Loss of reachable nodes.": loss_reach.detach(), + "Loss of unreachable nodes.": loss_unreach.detach(), + "K_used": torch.tensor(K, device=device) + } diff --git a/param_handler.py b/param_handler.py new file mode 100644 index 0000000..d615d33 --- /dev/null +++ b/param_handler.py @@ -0,0 +1,141 @@ +from gridfm_graphkit.training.loss import ( + PBELoss, + MaskedMSELoss, + SCELoss, + MixedLoss, + MSELoss, + VLDLoss, +) +from gridfm_graphkit.io.registries import ( + MASKING_REGISTRY, + NORMALIZERS_REGISTRY, + MODELS_REGISTRY, +) + +import argparse + + +class NestedNamespace(argparse.Namespace): + """ + A namespace object that supports nested structures, allowing for + easy access and manipulation of hierarchical configurations. + + """ + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + if isinstance(value, dict): + # Recursively convert dictionaries to NestedNamespace + setattr(self, key, NestedNamespace(**value)) + else: + setattr(self, key, value) + + def to_dict(self): + # Recursively convert NestedNamespace back to dictionary + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, NestedNamespace): + result[key] = value.to_dict() + else: + result[key] = value + return result + + def flatten(self, parent_key="", sep="."): + # Flatten the dictionary with dot-separated keys + items = [] + for key, value in self.__dict__.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, NestedNamespace): + items.extend(value.flatten(new_key, sep=sep).items()) + else: + items.append((new_key, value)) + return dict(items) + + +def load_normalizer(args): + """ + Load the appropriate data normalization methods + + Args: + args (NestedNamespace): contains configs. + + Returns: + tuple: Node and edge normalizers + + Raises: + ValueError: If an unknown normalization method is specified. + """ + method = args.data.normalization + + try: + return NORMALIZERS_REGISTRY.create( + method, + True, + args, + ), NORMALIZERS_REGISTRY.create(method, False, args) + except KeyError: + raise ValueError(f"Unknown transformation: {method}") + + +def get_loss_function(args): + """ + Load the appropriate loss function + + Args: + args (NestedNamespace): contains configs. + + Returns: + nn.Module: Loss function + + Raises: + ValueError: If an unknown loss function is specified. + """ + loss_functions = [] + for loss_name in args.training.losses: + if loss_name == "MSE": + loss_functions.append(MSELoss()) + elif loss_name == "MaskedMSE": + loss_functions.append(MaskedMSELoss()) + elif loss_name == "SCE": + loss_functions.append(SCELoss()) + elif loss_name == "PBE": + loss_functions.append(PBELoss()) + elif loss_name == "VLDLoss": + loss_functions.append(VLDLoss()) + else: + raise ValueError(f"Unknown loss function: {loss_name}") + + return MixedLoss(loss_functions=loss_functions, weights=args.training.loss_weights) + + +def load_model(args): + """ + Load the appropriate model + + Args: + args (NestedNamespace): contains configs. + + Returns: + nn.Module: The selected model initialized with the provided configurations. + + Raises: + ValueError: If an unknown model type is specified. + """ + model_type = args.model.type + + try: + return MODELS_REGISTRY.create(model_type, args) + except KeyError: + raise ValueError(f"Unknown model type: {model_type}") + + +def get_transform(args): + """ + Load the appropriate dataset transform from the registry. + """ + mask_type = args.data.mask_type + + try: + return MASKING_REGISTRY.create(mask_type, args) + except KeyError: + raise ValueError(f"Unknown transformation: {mask_type}") diff --git a/powergrid_dataset.py b/powergrid_dataset.py new file mode 100644 index 0000000..e11db41 --- /dev/null +++ b/powergrid_dataset.py @@ -0,0 +1,406 @@ +from gridfm_graphkit.datasets.normalizers import Normalizer, BaseMVANormalizer +from gridfm_graphkit.datasets.transforms import ( + AddEdgeWeights, + AddNormalizedRandomWalkPE, +) + +import os.path as osp +import os +import torch +from torch_geometric.data import Data, Dataset, InMemoryDataset +import pandas as pd +from tqdm import tqdm +from typing import Optional, Callable +import glob +import re +import numpy as np + + +class GridDatasetDisk(Dataset): + """ + A PyTorch Geometric `Dataset` for power grid data stored on disk. + This dataset reads node and edge CSV files, applies normalization, + and saves each graph separately on disk as a processed file. + Data is loaded from disk lazily on demand. + + Args: + root (str): Root directory where the dataset is stored. + norm_method (str): Identifier for normalization method (e.g., "minmax", "standard"). + node_normalizer (Normalizer): Normalizer used for node features. + edge_normalizer (Normalizer): Normalizer used for edge features. + pe_dim (int): Length of the random walk used for positional encoding. + mask_dim (int, optional): Number of features per-node that could be masked. + transform (callable, optional): Transformation applied at runtime. + pre_transform (callable, optional): Transformation applied before saving to disk. + pre_filter (callable, optional): Filter to determine which graphs to keep. + """ + + def __init__( + self, + root: str, + norm_method: str, + node_normalizer: Normalizer, + edge_normalizer: Normalizer, + pe_dim: int, + mask_dim: int = 6, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): + self.norm_method = norm_method + self.node_normalizer = node_normalizer + self.edge_normalizer = edge_normalizer + self.pe_dim = pe_dim + self.mask_dim = mask_dim + self.length = None + self.files = None + + super().__init__(root, transform, pre_transform, pre_filter) + + # Load normalization stats if available + node_stats_path = osp.join( + self.processed_dir, + f"node_stats_{self.norm_method}.pt", + ) + edge_stats_path = osp.join( + self.processed_dir, + f"edge_stats_{self.norm_method}.pt", + ) + if osp.exists(node_stats_path) and osp.exists(edge_stats_path): + self.node_stats = torch.load(node_stats_path, weights_only=False) + self.edge_stats = torch.load(edge_stats_path, weights_only=False) + self.node_normalizer.fit_from_dict(self.node_stats) + self.edge_normalizer.fit_from_dict(self.edge_stats) + + def scan_batch_files(self) -> tuple[list[str], list[str]]: + """ + Scan directory for batch CSV files + + Returns: + tuple: (bus_files, branch_files) sorted lists of file paths + """ + # Pattern to match batch files + raw_dir = "./data/scenario_33meshed/raw" + bus_pattern = osp.join(osp.abspath(raw_dir), "bus_batch_*.csv") + branch_pattern = osp.join(osp.abspath(raw_dir), "branch_batch_*.csv") + + # Find all matching files + bus_files = glob.glob(bus_pattern) + branch_files = glob.glob(branch_pattern) + + # Sort files numerically by batch number + def sort_files_numerically(file_list): + def extract_batch_number(filename): + match = re.search(r'batch_(\d+)', filename) + return int(match.group(1)) if match else 0 + + return sorted(file_list, key=extract_batch_number) + + bus_files_sorted = sort_files_numerically(bus_files) + branch_files_sorted = sort_files_numerically(branch_files) + + print(f"Found {len(bus_files_sorted)} bus batch files") + print(f"Found {len(branch_files_sorted)} branch batch files") + + return bus_files_sorted, branch_files_sorted + + def load_all_batch_data(self) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Load all bus and branch batch files into DataFrames + Replace empty cells with zeros + """ + bus_files, branch_files = self.scan_batch_files() + + if not bus_files or not branch_files: + # List what files were actually found + all_files = os.listdir(self.raw_dir) + csv_files = [f for f in all_files if f.endswith('.csv')] + + error_msg = ( + f"No batch files found in {self.raw_dir}.\n" + f"Expected files like: bus_batch_0001.csv, branch_batch_0001.csv\n" + f"Files found in directory: {all_files}\n" + f"CSV files found: {csv_files}" + ) + raise FileNotFoundError(error_msg) + + # Load bus batches + bus_dfs = [] + print("Loading bus batch files...") + for i, bus_file in enumerate(tqdm(bus_files, desc="Bus batches")): + try: + # Read CSV and replace empty cells with 0 + df = pd.read_csv(bus_file) + + # Replace empty strings, NaN, and None with 0 for all numeric columns + df = self._fill_empty_cells_with_zero(df) + + bus_dfs.append(df) + + except Exception as e: + print(f"Error loading {bus_file}: {e}") + continue + + if not bus_dfs: + raise ValueError("No bus data loaded from batch files") + + combined_bus_df = pd.concat(bus_dfs, ignore_index=True) + print(f"✓ Combined {len(bus_dfs)} bus batch files: {len(combined_bus_df):,} rows") + + # Load branch batches + branch_dfs = [] + for i, branch_file in enumerate(tqdm(branch_files, desc="Branch batches")): + try: + # Read CSV and replace empty cells with 0 + df = pd.read_csv(branch_file) + + # Replace empty strings, NaN, and None with 0 for all numeric columns + df = self._fill_empty_cells_with_zero(df) + + branch_dfs.append(df) + + except Exception as e: + print(f"Error loading {branch_file}: {e}") + continue + + if not branch_dfs: + raise ValueError("No branch data loaded from batch files") + + combined_branch_df = pd.concat(branch_dfs, ignore_index=True) + print(f"✓ Combined {len(branch_dfs)} branch batch files: {len(combined_branch_df):,} rows") + + return combined_bus_df, combined_branch_df + + def _fill_empty_cells_with_zero(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Replace empty cells with zeros in a DataFrame + + Parameters: + ----------- + df : pd.DataFrame + Input DataFrame + + Returns: + -------- + pd.DataFrame + DataFrame with empty cells replaced by 0 + """ + # Make a copy to avoid modifying the original + df_filled = df.copy() + + # Identify numeric columns (including integer and float) + numeric_columns = df_filled.select_dtypes(include=[np.number]).columns + + # Identify non-numeric columns that should be handled differently + non_numeric_columns = df_filled.select_dtypes(exclude=[np.number]).columns + + print(f" Processing {len(df_filled)} rows, {len(numeric_columns)} numeric columns") + + # For numeric columns: replace NaN with 0 + if len(numeric_columns) > 0: + # Count NaN values before replacement + nan_count_before = df_filled[numeric_columns].isna().sum().sum() + if nan_count_before > 0: + # Replace NaN with 0 for numeric columns + df_filled[numeric_columns] = df_filled[numeric_columns].fillna(0) + + # For non-numeric columns that might contain numeric data as strings + for col in non_numeric_columns: + # Try to convert to numeric, replacing non-convertible values with NaN first + # then fill with 0 + converted = pd.to_numeric(df_filled[col], errors='coerce') + if not converted.isna().all(): # If at least some values could be converted to numeric + df_filled[col] = converted.fillna(0) + + # Also handle empty strings in numeric columns that might have been read as objects + for col in df_filled.columns: + if df_filled[col].dtype == 'object': + # Replace empty strings with 0 + empty_string_mask = df_filled[col] == '' + empty_count = empty_string_mask.sum() + if empty_count > 0: + df_filled.loc[empty_string_mask, col] = 0 + + # Also try to convert the entire column to numeric if possible + try: + converted = pd.to_numeric(df_filled[col]) + df_filled[col] = converted + except (ValueError, TypeError): + # Column contains non-numeric values, leave as is + pass + + return df_filled + + @property + def raw_file_names(self): + #return ["pf_node.csv", "pf_edge.csv"] + return [] + + @property + def processed_done_file(self): + return f"processed_{self.norm_method}_{self.mask_dim}_{self.pe_dim}.done" + + @property + def processed_file_names(self): + return [self.processed_done_file] + + def download(self): + pass + + def process(self): + """ + node_df = pd.read_csv(osp.join(self.raw_dir, "pf_node.csv")) + edge_df = pd.read_csv(osp.join(self.raw_dir, "pf_edge.csv")) + """ + node_df, edge_df = self.load_all_batch_data() #load all batch data + + # Check the unique scenarios available + scenarios = node_df["scenario"].unique() + + # Ensure node and edge data match + """ + if not (scenarios == edge_df["scenario"].unique()).all(): + raise ValueError("Mismatch between node and edge scenario values.") + """ + edge_scenarios = edge_df["scenario"].unique() + if not set(scenarios) == set(edge_scenarios): + print(f"Warning: Mismatch between node and edge scenarios.") + print(f"Node scenarios: {len(scenarios)}, Edge scenarios: {len(edge_scenarios)}") + # Use intersection of scenarios + common_scenarios = set(scenarios) & set(edge_scenarios) + node_df = node_df[node_df["scenario"].isin(common_scenarios)] + edge_df = edge_df[edge_df["scenario"].isin(common_scenarios)] + scenarios = node_df["scenario"].unique() + print(f"Using {len(common_scenarios)} common scenarios") + + print(f"Processing {len(scenarios)} scenarios...") + + # normalize node attributes + cols_to_normalize = ["Pd", "Qd", "Pg", "Qg", "Vm", "Va"] + to_normalize = torch.tensor( + node_df[cols_to_normalize].values, + dtype=torch.float, + ) + self.node_stats = self.node_normalizer.fit(to_normalize) + node_df[cols_to_normalize] = self.node_normalizer.transform( + to_normalize, + ).numpy() + + # normalize edge attributes + cols_to_normalize = ["G", "B"] + to_normalize = torch.tensor( + edge_df[cols_to_normalize].values, + dtype=torch.float, + ) + if isinstance(self.node_normalizer, BaseMVANormalizer): + self.edge_stats = self.edge_normalizer.fit( + to_normalize, + self.node_normalizer.baseMVA, + ) + else: + self.edge_stats = self.edge_normalizer.fit(to_normalize) + edge_df[cols_to_normalize] = self.edge_normalizer.transform( + to_normalize, + ).numpy() + + # save stats + node_stats_path = osp.join( + self.processed_dir, + f"node_stats_{self.norm_method}.pt", + ) + edge_stats_path = osp.join( + self.processed_dir, + f"edge_stats_{self.norm_method}.pt", + ) + torch.save(self.node_stats, node_stats_path) + torch.save(self.edge_stats, edge_stats_path) + + # Create groupby objects for scenarios + node_groups = node_df.groupby("scenario") + edge_groups = edge_df.groupby("scenario") + + for scenario_idx in tqdm(scenarios): + # NODE DATA + node_data = node_groups.get_group(scenario_idx) + x = torch.tensor( + node_data[ + ["Pd", "Qd", "Pg", "Qg", "Vm", "Va", "PQ", "PV", "REF"] + ].values, + dtype=torch.float, + ) + y = x[:, : self.mask_dim] + + # EDGE DATA + edge_data = edge_groups.get_group(scenario_idx) + edge_attr = torch.tensor(edge_data[["G", "B"]].values, dtype=torch.float) + edge_index = torch.tensor( + edge_data[["index1", "index2"]].values.T, + dtype=torch.long, + ) + + # Create the Data object + graph_data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + y=y, + scenario_id=scenario_idx, + ) + pe_pre_transform = AddEdgeWeights() + graph_data = pe_pre_transform(graph_data) + pe_transform = AddNormalizedRandomWalkPE( + walk_length=self.pe_dim, + attr_name="pe", + ) + graph_data = pe_transform(graph_data) + torch.save( + graph_data, + osp.join( + self.processed_dir, + f"data_{self.norm_method}_{self.mask_dim}_{self.pe_dim}_index_{scenario_idx}.pt", + ), + ) + with open(osp.join(self.processed_dir, self.processed_done_file), "w") as f: + f.write("done") + + def len(self): + if self.files is None: + self.files = sorted([ + f for f in os.listdir(self.processed_dir) + if f.startswith(f"data_{self.norm_method}_{self.mask_dim}_{self.pe_dim}_index_") + and f.endswith(".pt") + ]) + return len(self.files) + + def get(self, idx): + if self.files is None: + self.len() # populate self.files + + if idx >= len(self.files): + raise IndexError(f"Requested index {idx}, but dataset has only {len(self.files)} files.") + + file_path = osp.join(self.processed_dir, self.files[idx]) + data = torch.load(file_path, weights_only=False) + if self.transform: + data = self.transform(data) + return data + + def change_transform(self, new_transform): + """ + Temporarily switch to a new transform function, used when evaluating different tasks. + + Args: + new_transform (Callable): The new transform to use. + """ + self.original_transform = self.transform + self.transform = new_transform + + def reset_transform(self): + """ + Reverts the transform to the original one set during initialization, usually called after the evaluation step. + """ + if self.original_transform is None: + raise ValueError( + "The original transform is None or the function change_transform needs to be called before", + ) + self.transform = self.original_transform