gridfm-graphkit
This library is brought to you by the GridFM team to train, finetune and interact with a foundation model for the electric power grid.
Create and activate a virtual environment (make sure you use the right python version = 3.10, 3.11 or 3.12. I highly recommend 3.12)
python -m venv venv
source venv/bin/activateInstall gridfm-graphkit in editable mode
pip install -e .Get PyTorch + CUDA version for torch-scatter
TORCH_CUDA_VERSION=$(python -c "import torch; print(torch.__version__ + ('+cpu' if torch.version.cuda is None else ''))")Install the correct torch-scatter wheel
pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH_CUDA_VERSION}.htmlFor documentation generation and unit testing, install with the optional dev and test extras:
pip install -e .[dev,test]Interface to train, fine-tune, evaluate, and run inference on GridFM models using YAML configs and MLflow tracking.
gridfm_graphkit <command> [OPTIONS]Available commands:
train- Train a new model from scratchfinetune– Fine-tune an existing pre-trained modelevaluate– Evaluate model performance on a datasetpredict– Run inference and save predictions
gridfm_graphkit train --config path/to/config.yaml| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to the training configuration YAML file. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow tracking/logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
Standard Training:
gridfm_graphkit train --config examples/config/case30_ieee_base.yaml --data_path examples/datagridfm_graphkit finetune --config path/to/config.yaml --model_path path/to/model.pt| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Fine-tuning configuration file. | None |
--model_path |
str |
Required. Path to a pre-trained model state dict. | None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Root dataset directory. | data |
gridfm_graphkit evaluate --config path/to/eval.yaml --model_path path/to/model.pt| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to evaluation config. | None |
--model_path |
str |
Path to the trained model state dict. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics instead of re-fitting on the current data split. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--compute_dc_ac_metrics |
flag |
Compute ground-truth AC/DC power balance metrics on the test split. | False |
--save_output |
flag |
Save predictions as <grid_name>_predictions.parquet under MLflow artifacts (.../artifacts/test). |
False |
When evaluating a model on a dataset, you can pass the normalizer statistics from the original training run to ensure the same normalization parameters are used:
gridfm_graphkit evaluate \
--config examples/config/HGNS_PF_datakit_case118.yaml \
--model_path mlruns/<experiment_id>/<run_id>/artifacts/model/best_model_state_dict.pt \
--normalizer_stats mlruns/<experiment_id>/<run_id>/artifacts/stats/normalizer_stats.pt \
--data_path dataNote: The
--normalizer_statsflag only affects normalizers withfit_strategy = "fit_on_train"(e.g.HeteroDataMVANormalizer). Per-sample normalizers (HeteroDataPerSampleMVANormalizer) always recompute their statistics from the current dataset regardless of this flag.
gridfm_graphkit predict --config path/to/config.yaml --model_path path/to/model.pt| Argument | Type | Description | Default |
|---|---|---|---|
--config |
str |
Required. Path to prediction config file. | None |
--model_path |
str |
Path to the trained model state dict. | None |
--normalizer_stats |
str |
Path to normalizer_stats.pt from a training run. Restores fit_on_train normalizers from saved statistics. |
None |
--exp_name |
str |
MLflow experiment name. | timestamp |
--run_name |
str |
MLflow run name. | run |
--log_dir |
str |
MLflow logging directory. | mlruns |
--data_path |
str |
Dataset directory. | data |
--output_path |
str |
Directory where predictions are saved as <grid_name>_predictions.parquet. |
data |
Use built-in help for full command details:
gridfm_graphkit --help
gridfm_graphkit <command> --help