Skip to content

manonarfib/X_Chaos_Meteo

Repository files navigation

Explainability in a chaotic system – Application to weather forecasting

πŸ’‘ Overview

This repository contains all the code developed as part of a CentraleSupΓ©lec project conducted in partnership with HeadMind Partners, focusing on the explainability of weather forecasting models. Specifically, we study precipitation prediction over Europe within a 6-hour forecasting horizon. Although some experiments were carried out with longer time horizons, their predictive performance was significantly lower; as a result, we chose not to include their explainability analyses in this repository. All models were trained using the ERA5 dataset from WeatherBench2.

The project is structured into two main phases:

  1. Precipitation prediction. We provide scripts to download and preprocess the data, train two types of models (U-Net and ConvLSTM), and evaluate their performance. More detailed information about the files and workflows is provided in a later section.

  2. Prediction explainability. We implement permutation-based methods and integrated gradients, combined with various aggregation strategies, to extract insights into the most influential input variables and time steps. These methods allow us to analyze which pixels contribute most to individual predictions, identify globally important features, explore patterns that are consistent with meteorological knowledge, and more. More detailed information about the explainability pipeline and related files is provided in a later section.

πŸ“¦ Getting Started

To get a local copy of this project up and running, follow these steps.

  1. Clone the repository:

    git clone git@github.com:manonarfib/X_Chaos_Meteo.git
    cd X_Chaos_Meteo
  2. Install dependencies:

We recommend using a virtual environment to manage dependencies.

pip install -r requirements.txt
  1. Downloading UNet checkpoint

If you can, you should install the Git LFS extension (see https://git-lfs.com/), which handles the versioning of large files. In that case, you only need to run git lfs install (you only need to run that once in your git ), and the checkpoint is automatically usable from checkpoints/unet. However, if you can't install the extension (beware, it isn't installed on the DCE), you can clone the repository as usual, then go to https://github.com/manonarfib/X_Chaos_Meteo/tree/main/checkpoints/unet, and manually download the checkpoint. Then you have to rename the file (we recommend to rename it best_mse_true.pt), and drag and drop it in checkpoints/unet.

πŸ“– Usage

πŸ—‚οΈ Repository Structure Description

This repository is organized as follows:

X_Chaos_Meteo/
β”œβ”€β”€ checkpoints/
β”‚   β”œβ”€β”€ convlstm/                       # Checkpoints for the ConvLSTM model according to the loss used during training
β”‚   └── unet/                           # Checkpoint for the U-Net model corresponding to training with MSE loss
β”‚   
β”œβ”€β”€ demonstrator/
β”‚   β”œβ”€β”€ app_avec_calendrier.py          # Main demonstrator file
β”‚   └── ...
β”‚
β”œβ”€β”€ download_dataset_from_gcs/          # Scripts to download the data from WeatherBench2
β”‚
β”œβ”€β”€ era5_visuals/
β”‚   β”œβ”€β”€ figures/                        # Created visuals
β”‚   └── visuels_era5.ipynb              # Notebook to create pretty representations of ERA5 variables
β”‚
β”œβ”€β”€ explainability/
β”‚   β”œβ”€β”€ clusters/                       # Explain rain clusters instead of the whole map prediction
β”‚   β”œβ”€β”€ explainable_by_design/          # WeatherCBM implementation
β”‚   β”œβ”€β”€ features_permutation/           # Permutation-based importance methods
β”‚   β”œβ”€β”€ integrated_gradients/           # Integrated Gradients implementation and aggregation methods
β”‚   └── noise/                          # Noise methods for explainability
β”‚
β”œβ”€β”€ inference/
β”‚   β”œβ”€β”€ compare_3models/                # Contains maps and boxplots for compare_model.py
β”‚   β”œβ”€β”€ compare_predict_maps_outputs/   # Contains maps and boxplots for compare_predict_maps.py
β”‚   β”œβ”€β”€ predict_maps_outputs/           # Contains maps and boxplots for predict_maps.py
β”‚   β”œβ”€β”€ compare_model.py                # Create boxplot and map for an inference on test set sample for different models
β”‚   β”œβ”€β”€ compare_predict_maps.py         # Create boxplots and maps for an inference on test set sample for different checkpoints using the same model architecture
β”‚   └── predict_maps.py                 # Create boxplot and map for an inference on test set sample for one checkpoint 
β”‚
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ ConvLSTM/                       # ConvLSTM architecture and training scripts
β”‚   β”œβ”€β”€ unet/                           # U-Net architecture and training scripts
β”‚   β”œβ”€β”€ mixture/                        # Mixing predictions of ConvLSTM and U-Net to improve final prediction
β”‚   └── utils/                          # Preprocessing, postprocessing and evaluation scripts
β”‚
β”œβ”€β”€ spearman_correlations/              # Contains script to compute Spearman correlations between our features
β”‚
β”œβ”€β”€ requirements.txt                    # Python dependencies
└── README.md                           # Project documentation

πŸ–₯️ Demonstrator

A demonstrator was developed, permitting the user to test most of the functionalities described above. It can be accessed here :

Open in Streamlit

You can also download a short demonstration video if you struggle to use the demonstrator : [Demonstration video]

πŸ” Visualizing some variables

The notebook era5_visuals/visuels_era5.ipynb allows you to visualize and plot key variables of ERA5 dataset.

πŸ“š Downloading the dataset

Downloading the dataset is not required to run the codes, as two weeks of data has been downloaded in this git for you (accessible in ./demonstrator/era5_europe_ml_test_2_weeks.zarr). However if you wish to retrain the models or add new data, you must follow the format we used.

In ./download_dataset_from_gcs/download_dataset.py, change :

  • OUT_ZARR to the desired path,
  • TIME_BLOCKS to download the period of time of your choice.

Run :

python -m download_dataset_from_gcs.download_dataset

The following data was used to train the available models :

  • train set: 1980-01-01 to 2018-01-01
  • validation set: 2018-01-01 to 2020-01-01
  • test set: 2020-10-10 to 2022-01-01

🌧️ Training a weather forecasting model

Before training either model, make sure the ERA5 training and validation datasets are available locally and that the paths defined in the training scripts match your environment.

Typical expected files are:

era5_europe_ml_train.zarr
era5_europe_ml_validation.zarr

ConvLSTM

The ConvLSTM training pipeline is implemented in models/ConvLSTM/train_convlstm_with_downloaded_data.py, the script includes a configuration block where you can adjust:

  • dataset paths: train_dataset_path and val_dataset_path,
  • sequence length: T with default value equals to 8 (inputs have a temporal window of t-42h to t),
  • prediction lead time: lead with default value equals to 1 (we predict precipitation in t+6h),
  • batch size: batch_size we recommend to keep a low value since it could take a lot of place in memory,
  • loss function: loss_type in str,
  • checkpoint and log locations: checkpoint_dir.

Run training with:

python -m models.ConvLSTM.train_convlstm_with_downloaded_data

Supported loss functions include: MSE, weighted MSE, Dice-based loss, and a custom advanced_torrential loss designed for heavy precipitation events.

Generated checkpoints and logs are saved under checkpoints/convlstm/, and a different subfolder is create according to the loss type you used for training. Make sure to change the checkpoint location if you changed other parameters (such as lead time or sequence length) or it could erase a previous checkpoint.

U-Net

The U-Net training pipeline is implemented in models/unet/training_optimized.py, the script includes a configuration block where you can adjust:

  • dataset paths: dataset_train_path and dataset_val_path,
  • sequence length: n_input_steps with default value equals to 8 (inputs have a temporal window of t-42h to t),
  • prediction lead time: lead_steps with default value equals to 1 (we predict precipitation in t+6h),
  • batch size: batch_size we recommend to keep a low value since it could take a lot of place in memory,
  • loss function: loss_type in str.

Run training with:

python -m models.unet.training_optimized --save_path {SAVE_PATH}

Where SAVE_PATH is the path to which your checkpoints will be saved, we recommend you to put a path beginning with checkpoints/unet/.
Supported loss functions include: MSE, weighted MSE and a Dice-based losses.
The pretrained checkpoint provided in this repository corresponds to the U-Net model trained with MSE loss.

πŸ”¬ Explaining a pretrained model

Importance by feature permutation

We implement a permutation-based feature importance method to quantify the influence of each input variable and timestep on the model predictions.
The idea is to measure how much the model performance degrades when a given feature is randomly permuted. More precisely:

  1. A baseline prediction is computed on the original input.
  2. For each feature (defined as a variable at a given timestep), we randomly shuffle its spatial values, run the model again, and measure the increase in prediction error (MSE).
  3. The importance of a feature is defined as the difference between the permuted error and the baseline error.

To compute permutation-based feature importance, one can run:

python -m explainability/features_permutation/permutation_importance

Make sure to configure:

  • MODEL_TYPE ("unet" or "convlstm")
  • CKPT_PATH (the checkpoint used for the model)
  • DATASET_PATH

inside the script before execution.

The script produces:

  • A .npz file containing raw importance scores for each sample downloaded at explainability/features_permutation/permutation_importances_to_stack_time_and_var_<model>.npz
  • Aggregated visualizations: importance per variable and importance per timestep, saved in: explainability/features_permutation/figures/

Integrated Gradients methods

We also provide an explainability pipeline based on Integrated Gradients (IG) to identify which input variables, timesteps, and spatial regions contribute the most to the model prediction.

Integrated Gradients is a gradient-based attribution method. It computes feature attributions by integrating gradients along a path between a baseline input and the actual sample. This makes it possible to obtain both global importance summaries (aggregated over multiple samples of the test set) and local explanations (showing which pixels and variables influenced a specific prediction).

Given an input sample and a baseline, Integrated Gradients computes the attribution of each input feature by accumulating gradients along interpolated inputs between the baseline and the original sample. In this implementation, attributions are computed with respect to a target prediction region defined from the model output itself. More precisely:

  1. The model predicts a precipitation map.
  2. A region of interest is defined as the pixels above a chosen prediction quantile.
  3. Integrated Gradients is computed with respect to the sum of predictions over this region.

This allows the method to focus on the input features that most influence the strongest predicted precipitation areas.

The script is configured directly through a user-defined configuration block at the top of the file. Main parameters that can be modified are:

  • MODEL_TYPE: model to explain ("convlstm" or "unet")
  • LOSS_NAME: used for output folder naming
  • CKPT_PATH: path to the model checkpoint
  • DATASET_PATH: path to the dataset used for explainability
  • SAMPLE_IDX: index of the sample used for local visualizations
  • T: number of input timesteps
  • LEAD: prediction lead time

Aggregation settings:

  • DO_AGG: whether to compute global aggregated importance over multiple samples
  • N_SAMPLES_AGG: number of samples used for aggregation
  • SEED: random seed for reproducibility

Attribution settings:

  • IG_STEPS: number of interpolation steps for Integrated Gradients
  • BASELINE_MODE: baseline type ("zeros" or "mean_over_space_time")
  • REGION_QUANTILE: quantile used to define the region of interest in the predicted map

To run the explainability script, execute:

python -m explainability.integrated_gradients.integrated_gradients

The script produces two types of outputs.

  • Aggregated importance over multiple samples providing a global view of which variables and timesteps are the most influential for the model.
  • Detailed visualizations for one selected sample

Generated files are saved under: explainability/integrated gradients/ig_outputs/ with subfolders depending on the model (unet or convlstm), the loss name, the prediction lead time, the selected sample index.

Training WeatherCBM (explainable-by-design model) and interpreting it

Training

The WeatherCBM training pipeline is implemented in explainability/explainable_by_design/training_WeatherCBM(_with_reg_on_vars).py. The file _with_reg_on_vars implements the version of WeatherCBM with additional loss terms further constraining the use of input variables by the concepts. In each case, the script includes a configuration block where you can adjust:

  • dataset paths: train_dataset_path and val_dataset_path,
  • sequence length: T with default value equals to 8 (inputs have a temporal window of t-42h to t),
  • prediction lead time: lead with default value equals to 1 (we predict precipitation in t+6h),
  • batch size: batch_size we recommend to keep a low value since it could take a lot of place in memory,
  • loss function: loss_type in str,
  • checkpoint and log locations: checkpoint_dir.

Run training with:

python -m explainability.explainable_by_design.training_WeatherCBM_with_reg_on_vars

Generated checkpoints and logs are saved under checkpoints/weathercbm/, and a different subfolder is created according to the name you specify. Make sure to change the checkpoint location if you changed other parameters (such as lead time or sequence length) or it could erase a previous checkpoint.

Interpretation

explainability/explainable_by_design/explain_results contains files to interpret the model.

  • integrated_gradients allows you to visualize which input variables contribute the most to each concept
  • predict_concept_activation saves maps corresponding to the activation of the concepts on a specific sample
  • analysis_regularization gives you access to matrix A of the model with regularization on the input variables, that is to say the importance matrix of each input variable for each concept

🀝 Authors

This repository was created and equally contributed to by :

⭐ Acknowledgment

A huge thank you to Florestan Fontaine from HeadMind Partners for his help and valuable advice.

πŸ“„ License

This project is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).

You are permitted to use, share, and adapt the material for non-commercial purposes, provided that appropriate credit is given to the original authors.

Commercial use of this work is strictly prohibited without prior written permission from the authors.

For full license terms, see: https://creativecommons.org/licenses/by-nc/4.0/

About

Prediction and interpretable analysis of short-term precipitation using deep learning models

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors