Skip to content

Add ConvGRU ensemble training pipeline#10

Open
franchg wants to merge 56 commits intomlcast-community:mainfrom
franchg:feat/convgru-ensemble-training
Open

Add ConvGRU ensemble training pipeline#10
franchg wants to merge 56 commits intomlcast-community:mainfrom
franchg:feat/convgru-ensemble-training

Conversation

@franchg
Copy link
Copy Markdown
Member

@franchg franchg commented Mar 26, 2026

Summary

New files

  • losses.py — CRPS, afCRPS, MaskedLoss, and build_loss() factory
  • utils.py — rain rate ↔ normalized reflectivity conversions (Marshall-Palmer)
  • models/convgru.pyRadarLightningModel with ensemble support, TensorBoard image logging, and inference API
  • configs.py@auto_config experiment factory following the weatherduck pattern
  • __main__.py — CLI entry point (python -m mlcast --zarr-path ... --csv-path ...)

Updated files

  • modules/convgru_modules.py — added ensemble generation via noisy decoder inputs
  • data/zarr_dataset.pySampledRadarDataset loading datacubes from Zarr using CSV (t, x, y) coordinates
  • data/zarr_datamodule.pyRadarDataModule with chronological train/val/test splits and augmentation
  • pyproject.toml — added fiddle, pandas, pytorch-lightning, torchvision dependencies

Usage (hackathon)

# 1. Data is already on Leonardo as Zarr
# 2. Run sampler to get CSV:
#    mlcast-dataset-sampler filter-nan ... && mlcast-dataset-sampler sample ...
# 3. Train:
python -m mlcast --zarr-path /path/to/data.zarr --csv-path /path/to/sampled.csv --variable-name RR

# Or programmatically with config overrides:
from mlcast.configs import convgru_experiment
import fiddle as fdl

cfg = convgru_experiment.as_buildable(zarr_path="...", csv_path="...", variable_name="RR")
cfg.data.batch_size = 32
cfg.trainer.max_epochs = 50
fdl.build(cfg).run()

Closes #7 (partially — brings ensemble ConvGRU into mlcast)
Implements #5 (Fiddle configuration)

Test plan

  • Install with uv pip install -e . and verify imports work
  • Run python -m mlcast --help to verify CLI
  • Train on a small Zarr + CSV on Leonardo to verify end-to-end pipeline
  • Verify TensorBoard logging works

franchg added 6 commits March 26, 2026 13:02
Transfer the training-ready ConvGRU ensemble model from ConvGRU-Ensemble
into mlcast, enabling participants at the hackathon to train models on
sampled radar datasets.

New modules:
- losses.py: CRPS, afCRPS, MaskedLoss, and build_loss() factory
- utils.py: rain rate <-> normalized reflectivity conversions
- models/convgru.py: RadarLightningModel with ensemble support
- configs.py: @auto_config experiment factory (weatherduck pattern)
- __main__.py: CLI entry point for training

Updated modules:
- modules/convgru_modules.py: ensemble generation via noisy decoder
- data/zarr_dataset.py: SampledRadarDataset using CSV coordinates
- data/zarr_datamodule.py: RadarDataModule with chronological splits
- pyproject.toml: add fiddle, pandas, pytorch-lightning, torchvision deps
Replace argparse with Fiddle's absl_flags integration so that any
config parameter can be overridden from the command line using
--config set:key.path=value syntax. This is essential for the
hackathon where participants need to iterate quickly on HPC.

Usage:
  python -m mlcast \
      --config config:convgru_experiment \
      --config set:data.zarr_path=/path/to/data.zarr \
      --config set:data.csv_path=/path/to/sampled.csv \
      --config set:data.batch_size=32 \
      --config set:pl_module.num_blocks=4 \
      --config set:trainer.max_epochs=50

Also adds absl-py, etils, importlib-resources as dependencies.
Restructure CLI to use subcommands so future commands (test, predict)
can be added alongside train.
Adds [project.scripts] so 'uv run mlcast train' works alongside
'python -m mlcast train'.
@franchg franchg requested a review from leifdenby April 1, 2026 12:27
- Create `tests/data/test_normalization.py` to verify the symmetry of
  `rainrate_to_normalized` and `normalized_to_rainrate`.
- Create `tests/test_losses.py` to verify the expected tensor output shapes
  for `CRPS` and `afCRPS` across different reduction modes.

This establishes the initial test coverage required for Phase 1
before refactoring the utilities and loss modules.
- Rename `src/mlcast/utils.py` to `src/mlcast/data/normalization.py`
- Update all references in `zarr_dataset.py`, `convgru.py`, and tests
- Create `NORMALIZATION_REGISTRY` in `normalization.py` to map CF standard names to their normalization functions.
- Add test to verify the registry mapping.
- Create `src/mlcast/visualization.py` to house visualization utilities.
- Move `apply_radar_colormap` and `log_images` out of `convgru.py`.
- Update `RadarLightningModel` to use the extracted `log_images`.
- Rename `afCRPS` to `AFCRPS` to follow naming conventions.
- Update `build_loss` signature to explicitly default to `loss_class="mse"`.
- Add type checking in `build_loss` to ensure `loss_class` is a string or class.
- Use a dedicated `LossClass` variable internally for instantiation.
- Expand docstrings to include explicit expected tensor shapes.
- Update `__main__.py` to use argparse subparsers and default to `training_experiment`.
- Rename `convgru_experiment` to `training_experiment` in `configs.py`.
- Update TensorBoardLogger default name to `mlcast`.
- Defer documentation updates to Phase 7 to align with implementation timeline.
…dule

- Rename `src/mlcast/models/convgru.py` to `base.py`.
- Refactor `RadarLightningModel` into a generic `NowcastLightningModule` that accepts an injected PyTorch `nn.Module`.
- Extract the core `EncoderDecoder` logic into a separate file (`src/mlcast/modules/convgru_modules.py`) and rename it to `ConvGruModel`.
- Update `src/mlcast/configs.py` to use Fiddle to inject `ConvGruModel` into `NowcastLightningModule`.

This completes Phase 2A of the restructuring plan, establishing a clean separation of concerns between the training orchestrator and the underlying neural network architecture.
- Add `jaxtyping` and `beartype` dependencies for static and runtime type checking.
- Add rigorous NumPy-style docstrings to all methods in `base.py`.
- Decorate PyTorch module `forward()` methods with `@jaxtyped(typechecker=beartype)` to enforce shape constraints at runtime.
- Fix static type errors identified by `mypy`.
- Add `storage_options` to dataset factory classes to support reading Zarr stores anonymously from S3 object storage
- Create `use_anon_s3_dataset` Fiddler in `mlcast.config.fiddlers` to easily configure the dataset factory for remote AWS connection strings
- Revert default experiment configuration to use local dummy paths and `rainfall_rate` rather than hardcoding the Italian S3 dataset
- Implement graceful CF `standard_name` validation inside dataset classes to intercept `cf_xarray` KeyErrors. When a requested variable isn't found, emit a clear `ValueError` listing all available valid CF standard names in the dataset, along with CLI hints on how to select them.
- Add `use_anon_s3_dataset` usage example to dynamic CLI help text.
- Create standalone script at `examples/scripts/download_mlcast_dataset_sample.py` to download temporal slices of remote datasets.
- Use `mlcast_datasets.open_catalog()` and Intake to dynamically resolve remote paths.
- Support dot-notation catalog traversal (e.g. `precipitation.radklim_hourly`).
- Automatically mirror remote dataset directory structure onto the local filesystem, using `mode='w'` to safely overwrite existing local caches.
- Provide a Dask progress bar to show download/writing status.
- Include an optional `--data-stage` argument (defaulting to `source_data`) designed to validate the stage but temporarily bypass traversal until the catalog structure updates to support it as the root node.
Adds optional `gpu-cu128` and `gpu-cu130` extras so users can opt into
CUDA-enabled torch builds via `uv sync --extra gpu-cu128/cu130`, while
keeping CPU torch as the default. Updates README with install instructions.
- Promote nvidia-ml-py and psutil to core dependencies (required for
  MLflow system metrics monitoring)
- Add LogSystemInfoCallback: logs system/git tags as MLflow run tags,
  starts SystemMetricsMonitor manually, prints run URL at train start
- Fix GPU warning suppression: patch gpu_monitor._logger immediately
  before SystemMetricsMonitor.start() to avoid being reset by
  mlflow.__init__ dictConfig call
- Add use_mlflow_logger() fiddler: swaps TensorBoardLogger for
  MLFlowLogger, appends LogSystemInfoCallback to trainer callbacks
- Export use_mlflow_logger from config/__init__.py
- log_images() now dispatches on TensorBoardLogger vs MLFlowLogger,
  converting tensors to PIL for MLflow's log_image API
- Pass self.logger (not self.logger.experiment) to log_images()
Add rainfall_amount_5min_to_normalized and normalized_to_rainfall_amount_5min,
converting 5-minute accumulated rainfall (kg m-2 = mm) to/from normalized
reflectivity via Marshall-Palmer Z-R relationship using a 1/12 h accumulation
factor. Register under 'rainfall_amount' in NORMALIZATION_REGISTRY and
DENORMALIZATION_REGISTRY.
Replace check_unquoted_fiddle_strings() (warn-only) with
auto_quote_fiddle_strings() which automatically wraps bare string values
in single quotes before passing to absl/Fiddle. Fiddle uses
ast.literal_eval to parse override values, so unquoted strings would
otherwise cause a parse error.
logging.config.dictConfig (called by mlflow.__init__) resets logger
levels on existing loggers but does not clear their filters. Replacing
setLevel(ERROR) with a logging.Filter subclass makes the suppression
robust to any subsequent dictConfig calls regardless of timing.
The approach of patching mlflow's gpu_monitor logger (both via setLevel
and logging.Filter) did not reliably suppress the warnings in practice.
Removing the dead code for now.
Calling torch.set_float32_matmul_precision('high') at training startup
silences PyTorch's warning about underutilised Tensor Cores and enables
TF32 for matrix multiplications, improving throughput with negligible
impact on training precision.
…g.yaml

Introduces load_yaml_config() (config/loader.py) which deserialises a
Fiddle YAML dump back into a fdl.Config by mirroring the representers
registered by fiddle._src.experimental.yaml_serialization in reverse
(using a custom _FiddleLoader subclass of yaml.SafeLoader).

The CLI (mlcast train) now detects a YAML file passed as --config,
removes it from the remaining argv, loads it, and seeds the internal
state of Fiddle's FiddleFlag directly so that all subsequent set: and
fiddler: overrides are applied by Fiddle's own flag machinery without
any custom override logic.

Also fixes MLflow hyperparameter logging: fdp.as_dict_flattened produces
keys like trainer.callbacks[0].monitor which MLflow rejects. Brackets
are replaced with dot notation (e.g. .0.) before logging.
…g Path

- Rename fixture to follow fp_ convention and return Path instead of
  an open dataset, so each test and dataloader worker can open the
  store independently
- Update all tests using the fixture (test_fixture.py,
  test_cli_training.py, test_source_datasets.py)
- Add test_cli_train_from_yaml_config exercising the new YAML load path
- Update __main__.py module docstring to reflect current CLI usage
- Add AGENTS.md documenting project conventions
…CLI, MLflow/WandB config upload, and YAML-based config loading
… CLI help, add config-diagram pre-commit hook
…riables skips input_channels if not in network signature
…ion with Mermaid diagram, mfai adapter example, and README snippet integration tests
…t classes

- DatasetSample TypedDict introduced; __getitem__ now returns input/target/target_mask
- target_mask computed per-timestep per-channel from target before nan_to_num (fixes collapsed-mask bug)
- forecast_steps removed from NowcastLightningModule; shared_step derives it from future.shape[1]
- forecast_steps moved to dataset_factory in base config and fiddlers
- Contract 3 (steps > forecast_steps) removed from consistency_checks; now enforced by dataset guard
- Old contracts 4/5 renumbered to 3/4; tests updated accordingly
- README: all pl_module.forecast_steps references updated to dataset_factory.forecast_steps
- Regenerate config diagram to reflect moved forecast_steps
- SourceDataDatasetBase extracts all shared logic: __init__, input_steps
  property, ds property, _validate_standard_names, _apply_augmentations,
  and _build_sample
- _build_sample centralises post-isel processing: mask capture, nan_to_num,
  input/target split, augmentations, and DatasetSample assembly
- DatasetSample constructed without mask first; target_mask added
  conditionally, eliminating duplicated if/else branches
- Both subclass __getitem__ reduced to isel slicing + _build_sample call
- _detect_axes extracted as module-private free function taking (ds,
  standard_name); sets t_dim/y_dim/x_dim via return value rather than
  side effects; stacklevel=3 accounts for the extra call frame
… steps becomes a property

- input_steps + forecast_steps replace steps as the two primary constructor
  params on SourceDataDatasetBase and both subclasses; steps is now a
  @Property returning input_steps + forecast_steps
- Guards updated: input_steps < 1 and forecast_steps < 1 raise ValueError
- config/base.py: steps=18, forecast_steps=12 -> input_steps=6, forecast_steps=12
- config/fiddlers.py: use_random_sampler forwards input_steps instead of steps
- README: past_steps renamed to input_steps throughout; mfai HalfUNetNowcaster
  example updated to accept input_steps/num_vars and compute in_channels
  internally; einops.rearrange used for channel-stacking with explanatory
  comments; config site reads cfg.data.dataset_factory.input_steps directly
@leifdenby
Copy link
Copy Markdown
Member

leifdenby commented Apr 30, 2026

Ok, @franchg I think I'm done with my refactor 🥳 I hope you like what I've done to your code 🎁 ➡️ https://github.com/leifdenby/mlcast/tree/feat/convgru-ensemble-training

I've (with my agent's help) written up this list of changes (below). I went through this by hand and added/removed detail so I think it gives quite a complete overview. As well as reading the list below (hopefully it makes sense...) if you could also read through the updated README that would be great. I tried to take care to make it clear how the CLI now works, what the code structure is, and what ConvGRU (our first architecture) is. In the README I have also added an example of how to couple in a new architecture not yet in the mlcast codebase (by borrowing the HalfUNet from mfai).

If you are happy with these change my suggestion is that:

  • we merge my commits into your branch, I've made a PR here Feat/convgru ensemble training franchg/mlcast#1
  • we squash merge this PR into main
  • we tag this as v0.1.0 of mlcast and I write a changelog entry explaining the functionality that is there now.

We can then build on this when we discover things I have overlooked :)

Changes

What Changed

  • reworked the config/orchestration layer so dataset, model, trainer, logger, and remote-data choices are expressed as explicit Fiddle config nodes plus fiddlers, by splitting the old config surface into config/base.py, fiddlers.py, consistency_checks.py, loader.py, and orchestrator.py.
    • The idea here is to have a "base" config (ConvGRU training to a path that doesn't exist, maybe this shoudl change?), and use --config set: to set specific config parameters or apply fiddlers where multiple parameters need to be changed together (for example list of variables to load from zarr and number if input channels in the model)
    • once the config is built we then apply consistency checks that tries to ensure that the config is self-consistent (is this dataset compatible with this model?), and then finally built the Experiment object that actually calls the Training.fit() training loop start method.
    • I also made it so that the base config can be replaced either through 1) naming an alternative config (we might want to create base configs for different models in future maybe) or 2) passing in a path to a .yaml-file which is then loaded and parsed as a fiddle Config object.
  • separated the generic training wrapper from the ConvGRU architecture by refactoring the old Lightning model into a NowcastLightningModule that accepts an injected network, and by moving ConvGRU-specific logic into the model layer.
  • pushed more of the temporal contract into the dataset layer by moving forecast_steps out of the Lightning module and making input_steps and forecast_steps the primary dataset parameters, with steps derived from them.
  • made the dataset sample interface more explicit by changing dataset outputs from a single time tensor to a structured sample with input, target, and target_mask.
  • on variable selection: generalized variable selection by switching dataset lookup to CF standard_name via cf_xarray, instead of relying on a broader set of source-specific variable names.
    • made normalization follow the same abstraction by introducing normalization registries keyed by standard_name, so preprocessing is aligned with how variables are selected.
    • improved support for multi-variable inputs by deriving channel-related behavior from standard_names and by updating the custom-network examples to compute dimensions from config instead of assuming a single variable.
  • added a second sampling strategy by introducing SourceDataRandomSamplingDataset alongside the precomputed sampler (useful when not wanting to generate a sampler CSV, and makes space for switching other samplers in future), and then made the data module consume a dataset_factory so either sampler can be injected through config.
  • fixed target masking semantics by computing masks only on the forecast target window and per timestep/per channel, before NaNs are filled, instead of collapsing validity across time (I think there was a bug here before, the pytorch.Dataset did reduction of time to create a mask whereas the nn.Module.forward() call for the forecasting part did time-based indexing)
  • introduced use of jaxtyping and beartype to communicate what the shapes of tensors mean, and to check that calls are consistent on this. This will help (I hope) to make the interface to our nn.Module.forward(...) methods clear so that it will be simpler to add new architectures in future.
  • moved some model-adjacent responsibilities out of the core architecture by separating visualization logic and simplifying loss construction and naming.
  • expanded test coverage across config, data, normalization, losses, CLI, orchestrator, and README snippets, so the refactor is checked through the interfaces people actually use.
  • expanded the documentation substantially by rewriting the README, adding architecture diagrams, adding a config diagram generator/hook, documenting custom model integration, and tightening the install instructions.
  • added MLflow as a config-level logging option by introducing a use_mlflow_logger fiddler plus a dedicated callback for MLflow tags, system metrics, and run URLs, rather than wiring MLflow directly into the training loop.

Why

All these changes were motivate by a desire to:

  • make the system easier to reason about as a config-driven training stack, so the important choices live in the config graph instead of being partly hidden in training code.
  • reduce coupling between dataset shape semantics, forecast horizon handling, model architecture, and Lightning orchestration, because those concerns looked like they were bleeding into each other.
  • make custom architectures easier to plug in, especially non-ConvGRU models, without forcing them to conform to assumptions that I think really belong to one specific model family.
  • make the temporal contract clearer by distinguishing “what the model sees” (input_steps) from “what the model predicts” (forecast_steps) at the dataset/config boundary.
  • make variable handling more robust by using CF standard_name as the common key for selection, validation, and normalization.
  • make data access more portable by supporting both local and anonymous S3-backed Zarr stores through the same dataset/config interface.
  • make experimentation easier by allowing sampler choice, logger choice, and some data/model behaviors to be swapped through fiddlers instead of code edits.
  • make masking behavior more correct and less surprising, especially for future probabilistic or multi-step work where per-timestep target validity matters.
  • make the data layer more extensible by giving both samplers the same interface and shared base behavior, which should lower the cost of adding more dataset variants later.
  • make the project easier for other people to pick up and modify by strengthening tests, docs, examples, and generated diagrams around the refactored interfaces.

What This Now Enables

With this, it is now:

  • easier to plug in non-ConvGRU networks, because the Lightning wrapper is more generic, the validation logic is less tied to ConvGRU-only parameters and input_steps is directly available on the dataset config.
  • more natural to support multiple variables end-to-end, because variable lookup, normalization, and example model wiring all derive from standard_names.
  • possible to swap in different sampling approaches, because the data module now consumes an injected dataset_factory.

In addition the support for mlflow and remote S3 Zarr stores were convenient for my own experimentation, but I thought that enabling this flexiblity should make it easier to add more choices here in future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improving the ConvGUR baseline (i.e. make ConvGRU an ensemble model)

2 participants