Add ConvGRU ensemble training pipeline#10
Add ConvGRU ensemble training pipeline#10franchg wants to merge 56 commits intomlcast-community:mainfrom
Conversation
Transfer the training-ready ConvGRU ensemble model from ConvGRU-Ensemble into mlcast, enabling participants at the hackathon to train models on sampled radar datasets. New modules: - losses.py: CRPS, afCRPS, MaskedLoss, and build_loss() factory - utils.py: rain rate <-> normalized reflectivity conversions - models/convgru.py: RadarLightningModel with ensemble support - configs.py: @auto_config experiment factory (weatherduck pattern) - __main__.py: CLI entry point for training Updated modules: - modules/convgru_modules.py: ensemble generation via noisy decoder - data/zarr_dataset.py: SampledRadarDataset using CSV coordinates - data/zarr_datamodule.py: RadarDataModule with chronological splits - pyproject.toml: add fiddle, pandas, pytorch-lightning, torchvision deps
Replace argparse with Fiddle's absl_flags integration so that any
config parameter can be overridden from the command line using
--config set:key.path=value syntax. This is essential for the
hackathon where participants need to iterate quickly on HPC.
Usage:
python -m mlcast \
--config config:convgru_experiment \
--config set:data.zarr_path=/path/to/data.zarr \
--config set:data.csv_path=/path/to/sampled.csv \
--config set:data.batch_size=32 \
--config set:pl_module.num_blocks=4 \
--config set:trainer.max_epochs=50
Also adds absl-py, etils, importlib-resources as dependencies.
Restructure CLI to use subcommands so future commands (test, predict) can be added alongside train.
Adds [project.scripts] so 'uv run mlcast train' works alongside 'python -m mlcast train'.
- Create `tests/data/test_normalization.py` to verify the symmetry of `rainrate_to_normalized` and `normalized_to_rainrate`. - Create `tests/test_losses.py` to verify the expected tensor output shapes for `CRPS` and `afCRPS` across different reduction modes. This establishes the initial test coverage required for Phase 1 before refactoring the utilities and loss modules.
- Rename `src/mlcast/utils.py` to `src/mlcast/data/normalization.py` - Update all references in `zarr_dataset.py`, `convgru.py`, and tests
- Create `NORMALIZATION_REGISTRY` in `normalization.py` to map CF standard names to their normalization functions. - Add test to verify the registry mapping.
- Create `src/mlcast/visualization.py` to house visualization utilities. - Move `apply_radar_colormap` and `log_images` out of `convgru.py`. - Update `RadarLightningModel` to use the extracted `log_images`.
- Rename `afCRPS` to `AFCRPS` to follow naming conventions. - Update `build_loss` signature to explicitly default to `loss_class="mse"`. - Add type checking in `build_loss` to ensure `loss_class` is a string or class. - Use a dedicated `LossClass` variable internally for instantiation. - Expand docstrings to include explicit expected tensor shapes.
- Update `__main__.py` to use argparse subparsers and default to `training_experiment`. - Rename `convgru_experiment` to `training_experiment` in `configs.py`. - Update TensorBoardLogger default name to `mlcast`. - Defer documentation updates to Phase 7 to align with implementation timeline.
…dule - Rename `src/mlcast/models/convgru.py` to `base.py`. - Refactor `RadarLightningModel` into a generic `NowcastLightningModule` that accepts an injected PyTorch `nn.Module`. - Extract the core `EncoderDecoder` logic into a separate file (`src/mlcast/modules/convgru_modules.py`) and rename it to `ConvGruModel`. - Update `src/mlcast/configs.py` to use Fiddle to inject `ConvGruModel` into `NowcastLightningModule`. This completes Phase 2A of the restructuring plan, establishing a clean separation of concerns between the training orchestrator and the underlying neural network architecture.
- Add `jaxtyping` and `beartype` dependencies for static and runtime type checking. - Add rigorous NumPy-style docstrings to all methods in `base.py`. - Decorate PyTorch module `forward()` methods with `@jaxtyped(typechecker=beartype)` to enforce shape constraints at runtime. - Fix static type errors identified by `mypy`.
- Add `storage_options` to dataset factory classes to support reading Zarr stores anonymously from S3 object storage - Create `use_anon_s3_dataset` Fiddler in `mlcast.config.fiddlers` to easily configure the dataset factory for remote AWS connection strings - Revert default experiment configuration to use local dummy paths and `rainfall_rate` rather than hardcoding the Italian S3 dataset - Implement graceful CF `standard_name` validation inside dataset classes to intercept `cf_xarray` KeyErrors. When a requested variable isn't found, emit a clear `ValueError` listing all available valid CF standard names in the dataset, along with CLI hints on how to select them. - Add `use_anon_s3_dataset` usage example to dynamic CLI help text.
- Create standalone script at `examples/scripts/download_mlcast_dataset_sample.py` to download temporal slices of remote datasets. - Use `mlcast_datasets.open_catalog()` and Intake to dynamically resolve remote paths. - Support dot-notation catalog traversal (e.g. `precipitation.radklim_hourly`). - Automatically mirror remote dataset directory structure onto the local filesystem, using `mode='w'` to safely overwrite existing local caches. - Provide a Dask progress bar to show download/writing status. - Include an optional `--data-stage` argument (defaulting to `source_data`) designed to validate the stage but temporarily bypass traversal until the catalog structure updates to support it as the root node.
Adds optional `gpu-cu128` and `gpu-cu130` extras so users can opt into CUDA-enabled torch builds via `uv sync --extra gpu-cu128/cu130`, while keeping CPU torch as the default. Updates README with install instructions.
- Promote nvidia-ml-py and psutil to core dependencies (required for MLflow system metrics monitoring) - Add LogSystemInfoCallback: logs system/git tags as MLflow run tags, starts SystemMetricsMonitor manually, prints run URL at train start - Fix GPU warning suppression: patch gpu_monitor._logger immediately before SystemMetricsMonitor.start() to avoid being reset by mlflow.__init__ dictConfig call - Add use_mlflow_logger() fiddler: swaps TensorBoardLogger for MLFlowLogger, appends LogSystemInfoCallback to trainer callbacks - Export use_mlflow_logger from config/__init__.py - log_images() now dispatches on TensorBoardLogger vs MLFlowLogger, converting tensors to PIL for MLflow's log_image API - Pass self.logger (not self.logger.experiment) to log_images()
Add rainfall_amount_5min_to_normalized and normalized_to_rainfall_amount_5min, converting 5-minute accumulated rainfall (kg m-2 = mm) to/from normalized reflectivity via Marshall-Palmer Z-R relationship using a 1/12 h accumulation factor. Register under 'rainfall_amount' in NORMALIZATION_REGISTRY and DENORMALIZATION_REGISTRY.
Replace check_unquoted_fiddle_strings() (warn-only) with auto_quote_fiddle_strings() which automatically wraps bare string values in single quotes before passing to absl/Fiddle. Fiddle uses ast.literal_eval to parse override values, so unquoted strings would otherwise cause a parse error.
logging.config.dictConfig (called by mlflow.__init__) resets logger levels on existing loggers but does not clear their filters. Replacing setLevel(ERROR) with a logging.Filter subclass makes the suppression robust to any subsequent dictConfig calls regardless of timing.
The approach of patching mlflow's gpu_monitor logger (both via setLevel and logging.Filter) did not reliably suppress the warnings in practice. Removing the dead code for now.
Calling torch.set_float32_matmul_precision('high') at training startup
silences PyTorch's warning about underutilised Tensor Cores and enables
TF32 for matrix multiplications, improving throughput with negligible
impact on training precision.
…owcastLightningModule
…g.yaml Introduces load_yaml_config() (config/loader.py) which deserialises a Fiddle YAML dump back into a fdl.Config by mirroring the representers registered by fiddle._src.experimental.yaml_serialization in reverse (using a custom _FiddleLoader subclass of yaml.SafeLoader). The CLI (mlcast train) now detects a YAML file passed as --config, removes it from the remaining argv, loads it, and seeds the internal state of Fiddle's FiddleFlag directly so that all subsequent set: and fiddler: overrides are applied by Fiddle's own flag machinery without any custom override logic. Also fixes MLflow hyperparameter logging: fdp.as_dict_flattened produces keys like trainer.callbacks[0].monitor which MLflow rejects. Brackets are replaced with dot notation (e.g. .0.) before logging.
…g Path - Rename fixture to follow fp_ convention and return Path instead of an open dataset, so each test and dataloader worker can open the store independently - Update all tests using the fixture (test_fixture.py, test_cli_training.py, test_source_datasets.py) - Add test_cli_train_from_yaml_config exercising the new YAML load path - Update __main__.py module docstring to reflect current CLI usage - Add AGENTS.md documenting project conventions
…CLI, MLflow/WandB config upload, and YAML-based config loading
… CLI help, add config-diagram pre-commit hook
…riables skips input_channels if not in network signature
…ion with Mermaid diagram, mfai adapter example, and README snippet integration tests
…diagram with PNGs
…t classes - DatasetSample TypedDict introduced; __getitem__ now returns input/target/target_mask - target_mask computed per-timestep per-channel from target before nan_to_num (fixes collapsed-mask bug) - forecast_steps removed from NowcastLightningModule; shared_step derives it from future.shape[1] - forecast_steps moved to dataset_factory in base config and fiddlers - Contract 3 (steps > forecast_steps) removed from consistency_checks; now enforced by dataset guard - Old contracts 4/5 renumbered to 3/4; tests updated accordingly - README: all pl_module.forecast_steps references updated to dataset_factory.forecast_steps - Regenerate config diagram to reflect moved forecast_steps
- SourceDataDatasetBase extracts all shared logic: __init__, input_steps property, ds property, _validate_standard_names, _apply_augmentations, and _build_sample - _build_sample centralises post-isel processing: mask capture, nan_to_num, input/target split, augmentations, and DatasetSample assembly - DatasetSample constructed without mask first; target_mask added conditionally, eliminating duplicated if/else branches - Both subclass __getitem__ reduced to isel slicing + _build_sample call - _detect_axes extracted as module-private free function taking (ds, standard_name); sets t_dim/y_dim/x_dim via return value rather than side effects; stacklevel=3 accounts for the extra call frame
… steps becomes a property - input_steps + forecast_steps replace steps as the two primary constructor params on SourceDataDatasetBase and both subclasses; steps is now a @Property returning input_steps + forecast_steps - Guards updated: input_steps < 1 and forecast_steps < 1 raise ValueError - config/base.py: steps=18, forecast_steps=12 -> input_steps=6, forecast_steps=12 - config/fiddlers.py: use_random_sampler forwards input_steps instead of steps - README: past_steps renamed to input_steps throughout; mfai HalfUNetNowcaster example updated to accept input_steps/num_vars and compute in_channels internally; einops.rearrange used for channel-stacking with explanatory comments; config site reads cfg.data.dataset_factory.input_steps directly
|
Ok, @franchg I think I'm done with my refactor 🥳 I hope you like what I've done to your code 🎁 ➡️ https://github.com/leifdenby/mlcast/tree/feat/convgru-ensemble-training I've (with my agent's help) written up this list of changes (below). I went through this by hand and added/removed detail so I think it gives quite a complete overview. As well as reading the list below (hopefully it makes sense...) if you could also read through the updated README that would be great. I tried to take care to make it clear how the CLI now works, what the code structure is, and what ConvGRU (our first architecture) is. In the README I have also added an example of how to couple in a new architecture not yet in the If you are happy with these change my suggestion is that:
We can then build on this when we discover things I have overlooked :) ChangesWhat Changed
WhyAll these changes were motivate by a desire to:
What This Now EnablesWith this, it is now:
In addition the support for mlflow and remote S3 Zarr stores were convenient for my own experimentation, but I thought that enabling this flexiblity should make it easier to add more choices here in future. |
Feat/convgru ensemble training
Summary
@auto_configpattern (as discussed in Designing configuration infrastructure for mlcast python package #5 and used in weatherduck)New files
losses.py— CRPS, afCRPS, MaskedLoss, andbuild_loss()factoryutils.py— rain rate ↔ normalized reflectivity conversions (Marshall-Palmer)models/convgru.py—RadarLightningModelwith ensemble support, TensorBoard image logging, and inference APIconfigs.py—@auto_configexperiment factory following the weatherduck pattern__main__.py— CLI entry point (python -m mlcast --zarr-path ... --csv-path ...)Updated files
modules/convgru_modules.py— added ensemble generation via noisy decoder inputsdata/zarr_dataset.py—SampledRadarDatasetloading datacubes from Zarr using CSV(t, x, y)coordinatesdata/zarr_datamodule.py—RadarDataModulewith chronological train/val/test splits and augmentationpyproject.toml— addedfiddle,pandas,pytorch-lightning,torchvisiondependenciesUsage (hackathon)
Closes #7 (partially — brings ensemble ConvGRU into mlcast)
Implements #5 (Fiddle configuration)
Test plan
uv pip install -e .and verify imports workpython -m mlcast --helpto verify CLI