This repository contains all the code developed as part of a CentraleSupΓ©lec project conducted in partnership with HeadMind Partners, focusing on the explainability of weather forecasting models. Specifically, we study precipitation prediction over Europe within a 6-hour forecasting horizon. Although some experiments were carried out with longer time horizons, their predictive performance was significantly lower; as a result, we chose not to include their explainability analyses in this repository. All models were trained using the ERA5 dataset from WeatherBench2.
The project is structured into two main phases:
-
Precipitation prediction. We provide scripts to download and preprocess the data, train two types of models (U-Net and ConvLSTM), and evaluate their performance. More detailed information about the files and workflows is provided in a later section.
-
Prediction explainability. We implement permutation-based methods and integrated gradients, combined with various aggregation strategies, to extract insights into the most influential input variables and time steps. These methods allow us to analyze which pixels contribute most to individual predictions, identify globally important features, explore patterns that are consistent with meteorological knowledge, and more. More detailed information about the explainability pipeline and related files is provided in a later section.
To get a local copy of this project up and running, follow these steps.
-
Clone the repository:
git clone git@github.com:manonarfib/X_Chaos_Meteo.git cd X_Chaos_Meteo -
Install dependencies:
We recommend using a virtual environment to manage dependencies.
pip install -r requirements.txt- Downloading UNet checkpoint
If you can, you should install the Git LFS extension (see https://git-lfs.com/), which handles the versioning of large files. In that case, you only need to run git lfs install (you only need to run that once in your git ), and the checkpoint is automatically usable from checkpoints/unet.
However, if you can't install the extension (beware, it isn't installed on the DCE), you can clone the repository as usual, then go to https://github.com/manonarfib/X_Chaos_Meteo/tree/main/checkpoints/unet, and manually download the checkpoint. Then you have to rename the file (we recommend to rename it best_mse_true.pt), and drag and drop it in checkpoints/unet.
This repository is organized as follows:
X_Chaos_Meteo/
βββ checkpoints/
β βββ convlstm/ # Checkpoints for the ConvLSTM model according to the loss used during training
β βββ unet/ # Checkpoint for the U-Net model corresponding to training with MSE loss
β
βββ demonstrator/
β βββ app_avec_calendrier.py # Main demonstrator file
β βββ ...
β
βββ download_dataset_from_gcs/ # Scripts to download the data from WeatherBench2
β
βββ era5_visuals/
β βββ figures/ # Created visuals
β βββ visuels_era5.ipynb # Notebook to create pretty representations of ERA5 variables
β
βββ explainability/
β βββ clusters/ # Explain rain clusters instead of the whole map prediction
β βββ explainable_by_design/ # WeatherCBM implementation
β βββ features_permutation/ # Permutation-based importance methods
β βββ integrated_gradients/ # Integrated Gradients implementation and aggregation methods
β βββ noise/ # Noise methods for explainability
β
βββ inference/
β βββ compare_3models/ # Contains maps and boxplots for compare_model.py
β βββ compare_predict_maps_outputs/ # Contains maps and boxplots for compare_predict_maps.py
β βββ predict_maps_outputs/ # Contains maps and boxplots for predict_maps.py
β βββ compare_model.py # Create boxplot and map for an inference on test set sample for different models
β βββ compare_predict_maps.py # Create boxplots and maps for an inference on test set sample for different checkpoints using the same model architecture
β βββ predict_maps.py # Create boxplot and map for an inference on test set sample for one checkpoint
β
βββ models/
β βββ ConvLSTM/ # ConvLSTM architecture and training scripts
β βββ unet/ # U-Net architecture and training scripts
β βββ mixture/ # Mixing predictions of ConvLSTM and U-Net to improve final prediction
β βββ utils/ # Preprocessing, postprocessing and evaluation scripts
β
βββ spearman_correlations/ # Contains script to compute Spearman correlations between our features
β
βββ requirements.txt # Python dependencies
βββ README.md # Project documentation
A demonstrator was developed, permitting the user to test most of the functionalities described above. It can be accessed here :
You can also download a short demonstration video if you struggle to use the demonstrator : [Demonstration video]
The notebook era5_visuals/visuels_era5.ipynb allows you to visualize and plot key variables of ERA5 dataset.
Downloading the dataset is not required to run the codes, as two weeks of data has been downloaded in this git for you (accessible in ./demonstrator/era5_europe_ml_test_2_weeks.zarr). However if you wish to retrain the models or add new data, you must follow the format we used.
In ./download_dataset_from_gcs/download_dataset.py, change :
- OUT_ZARR to the desired path,
- TIME_BLOCKS to download the period of time of your choice.
Run :
python -m download_dataset_from_gcs.download_datasetThe following data was used to train the available models :
- train set: 1980-01-01 to 2018-01-01
- validation set: 2018-01-01 to 2020-01-01
- test set: 2020-10-10 to 2022-01-01
Before training either model, make sure the ERA5 training and validation datasets are available locally and that the paths defined in the training scripts match your environment.
Typical expected files are:
era5_europe_ml_train.zarr
era5_europe_ml_validation.zarr
The ConvLSTM training pipeline is implemented in models/ConvLSTM/train_convlstm_with_downloaded_data.py, the script includes a configuration block where you can adjust:
- dataset paths:
train_dataset_pathandval_dataset_path, - sequence length:
Twith default value equals to8(inputs have a temporal window oft-42htot), - prediction lead time:
leadwith default value equals to1(we predict precipitation int+6h), - batch size:
batch_sizewe recommend to keep a low value since it could take a lot of place in memory, - loss function:
loss_typeinstr, - checkpoint and log locations:
checkpoint_dir.
Run training with:
python -m models.ConvLSTM.train_convlstm_with_downloaded_data
Supported loss functions include: MSE, weighted MSE, Dice-based loss, and a custom advanced_torrential loss designed for heavy precipitation events.
Generated checkpoints and logs are saved under checkpoints/convlstm/, and a different subfolder is create according to the loss type you used for training. Make sure to change the checkpoint location if you changed other parameters (such as lead time or sequence length) or it could erase a previous checkpoint.
The U-Net training pipeline is implemented in models/unet/training_optimized.py, the script includes a configuration block where you can adjust:
- dataset paths:
dataset_train_pathanddataset_val_path, - sequence length:
n_input_stepswith default value equals to8(inputs have a temporal window oft-42htot), - prediction lead time:
lead_stepswith default value equals to1(we predict precipitation int+6h), - batch size:
batch_sizewe recommend to keep a low value since it could take a lot of place in memory, - loss function:
loss_typeinstr.
Run training with:
python -m models.unet.training_optimized --save_path {SAVE_PATH}
Where SAVE_PATH is the path to which your checkpoints will be saved, we recommend you to put a path beginning with checkpoints/unet/.
Supported loss functions include: MSE, weighted MSE and a Dice-based losses.
The pretrained checkpoint provided in this repository corresponds to the U-Net model trained with MSE loss.
We implement a permutation-based feature importance method to quantify the influence of each input variable and timestep on the model predictions.
The idea is to measure how much the model performance degrades when a given feature is randomly permuted. More precisely:
- A baseline prediction is computed on the original input.
- For each feature (defined as a variable at a given timestep), we randomly shuffle its spatial values, run the model again, and measure the increase in prediction error (MSE).
- The importance of a feature is defined as the difference between the permuted error and the baseline error.
To compute permutation-based feature importance, one can run:
python -m explainability/features_permutation/permutation_importance
Make sure to configure:
MODEL_TYPE("unet" or "convlstm")CKPT_PATH(the checkpoint used for the model)DATASET_PATH
inside the script before execution.
The script produces:
- A
.npzfile containing raw importance scores for each sample downloaded atexplainability/features_permutation/permutation_importances_to_stack_time_and_var_<model>.npz - Aggregated visualizations: importance per variable and importance per timestep, saved in:
explainability/features_permutation/figures/
We also provide an explainability pipeline based on Integrated Gradients (IG) to identify which input variables, timesteps, and spatial regions contribute the most to the model prediction.
Integrated Gradients is a gradient-based attribution method. It computes feature attributions by integrating gradients along a path between a baseline input and the actual sample. This makes it possible to obtain both global importance summaries (aggregated over multiple samples of the test set) and local explanations (showing which pixels and variables influenced a specific prediction).
Given an input sample and a baseline, Integrated Gradients computes the attribution of each input feature by accumulating gradients along interpolated inputs between the baseline and the original sample. In this implementation, attributions are computed with respect to a target prediction region defined from the model output itself. More precisely:
- The model predicts a precipitation map.
- A region of interest is defined as the pixels above a chosen prediction quantile.
- Integrated Gradients is computed with respect to the sum of predictions over this region.
This allows the method to focus on the input features that most influence the strongest predicted precipitation areas.
The script is configured directly through a user-defined configuration block at the top of the file. Main parameters that can be modified are:
MODEL_TYPE: model to explain ("convlstm"or"unet")LOSS_NAME: used for output folder namingCKPT_PATH: path to the model checkpointDATASET_PATH: path to the dataset used for explainabilitySAMPLE_IDX: index of the sample used for local visualizationsT: number of input timestepsLEAD: prediction lead time
Aggregation settings:
DO_AGG: whether to compute global aggregated importance over multiple samplesN_SAMPLES_AGG: number of samples used for aggregationSEED: random seed for reproducibility
Attribution settings:
IG_STEPS: number of interpolation steps for Integrated GradientsBASELINE_MODE: baseline type ("zeros"or"mean_over_space_time")REGION_QUANTILE: quantile used to define the region of interest in the predicted map
To run the explainability script, execute:
python -m explainability.integrated_gradients.integrated_gradients
The script produces two types of outputs.
- Aggregated importance over multiple samples providing a global view of which variables and timesteps are the most influential for the model.
- Detailed visualizations for one selected sample
Generated files are saved under:
explainability/integrated gradients/ig_outputs/
with subfolders depending on the model (unet or convlstm), the loss name, the prediction lead time, the selected sample index.
The WeatherCBM training pipeline is implemented in explainability/explainable_by_design/training_WeatherCBM(_with_reg_on_vars).py. The file _with_reg_on_vars implements the version of WeatherCBM with additional loss terms further constraining the use of input variables by the concepts. In each case, the script includes a configuration block where you can adjust:
- dataset paths:
train_dataset_pathandval_dataset_path, - sequence length:
Twith default value equals to8(inputs have a temporal window oft-42htot), - prediction lead time:
leadwith default value equals to1(we predict precipitation int+6h), - batch size:
batch_sizewe recommend to keep a low value since it could take a lot of place in memory, - loss function:
loss_typeinstr, - checkpoint and log locations:
checkpoint_dir.
Run training with:
python -m explainability.explainable_by_design.training_WeatherCBM_with_reg_on_vars
Generated checkpoints and logs are saved under checkpoints/weathercbm/, and a different subfolder is created according to the name you specify. Make sure to change the checkpoint location if you changed other parameters (such as lead time or sequence length) or it could erase a previous checkpoint.
explainability/explainable_by_design/explain_results contains files to interpret the model.
- integrated_gradients allows you to visualize which input variables contribute the most to each concept
- predict_concept_activation saves maps corresponding to the activation of the concepts on a specific sample
- analysis_regularization gives you access to matrix A of the model with regularization on the input variables, that is to say the importance matrix of each input variable for each concept
This repository was created and equally contributed to by :
- Louisa Arfib : https://github.com/arfiblouisa
- Manon Arfib : https://github.com/manonarfib
- Nathan Morin : https://github.com/Nathan9842
A huge thank you to Florestan Fontaine from HeadMind Partners for his help and valuable advice.
This project is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).
You are permitted to use, share, and adapt the material for non-commercial purposes, provided that appropriate credit is given to the original authors.
Commercial use of this work is strictly prohibited without prior written permission from the authors.
For full license terms, see: https://creativecommons.org/licenses/by-nc/4.0/
