From bc6b15327be46da79306438dfffb19ed09588da3 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Wed, 11 Feb 2026 10:31:27 +0100 Subject: [PATCH 01/13] First commit: reworked the code to remove some parts which were not used, and to make the three main components (the autoencoder, the conditioner and the diffuser) distinct entities. Thanks to this, the way the data makes its way in these three components is clearer (see LDCast.ipynb). The files which were taken from the original code without change are: ldcast/utils.py (from ldcast/models/utils.py) ldcast/distributions.py (from ldcast/models/distributions.py) autoenc/auteonc.py autoenc/encoder.py blocks/afno.py blocks/attention.py blocks/resnet.py diffusion/ema.py diffusion/utils.py diffusion/unet.py (which was in the genforecast folder, even though unet is used only in the denoiser) diffusion/utils.py The changes I made are essentially: diffusion/diffusion.py: the LatentDiffusion was given the three parts: model (=denoiser), autoencoder and the context_encoder (= conditioner) and the interactions of these three was not clear. The main class is now DiffusionModel, which needs only the denoiser to be instantiated (the forward call still needs the context given by the conditioner). The interaction between DiffusionModel and the PLMSSampler could be improved (merge the two ?). I removed the ema scope for now, but it should taken care of. diffusion/plms.py: changed only the way the model (denoiser) is called I removed the genforecast folder of the original code: unet.py is now in diffusion, and analysis.py is now context/context.py. The context folder contains nowcast.py (which comes from nowcast/nowcast.py), so that the context folder contains everything to build the conditioner. I reworked the conext/nowcast.py file: I removed the Nowcaster, AFNONowcastNetBasic and the AFNONowcastNet classes (which were not used), and simplified a little the code of two remaining classes (some parts were not used either). The AFNONowcastNetBase class was also taking the autoencoder as input to build the conditioner, which I find very weird (this is why the data seemed to be decoded but not encoded in forecast.Forecast.__call__...). Now, the conditioner is built without the autoencoder. ldcast.py is a new file which will contain the classes subclassing the base classes of mlcast. --- LDCast.ipynb | 243 +++++++++ .../distributions-checkpoint.py | 29 ++ .../.ipynb_checkpoints/ldcast-checkpoint.py | 319 ++++++++++++ .../.ipynb_checkpoints/utils-checkpoint.py | 28 + .../.ipynb_checkpoints/autoenc-checkpoint.py | 96 ++++ .../.ipynb_checkpoints/encoder-checkpoint.py | 57 ++ src/mlcast/models/ldcast/autoenc/autoenc.py | 96 ++++ src/mlcast/models/ldcast/autoenc/encoder.py | 57 ++ .../.ipynb_checkpoints/afno-checkpoint.py | 348 +++++++++++++ .../attention-checkpoint.py | 104 ++++ .../.ipynb_checkpoints/resnet-checkpoint.py | 89 ++++ src/mlcast/models/ldcast/blocks/afno.py | 348 +++++++++++++ src/mlcast/models/ldcast/blocks/attention.py | 104 ++++ src/mlcast/models/ldcast/blocks/resnet.py | 89 ++++ .../.ipynb_checkpoints/context-checkpoint.py | 33 ++ .../.ipynb_checkpoints/nowcast-checkpoint.py | 125 +++++ src/mlcast/models/ldcast/context/context.py | 33 ++ src/mlcast/models/ldcast/context/nowcast.py | 125 +++++ .../diffusion-checkpoint.py | 220 ++++++++ .../.ipynb_checkpoints/ema-checkpoint.py | 76 +++ .../.ipynb_checkpoints/plms-checkpoint.py | 367 +++++++++++++ .../.ipynb_checkpoints/unet-checkpoint.py | 490 ++++++++++++++++++ .../.ipynb_checkpoints/utils-checkpoint.py | 247 +++++++++ .../models/ldcast/diffusion/diffusion.py | 220 ++++++++ src/mlcast/models/ldcast/diffusion/ema.py | 76 +++ src/mlcast/models/ldcast/diffusion/plms.py | 367 +++++++++++++ src/mlcast/models/ldcast/diffusion/unet.py | 490 ++++++++++++++++++ src/mlcast/models/ldcast/diffusion/utils.py | 247 +++++++++ src/mlcast/models/ldcast/distributions.py | 29 ++ src/mlcast/models/ldcast/ldcast.py | 317 +++++++++++ src/mlcast/models/ldcast/utils.py | 28 + 31 files changed, 5497 insertions(+) create mode 100644 LDCast.ipynb create mode 100644 src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py create mode 100644 src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py create mode 100644 src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py create mode 100644 src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py create mode 100644 src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py create mode 100644 src/mlcast/models/ldcast/autoenc/autoenc.py create mode 100644 src/mlcast/models/ldcast/autoenc/encoder.py create mode 100644 src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py create mode 100644 src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py create mode 100644 src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py create mode 100644 src/mlcast/models/ldcast/blocks/afno.py create mode 100644 src/mlcast/models/ldcast/blocks/attention.py create mode 100644 src/mlcast/models/ldcast/blocks/resnet.py create mode 100644 src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py create mode 100644 src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py create mode 100644 src/mlcast/models/ldcast/context/context.py create mode 100644 src/mlcast/models/ldcast/context/nowcast.py create mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py create mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py create mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py create mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py create mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py create mode 100644 src/mlcast/models/ldcast/diffusion/diffusion.py create mode 100644 src/mlcast/models/ldcast/diffusion/ema.py create mode 100644 src/mlcast/models/ldcast/diffusion/plms.py create mode 100644 src/mlcast/models/ldcast/diffusion/unet.py create mode 100644 src/mlcast/models/ldcast/diffusion/utils.py create mode 100644 src/mlcast/models/ldcast/distributions.py create mode 100644 src/mlcast/models/ldcast/ldcast.py create mode 100644 src/mlcast/models/ldcast/utils.py diff --git a/LDCast.ipynb b/LDCast.ipynb new file mode 100644 index 0000000..6ebf429 --- /dev/null +++ b/LDCast.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "aa7c06c7-6229-46bb-a06f-8aa8b03ab250", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "83a840b0-0705-4ece-a93d-6496ec075931", + "metadata": {}, + "outputs": [], + "source": [ + "#from src.mlcast.models.ldcast.ldcast import LDCast, LDCastLightningModule" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5e248d5a-84b4-4a61-9fef-ed4a9c613d5a", + "metadata": {}, + "outputs": [], + "source": [ + "#LDCastLightningModule(nn.Module(), nn.Module())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2a5fbc98-56de-48b5-9afa-3a46882e0a8b", + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e7639166-2162-4f71-9954-0e0d7e01dde8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "take care of ema scope\n" + ] + } + ], + "source": [ + "from src.mlcast.models.ldcast.autoenc.autoenc import AutoencoderKL\n", + "from src.mlcast.models.ldcast.autoenc.encoder import SimpleConvEncoder, SimpleConvDecoder\n", + "from src.mlcast.models.ldcast.context.context import AFNONowcastNetCascade\n", + "from src.mlcast.models.ldcast.diffusion.diffusion import DiffusionModel" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f0d7ac7d-0672-4a08-b08a-f06bcbb2d28e", + "metadata": {}, + "outputs": [], + "source": [ + "future_timesteps = 20\n", + "autoenc_time_ratio = 4 # number of timesteps encoded in the autoencoder" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "51608d56-c711-4da0-ab92-854be45d12ed", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# setup the different parts of LDCast\n", + "\n", + "# setup forecaster\n", + "conditioner = AFNONowcastNetCascade(\n", + " 32,\n", + " train_autoenc=False,\n", + " output_patches=future_timesteps//autoenc_time_ratio,\n", + " cascade_depth=3,\n", + " embed_dim=128,\n", + " analysis_depth=4\n", + ").to('cuda')\n", + "\n", + "enc = SimpleConvEncoder()\n", + "dec = SimpleConvDecoder()\n", + "autoencoder = AutoencoderKL(enc, dec).to('cuda')\n", + "\n", + "# setup denoiser\n", + "from src.mlcast.models.ldcast.diffusion.unet import UNetModel\n", + "denoiser = UNetModel(in_channels=autoencoder.hidden_width,\n", + " model_channels=256, out_channels=autoencoder.hidden_width,\n", + " num_res_blocks=2, attention_resolutions=(1,2), \n", + " dims=3, channel_mult=(1, 2, 4), num_heads=8,\n", + " num_timesteps=future_timesteps//autoenc_time_ratio,\n", + " # context channels (= analysis_net.cascade_dims)\n", + " context_ch=[128, 256, 512]).to('cuda')\n", + "\n", + "diffuser = DiffusionModel(denoiser).to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d68fe5a1-3c0b-4d83-bac5-e3d64e08d719", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "aea2ed01-350c-48aa-8e43-464e63fd5d6b", + "metadata": {}, + "outputs": [], + "source": [ + "class LDCastLightningModule(nn.ModuleDict):\n", + " def __init__(self, autoencoder, conditioner, diffuser):\n", + " super().__init__({'autoencoder': autoencoder, 'conditioner': conditioner, 'diffuser': diffuser})\n", + "\n", + " def forward(self, x, timesteps):\n", + " \n", + " # encoded is tuple of 3 tensors, but only the first one is used !!\n", + " encoded = self.autoencoder.encode(x) \n", + "\n", + " # condition is a dict of tensors\n", + " condition = conditioner(encoded[0], timesteps)\n", + "\n", + " latent_diffused = diffuser(condition) # tensor\n", + "\n", + " prediction = self.autoencoder.decode(latent_diffused) # tensor\n", + " \n", + " return prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8bd80398-d1c8-4a45-82a5-a31b80e5f02d", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "ldcast = LDCastLightningModule(autoencoder, conditioner, denoiser)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "55930f7c-ef89-4761-b053-39cdd52aba38", + "metadata": {}, + "outputs": [], + "source": [ + "# create fake data\n", + "timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32)\n", + "timesteps = timesteps.unsqueeze(0).expand(1,-1) # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why\n", + "x = torch.randn(1, 1, 4, 256, 256, device = 'cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "43867cb4-3c8a-4851-879c-e62dc7ed96d1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "PLMS Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00, 4.84it/s]\n" + ] + } + ], + "source": [ + "prediction = ldcast(x, timesteps)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "3cb247ac-4b87-4013-a591-0e97e8f66413", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 20, 256, 256])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prediction.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "700453df-61b2-4f04-93fc-f631e50d5ec0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py new file mode 100644 index 0000000..b7f68c2 --- /dev/null +++ b/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py @@ -0,0 +1,29 @@ +import numpy as np +import torch + + +def kl_from_standard_normal(mean, log_var): + kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) + return kl.mean() + + +def sample_from_standard_normal(mean, log_var, num=None): + std = (0.5 * log_var).exp() + shape = mean.shape + if num is not None: + # expand channel 1 to create several samples + shape = shape[:1] + (num,) + shape[1:] + mean = mean[:, None, ...] + std = std[:, None, ...] + return mean + std * torch.randn(shape, device=mean.device) + + +def ensemble_nll_normal(ensemble, sample, epsilon=1e-5): + mean = ensemble.mean(dim=1) + var = ensemble.var(dim=1, unbiased=True) + epsilon + logvar = var.log() + + diff = sample[:, None, ...] - mean + logtwopi = np.log(2 * np.pi) + nll = (logtwopi + logvar + diff.square() / var).mean() + return nll \ No newline at end of file diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py new file mode 100644 index 0000000..b49c551 --- /dev/null +++ b/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py @@ -0,0 +1,319 @@ +"""LDCast model implementation compliant with mlcast-ldcast structure.""" + +import abc +from pathlib import Path +from typing import Any + +import numpy as np +import pytorch_lightning as L +import torch +import xarray as xr +from torch import nn + +from ..base import NowcastingModelBase, NowcastingLightningModule + + +class LDCastLightningModule(NowcastingLightningModule): + """PyTorch Lightning module for LDCast diffusion model.""" + + def __init__( + self, + net: nn.Module, + loss: nn.Module, + optimizer_class: type | None = None, + optimizer_kwargs: dict | None = None, + **kwargs: Any, + ): + super().__init__( + net=net, + loss=loss, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + **kwargs, + ) + + +class LDCast(NowcastingModelBase): + """LDCast precipitation nowcasting model. + + This model implements a latent diffusion approach for precipitation forecasting, + combining an autoencoder for dimensionality reduction with a diffusion model + for temporal prediction. + + Attributes: + timestep_length: Time resolution of predictions (e.g., 5 minutes) + PLModuleClass: The Lightning module class used for training + """ + + timestep_length: np.timedelta64 | None = None + #PLModuleClass = LDCastLightningModule + + def __init__(self, config: dict | None = None): + """Initialize LDCast model. + + Args: + config: Configuration dictionary with model parameters + """ + #super().__init__() + self.pl_module = LDCastLightningModule(nn.Module(), nn.Module()) + self.config = config or {} + self.autoencoder = None + self.diffusion_model = None + self.scaler = None + + def save(self, path: str, **kwargs: Any) -> None: + """Save the trained LDCast model to disk. + + Args: + path: File path where the model should be saved + **kwargs: Additional arguments for model saving + """ + model_path = Path(path) + model_path.mkdir(parents=True, exist_ok=True) + + # Save autoencoder weights + if self.autoencoder is not None: + torch.save( + self.autoencoder.state_dict(), + model_path / "autoencoder.pt" + ) + + # Save diffusion model weights + if self.diffusion_model is not None: + torch.save( + self.diffusion_model.state_dict(), + model_path / "diffusion_model.pt" + ) + + # Save scaler parameters if present + if self.scaler is not None: + import pickle + with open(model_path / "scaler.pkl", "wb") as f: + pickle.dump(self.scaler, f) + + # Save configuration + import json + with open(model_path / "config.json", "w") as f: + json.dump(self.config, f) + + def load(self, path: str, **kwargs: Any) -> None: + """Load a pre-trained LDCast model from disk. + + Args: + path: File path to the saved model + **kwargs: Additional arguments for model loading + """ + model_path = Path(path) + + # Load configuration + import json + with open(model_path / "config.json", "r") as f: + self.config = json.load(f) + + # Load autoencoder weights if available + autoenc_path = model_path / "autoencoder.pt" + if autoenc_path.exists(): + # Initialize autoencoder architecture from config + self.autoencoder = self._build_autoencoder() + self.autoencoder.load_state_dict(torch.load(autoenc_path)) + + # Load diffusion model weights if available + diffusion_path = model_path / "diffusion_model.pt" + if diffusion_path.exists(): + # Initialize diffusion model architecture from config + self.diffusion_model = self._build_diffusion_model() + self.diffusion_model.load_state_dict(torch.load(diffusion_path)) + + # Load scaler parameters if available + scaler_path = model_path / "scaler.pkl" + if scaler_path.exists(): + import pickle + with open(scaler_path, "rb") as f: + self.scaler = pickle.load(f) + + def fit(self, da_rr: xr.DataArray, **kwargs: Any) -> None: + """Train the LDCast model on precipitation data. + + Args: + da_rr: xarray DataArray containing precipitation radar data + with time, latitude, and longitude dimensions + **kwargs: Additional arguments: + - epochs: Number of training epochs + - batch_size: Batch size for training + - val_split: Validation split ratio + - num_timesteps: Number of input timesteps + """ + # Extract configuration from kwargs + epochs = kwargs.get('epochs', self.config.get('max_epochs', 100)) + batch_size = kwargs.get('batch_size', self.config.get('batch_size', 32)) + num_timesteps = kwargs.get('num_timesteps', self.config.get('timesteps', 12)) + + # Step 1: Data preprocessing and scaling + self._preprocess_data(da_rr, **kwargs) + + # Step 2: Train autoencoder + self._train_autoencoder( + da_rr, + epochs=epochs, + batch_size=batch_size, + **kwargs + ) + + # Step 3: Train diffusion model + self._train_diffusion_model( + da_rr, + num_timesteps=num_timesteps, + epochs=epochs, + batch_size=batch_size, + **kwargs + ) + + # Store timestep length + if 'time' in da_rr.dims: + time_coords = da_rr.coords['time'].values + if len(time_coords) > 1: + self.timestep_length = np.timedelta64( + int(np.diff(time_coords[:2])[0]), 'ns' + ) + + def predict( + self, + da_rr: xr.DataArray, + duration: str, + **kwargs: Any + ) -> xr.DataArray: + """Generate precipitation forecasts. + + Args: + da_rr: xarray DataArray containing initial precipitation conditions + duration: ISO 8601 duration string (e.g., "PT1H" for 1 hour) + **kwargs: Additional arguments: + - num_samples: Number of ensemble samples to generate + - num_diffusion_steps: Number of diffusion steps + + Returns: + xarray DataArray containing precipitation predictions with + original spatial dimensions plus an "elapsed_time" dimension + """ + from isodate import parse_duration + + # Parse duration string + duration_obj = parse_duration(duration) + num_forecasts = int(duration_obj.total_seconds() / + self.timestep_length.astype(int)) + + # Extract configuration from kwargs + num_samples = kwargs.get('num_samples', 1) + num_diffusion_steps = kwargs.get('num_diffusion_steps', 50) + + # Preprocess input using stored scaler + processed_input = self._preprocess_input(da_rr) + + # Encode to latent space using autoencoder + with torch.no_grad(): + latent = self.autoencoder.encode(processed_input) + + # Generate predictions using diffusion model + predictions = [] + for _ in range(num_samples): + pred = self._diffusion_predict( + latent, + num_forecasts, + num_diffusion_steps + ) + predictions.append(pred) + + # Stack and average predictions + predictions = torch.stack(predictions, dim=0).mean(dim=0) + + # Decode from latent space + with torch.no_grad(): + forecasted = self.autoencoder.decode(predictions) + + # Postprocess and convert back to original scale + output = self._postprocess_output(forecasted, da_rr) + + # Create output DataArray with elapsed_time dimension + time_coords = da_rr.coords['time'].values[-1] + elapsed_times = [ + np.timedelta64(i, 'm') * 5 # Assuming 5-minute steps + for i in range(1, num_forecasts + 1) + ] + + output_da = xr.DataArray( + output, + dims=['elapsed_time', 'latitude', 'longitude'], + coords={ + 'elapsed_time': ('elapsed_time', elapsed_times), + 'latitude': ('latitude', da_rr.coords['latitude'].values), + 'longitude': ('longitude', da_rr.coords['longitude'].values), + }, + name='precipitation' + ) + + return output_da + + def _preprocess_data(self, da_rr: xr.DataArray, **kwargs: Any) -> None: + """Preprocess precipitation data and fit scaler.""" + # Implement data scaling/normalization + # Store scaling parameters in self.scaler + pass + + def _train_autoencoder( + self, + da_rr: xr.DataArray, + epochs: int, + batch_size: int, + **kwargs: Any + ) -> None: + """Train the autoencoder component.""" + # Import and use ldcast autoencoder training + from ldcast.models.autoenc import setup_and_train + # Implementation details + pass + + def _train_diffusion_model( + self, + da_rr: xr.DataArray, + num_timesteps: int, + epochs: int, + batch_size: int, + **kwargs: Any + ) -> None: + """Train the diffusion model component.""" + # Import and use ldcast genforecast training + from ldcast.models.genforecast import setup_and_train + # Implementation details + pass + + def _preprocess_input(self, da_rr: xr.DataArray) -> torch.Tensor: + """Convert input xarray to scaled tensor.""" + # Apply stored scaler + pass + + def _postprocess_output( + self, + output: torch.Tensor, + reference_da: xr.DataArray + ) -> np.ndarray: + """Convert predictions back to original scale and format.""" + # Reverse scaling using stored scaler + pass + + def _diffusion_predict( + self, + latent: torch.Tensor, + num_forecasts: int, + num_steps: int + ) -> torch.Tensor: + """Generate predictions using the diffusion model.""" + # Use ldcast diffusion inference + pass + + def _build_autoencoder(self) -> nn.Module: + """Build autoencoder architecture from config.""" + pass + + def _build_diffusion_model(self) -> nn.Module: + """Build diffusion model architecture from config.""" + pass \ No newline at end of file diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000..65cada8 --- /dev/null +++ b/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,28 @@ +import torch +from torch import nn + + +def normalization(channels, norm_type="group", num_groups=32): + if norm_type == "batch": + return nn.BatchNorm3d(channels) + elif norm_type == "group": + return nn.GroupNorm(num_groups=num_groups, num_channels=channels) + elif (not norm_type) or (norm_type.tolower() == "none"): + return nn.Identity() + else: + raise NotImplementedError(norm) + + +def activation(act_type="swish"): + if act_type == "swish": + return nn.SiLU() + elif act_type == "gelu": + return nn.GELU() + elif act_type == "relu": + return nn.ReLU() + elif act_type == "tanh": + return nn.Tanh() + elif not act_type: + return nn.Identity() + else: + raise NotImplementedError(act_type) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py new file mode 100644 index 0000000..51cc18b --- /dev/null +++ b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py @@ -0,0 +1,96 @@ +import pytorch_lightning as pl +import torch +from torch import nn + +from ..distributions import ( + ensemble_nll_normal, + kl_from_standard_normal, + sample_from_standard_normal, +) + + +class AutoencoderKL(pl.LightningModule): + def __init__( + self, + encoder, + decoder, + kl_weight=0.01, + encoded_channels=64, + hidden_width=32, + **kwargs, + ): + super().__init__(**kwargs) + self.encoder = encoder + self.decoder = decoder + self.hidden_width = hidden_width + self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, kernel_size=1) + self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, kernel_size=1) + self.log_var = nn.Parameter(torch.zeros(size=())) + self.kl_weight = kl_weight + + def encode(self, x): + h = self.encoder(x) + (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) + return (mean, log_var) + + def decode(self, z): + z = self.to_decoder(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + (mean, log_var) = self.encode(input) + if sample_posterior: + z = sample_from_standard_normal(mean, log_var) + else: + z = mean + dec = self.decode(z) + return (dec, mean, log_var) + + def _loss(self, batch): + (x, y) = batch + while isinstance(x, list) or isinstance(x, tuple): + x = x[0][0] + (y_pred, mean, log_var) = self.forward(x) + + rec_loss = (y - y_pred).abs().mean() + kl_loss = kl_from_standard_normal(mean, log_var) + + total_loss = rec_loss + self.kl_weight * kl_loss + + return (total_loss, rec_loss, kl_loss) + + def training_step(self, batch, batch_idx): + loss = self._loss(batch)[0] + self.log("train_loss", loss, on_step=True) + return loss + + @torch.no_grad() + def val_test_step(self, batch, batch_idx, split="val"): + (total_loss, rec_loss, kl_loss) = self._loss(batch) + log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} + self.log(f"{split}_loss", total_loss, **log_params, sync_dist=True) + self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params, sync_dist=True) + self.log(f"{split}_kl_loss", kl_loss, **log_params, sync_dist=True) + + def validation_step(self, batch, batch_idx): + self.val_test_step(batch, batch_idx, split="val") + + def test_step(self, batch, batch_idx): + self.val_test_step(batch, batch_idx, split="test") + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-3 + ) + reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, patience=3, factor=0.25, verbose=True + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": reduce_lr, + "monitor": "val_rec_loss", + "frequency": 1, + }, + } \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py new file mode 100644 index 0000000..aab9f7a --- /dev/null +++ b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py @@ -0,0 +1,57 @@ +import numpy as np +import torch.nn as nn + +from ..blocks.resnet import ResBlock3D +from ..utils import activation, normalization + + +class SimpleConvEncoder(nn.Sequential): + def __init__(self, in_dim=1, levels=2, min_ch=64): + sequence = [] + channels = np.hstack([ + in_dim, + (8**np.arange(1,levels+1)).clip(min=min_ch) + ]) + + for i in range(levels): + in_channels = int(channels[i]) + out_channels = int(channels[i+1]) + res_kernel_size = (3,3,3) if i == 0 else (1,3,3) + res_block = ResBlock3D( + in_channels, out_channels, + kernel_size=res_kernel_size, + norm_kwargs={"num_groups": 1} + ) + sequence.append(res_block) + downsample = nn.Conv3d(out_channels, out_channels, + kernel_size=(2,2,2), stride=(2,2,2)) + sequence.append(downsample) + in_channels = out_channels + + super().__init__(*sequence) + + +class SimpleConvDecoder(nn.Sequential): + def __init__(self, in_dim=1, levels=2, min_ch=64): + sequence = [] + channels = np.hstack([ + in_dim, + (8**np.arange(1,levels+1)).clip(min=min_ch) + ]) + + for i in reversed(list(range(levels))): + in_channels = int(channels[i+1]) + out_channels = int(channels[i]) + upsample = nn.ConvTranspose3d(in_channels, in_channels, + kernel_size=(2,2,2), stride=(2,2,2)) + sequence.append(upsample) + res_kernel_size = (3,3,3) if (i == 0) else (1,3,3) + res_block = ResBlock3D( + in_channels, out_channels, + kernel_size=res_kernel_size, + norm_kwargs={"num_groups": 1} + ) + sequence.append(res_block) + in_channels = out_channels + + super().__init__(*sequence) diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py new file mode 100644 index 0000000..51cc18b --- /dev/null +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -0,0 +1,96 @@ +import pytorch_lightning as pl +import torch +from torch import nn + +from ..distributions import ( + ensemble_nll_normal, + kl_from_standard_normal, + sample_from_standard_normal, +) + + +class AutoencoderKL(pl.LightningModule): + def __init__( + self, + encoder, + decoder, + kl_weight=0.01, + encoded_channels=64, + hidden_width=32, + **kwargs, + ): + super().__init__(**kwargs) + self.encoder = encoder + self.decoder = decoder + self.hidden_width = hidden_width + self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, kernel_size=1) + self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, kernel_size=1) + self.log_var = nn.Parameter(torch.zeros(size=())) + self.kl_weight = kl_weight + + def encode(self, x): + h = self.encoder(x) + (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) + return (mean, log_var) + + def decode(self, z): + z = self.to_decoder(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + (mean, log_var) = self.encode(input) + if sample_posterior: + z = sample_from_standard_normal(mean, log_var) + else: + z = mean + dec = self.decode(z) + return (dec, mean, log_var) + + def _loss(self, batch): + (x, y) = batch + while isinstance(x, list) or isinstance(x, tuple): + x = x[0][0] + (y_pred, mean, log_var) = self.forward(x) + + rec_loss = (y - y_pred).abs().mean() + kl_loss = kl_from_standard_normal(mean, log_var) + + total_loss = rec_loss + self.kl_weight * kl_loss + + return (total_loss, rec_loss, kl_loss) + + def training_step(self, batch, batch_idx): + loss = self._loss(batch)[0] + self.log("train_loss", loss, on_step=True) + return loss + + @torch.no_grad() + def val_test_step(self, batch, batch_idx, split="val"): + (total_loss, rec_loss, kl_loss) = self._loss(batch) + log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} + self.log(f"{split}_loss", total_loss, **log_params, sync_dist=True) + self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params, sync_dist=True) + self.log(f"{split}_kl_loss", kl_loss, **log_params, sync_dist=True) + + def validation_step(self, batch, batch_idx): + self.val_test_step(batch, batch_idx, split="val") + + def test_step(self, batch, batch_idx): + self.val_test_step(batch, batch_idx, split="test") + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-3 + ) + reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, patience=3, factor=0.25, verbose=True + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": reduce_lr, + "monitor": "val_rec_loss", + "frequency": 1, + }, + } \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/encoder.py b/src/mlcast/models/ldcast/autoenc/encoder.py new file mode 100644 index 0000000..aab9f7a --- /dev/null +++ b/src/mlcast/models/ldcast/autoenc/encoder.py @@ -0,0 +1,57 @@ +import numpy as np +import torch.nn as nn + +from ..blocks.resnet import ResBlock3D +from ..utils import activation, normalization + + +class SimpleConvEncoder(nn.Sequential): + def __init__(self, in_dim=1, levels=2, min_ch=64): + sequence = [] + channels = np.hstack([ + in_dim, + (8**np.arange(1,levels+1)).clip(min=min_ch) + ]) + + for i in range(levels): + in_channels = int(channels[i]) + out_channels = int(channels[i+1]) + res_kernel_size = (3,3,3) if i == 0 else (1,3,3) + res_block = ResBlock3D( + in_channels, out_channels, + kernel_size=res_kernel_size, + norm_kwargs={"num_groups": 1} + ) + sequence.append(res_block) + downsample = nn.Conv3d(out_channels, out_channels, + kernel_size=(2,2,2), stride=(2,2,2)) + sequence.append(downsample) + in_channels = out_channels + + super().__init__(*sequence) + + +class SimpleConvDecoder(nn.Sequential): + def __init__(self, in_dim=1, levels=2, min_ch=64): + sequence = [] + channels = np.hstack([ + in_dim, + (8**np.arange(1,levels+1)).clip(min=min_ch) + ]) + + for i in reversed(list(range(levels))): + in_channels = int(channels[i+1]) + out_channels = int(channels[i]) + upsample = nn.ConvTranspose3d(in_channels, in_channels, + kernel_size=(2,2,2), stride=(2,2,2)) + sequence.append(upsample) + res_kernel_size = (3,3,3) if (i == 0) else (1,3,3) + res_block = ResBlock3D( + in_channels, out_channels, + kernel_size=res_kernel_size, + norm_kwargs={"num_groups": 1} + ) + sequence.append(res_block) + in_channels = out_channels + + super().__init__(*sequence) diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py new file mode 100644 index 0000000..3e7f801 --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py @@ -0,0 +1,348 @@ +#reference: https://github.com/NVlabs/AFNO-transformer +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .attention import TemporalAttention + +class Mlp(nn.Module): + def __init__( + self, + in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, H, W, C = x.shape + + x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") + x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, H, W // 2 + 1, C) + x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0 + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + return x + + + +class AFNO3D(nn.Module): + def __init__( + self, hidden_size, num_blocks=8, sparsity_threshold=0.01, + hard_thresholding_fraction=1, hidden_size_factor=1 + ): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, D, H, W, C = x.shape + + x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho") + x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, D, H, W // 2 + 1, C) + x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class AFNOBlock3d(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + data_format="channels_last", + mlp_out_features=None, + ): + super().__init__() + self.norm_layer = norm_layer + self.norm1 = norm_layer(dim) + self.filter = AFNO3D(dim, num_blocks, sparsity_threshold, + hard_thresholding_fraction) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, out_features=mlp_out_features, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop + ) + self.double_skip = double_skip + self.channels_first = (data_format == "channels_first") + + def forward(self, x): + if self.channels_first: + # AFNO natively uses a channels-last data format + x = x.permute(0,2,3,4,1) + + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + + if self.channels_first: + x = x.permute(0,4,1,2,3) + + return x + + +class PatchEmbed3d(nn.Module): + def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.proj(x) + x = x.permute(0,2,3,4,1) # convert to BHWC + return x + + +class PatchExpand3d(nn.Module): + def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size)) + + def forward(self, x): + x = self.proj(x) + x = rearrange( + x, + "b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)", + p0=self.patch_size[0], + p1=self.patch_size[1], + p2=self.patch_size[2], + d=x.shape[1], + h=x.shape[2], + w=x.shape[3], + ) + return x + + +class AFNOCrossAttentionBlock3d(nn.Module): + """ AFNO 3D Block with channel mixing from two sources, used (only?) in the Unet denoiser + """ + def __init__( + self, + dim, + context_dim, + mlp_ratio=2., + drop=0., + act_layer=nn.GELU, + norm_layer=nn.Identity, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + data_format="channels_last", + timesteps=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim+context_dim) + mlp_hidden_dim = int((dim+context_dim) * mlp_ratio) + self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim) + self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold, + hard_thresholding_fraction) + self.mlp = Mlp( + in_features=dim+context_dim, + out_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop + ) + self.channels_first = (data_format == "channels_first") + + def forward(self, x, y): + if self.channels_first: + # AFNO natively uses a channels-last order + x = x.permute(0,2,3,4,1) + y = y.permute(0,2,3,4,1) + + xy = torch.concat((self.norm1(x),y), axis=-1) + xy = self.pre_proj(xy) + xy + xy = self.filter(self.norm2(xy)) + xy # AFNO filter + x = self.mlp(xy) + x # feed-forward + + if self.channels_first: + x = x.permute(0,4,1,2,3) + + return x \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py new file mode 100644 index 0000000..c3b791e --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py @@ -0,0 +1,104 @@ +import math + +import torch +from torch import nn +import torch.nn.functional as F + + +class TemporalAttention(nn.Module): + def __init__( + self, channels, context_channels=None, + head_dim=32, num_heads=8 + ): + super().__init__() + self.channels = channels + if context_channels is None: + context_channels = channels + self.context_channels = context_channels + self.head_dim = head_dim + self.num_heads = num_heads + self.inner_dim = head_dim * num_heads + self.attn_scale = self.head_dim ** -0.5 + if channels % num_heads: + raise ValueError("channels must be divisible by num_heads") + self.KV = nn.Linear(context_channels, self.inner_dim*2) + self.Q = nn.Linear(channels, self.inner_dim) + self.proj = nn.Linear(self.inner_dim, channels) + + def forward(self, x, y=None): + if y is None: + y = x + + (K,V) = self.KV(y).chunk(2, dim=-1) + (B, Dk, H, W, C) = K.shape + shape = (B, Dk, H, W, self.num_heads, self.head_dim) + K = K.reshape(shape) + V = V.reshape(shape) + + Q = self.Q(x) + (B, Dq, H, W, C) = Q.shape + shape = (B, Dq, H, W, self.num_heads, self.head_dim) + Q = Q.reshape(shape) + + K = K.permute((0,2,3,4,5,1)) # K^T + V = V.permute((0,2,3,4,1,5)) + Q = Q.permute((0,2,3,4,1,5)) + + attn = torch.matmul(Q, K) * self.attn_scale + attn = F.softmax(attn, dim=-1) + y = torch.matmul(attn, V) + y = y.permute((0,4,1,2,3,5)) + y = y.reshape((B,Dq,H,W,C)) + y = self.proj(y) + return y + + +class TemporalTransformer(nn.Module): + def __init__(self, + channels, + mlp_dim_mul=1, + **kwargs + ): + super().__init__() + self.attn1 = TemporalAttention(channels, **kwargs) + self.attn2 = TemporalAttention(channels, **kwargs) + self.norm1 = nn.LayerNorm(channels) + self.norm2 = nn.LayerNorm(channels) + self.norm3 = nn.LayerNorm(channels) + self.mlp = MLP(channels, dim_mul=mlp_dim_mul) + + def forward(self, x, y): + x = self.attn1(self.norm1(x)) + x # self attention + x = self.attn2(self.norm2(x), y) + x # cross attention + return self.mlp(self.norm3(x)) + x # feed-forward + + +class MLP(nn.Sequential): + def __init__(self, dim, dim_mul=4): + inner_dim = dim * dim_mul + sequence = [ + nn.Linear(dim, inner_dim), + nn.SiLU(), + nn.Linear(inner_dim, dim) + ] + super().__init__(*sequence) + + +def positional_encoding(position, dims, add_dims=()): + div_term = torch.exp( + torch.arange(0, dims, 2, device=position.device) * + (-math.log(10000.0) / dims) + ) + if position.ndim == 1: + arg = position[:,None] * div_term[None,:] + else: + arg = position[:,:,None] * div_term[None,None,:] + + pos_enc = torch.concat( + [torch.sin(arg), torch.cos(arg)], + dim=-1 + ) + if add_dims: + for dim in add_dims: + pos_enc = pos_enc.unsqueeze(dim) + return pos_enc \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py new file mode 100644 index 0000000..90dacbc --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py @@ -0,0 +1,89 @@ +from torch import nn +from torch.nn.utils.parametrizations import spectral_norm as sn + +from ..utils import activation, normalization + + +class ResBlock3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + resample=None, + resample_factor=(1, 1, 1), + kernel_size=(3, 3, 3), + act="swish", + norm="group", + norm_kwargs=None, + spectral_norm=False, + **kwargs, + ): + super().__init__(**kwargs) + if in_channels != out_channels: + self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1) + else: + self.proj = nn.Identity() + + padding = tuple(k // 2 for k in kernel_size) + if resample == "down": + self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=resample_factor, + padding=padding, + ) + self.conv2 = nn.Conv3d( + out_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + elif resample == "up": + self.resample = nn.Upsample(scale_factor=resample_factor, mode="trilinear") + self.conv1 = nn.ConvTranspose3d( + in_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + output_padding = tuple( + 2 * p + s - k + for (p, s, k) in zip(padding, resample_factor, kernel_size) + ) + self.conv2 = nn.ConvTranspose3d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=resample_factor, + padding=padding, + output_padding=output_padding, + ) + else: + self.resample = nn.Identity() + self.conv1 = nn.Conv3d( + in_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + self.conv2 = nn.Conv3d( + out_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + + if isinstance(act, str): + act = (act, act) + self.act1 = activation(act_type=act[0]) + self.act2 = activation(act_type=act[1]) + + if norm_kwargs is None: + norm_kwargs = {} + self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) + self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) + if spectral_norm: + self.conv1 = sn(self.conv1) + self.conv2 = sn(self.conv2) + if not isinstance(self.proj, nn.Identity): + self.proj = sn(self.proj) + + def forward(self, x): + x_in = self.resample(self.proj(x)) + x = self.norm1(x) + x = self.act1(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.act2(x) + x = self.conv2(x) + return x + x_in \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/afno.py b/src/mlcast/models/ldcast/blocks/afno.py new file mode 100644 index 0000000..3e7f801 --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/afno.py @@ -0,0 +1,348 @@ +#reference: https://github.com/NVlabs/AFNO-transformer +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .attention import TemporalAttention + +class Mlp(nn.Module): + def __init__( + self, + in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, H, W, C = x.shape + + x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") + x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, H, W // 2 + 1, C) + x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0 + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + return x + + + +class AFNO3D(nn.Module): + def __init__( + self, hidden_size, num_blocks=8, sparsity_threshold=0.01, + hard_thresholding_fraction=1, hidden_size_factor=1 + ): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, D, H, W, C = x.shape + + x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho") + x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, D, H, W // 2 + 1, C) + x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class AFNOBlock3d(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + data_format="channels_last", + mlp_out_features=None, + ): + super().__init__() + self.norm_layer = norm_layer + self.norm1 = norm_layer(dim) + self.filter = AFNO3D(dim, num_blocks, sparsity_threshold, + hard_thresholding_fraction) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, out_features=mlp_out_features, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop + ) + self.double_skip = double_skip + self.channels_first = (data_format == "channels_first") + + def forward(self, x): + if self.channels_first: + # AFNO natively uses a channels-last data format + x = x.permute(0,2,3,4,1) + + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + + if self.channels_first: + x = x.permute(0,4,1,2,3) + + return x + + +class PatchEmbed3d(nn.Module): + def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.proj(x) + x = x.permute(0,2,3,4,1) # convert to BHWC + return x + + +class PatchExpand3d(nn.Module): + def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size)) + + def forward(self, x): + x = self.proj(x) + x = rearrange( + x, + "b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)", + p0=self.patch_size[0], + p1=self.patch_size[1], + p2=self.patch_size[2], + d=x.shape[1], + h=x.shape[2], + w=x.shape[3], + ) + return x + + +class AFNOCrossAttentionBlock3d(nn.Module): + """ AFNO 3D Block with channel mixing from two sources, used (only?) in the Unet denoiser + """ + def __init__( + self, + dim, + context_dim, + mlp_ratio=2., + drop=0., + act_layer=nn.GELU, + norm_layer=nn.Identity, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + data_format="channels_last", + timesteps=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim+context_dim) + mlp_hidden_dim = int((dim+context_dim) * mlp_ratio) + self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim) + self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold, + hard_thresholding_fraction) + self.mlp = Mlp( + in_features=dim+context_dim, + out_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop + ) + self.channels_first = (data_format == "channels_first") + + def forward(self, x, y): + if self.channels_first: + # AFNO natively uses a channels-last order + x = x.permute(0,2,3,4,1) + y = y.permute(0,2,3,4,1) + + xy = torch.concat((self.norm1(x),y), axis=-1) + xy = self.pre_proj(xy) + xy + xy = self.filter(self.norm2(xy)) + xy # AFNO filter + x = self.mlp(xy) + x # feed-forward + + if self.channels_first: + x = x.permute(0,4,1,2,3) + + return x \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/attention.py b/src/mlcast/models/ldcast/blocks/attention.py new file mode 100644 index 0000000..c3b791e --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/attention.py @@ -0,0 +1,104 @@ +import math + +import torch +from torch import nn +import torch.nn.functional as F + + +class TemporalAttention(nn.Module): + def __init__( + self, channels, context_channels=None, + head_dim=32, num_heads=8 + ): + super().__init__() + self.channels = channels + if context_channels is None: + context_channels = channels + self.context_channels = context_channels + self.head_dim = head_dim + self.num_heads = num_heads + self.inner_dim = head_dim * num_heads + self.attn_scale = self.head_dim ** -0.5 + if channels % num_heads: + raise ValueError("channels must be divisible by num_heads") + self.KV = nn.Linear(context_channels, self.inner_dim*2) + self.Q = nn.Linear(channels, self.inner_dim) + self.proj = nn.Linear(self.inner_dim, channels) + + def forward(self, x, y=None): + if y is None: + y = x + + (K,V) = self.KV(y).chunk(2, dim=-1) + (B, Dk, H, W, C) = K.shape + shape = (B, Dk, H, W, self.num_heads, self.head_dim) + K = K.reshape(shape) + V = V.reshape(shape) + + Q = self.Q(x) + (B, Dq, H, W, C) = Q.shape + shape = (B, Dq, H, W, self.num_heads, self.head_dim) + Q = Q.reshape(shape) + + K = K.permute((0,2,3,4,5,1)) # K^T + V = V.permute((0,2,3,4,1,5)) + Q = Q.permute((0,2,3,4,1,5)) + + attn = torch.matmul(Q, K) * self.attn_scale + attn = F.softmax(attn, dim=-1) + y = torch.matmul(attn, V) + y = y.permute((0,4,1,2,3,5)) + y = y.reshape((B,Dq,H,W,C)) + y = self.proj(y) + return y + + +class TemporalTransformer(nn.Module): + def __init__(self, + channels, + mlp_dim_mul=1, + **kwargs + ): + super().__init__() + self.attn1 = TemporalAttention(channels, **kwargs) + self.attn2 = TemporalAttention(channels, **kwargs) + self.norm1 = nn.LayerNorm(channels) + self.norm2 = nn.LayerNorm(channels) + self.norm3 = nn.LayerNorm(channels) + self.mlp = MLP(channels, dim_mul=mlp_dim_mul) + + def forward(self, x, y): + x = self.attn1(self.norm1(x)) + x # self attention + x = self.attn2(self.norm2(x), y) + x # cross attention + return self.mlp(self.norm3(x)) + x # feed-forward + + +class MLP(nn.Sequential): + def __init__(self, dim, dim_mul=4): + inner_dim = dim * dim_mul + sequence = [ + nn.Linear(dim, inner_dim), + nn.SiLU(), + nn.Linear(inner_dim, dim) + ] + super().__init__(*sequence) + + +def positional_encoding(position, dims, add_dims=()): + div_term = torch.exp( + torch.arange(0, dims, 2, device=position.device) * + (-math.log(10000.0) / dims) + ) + if position.ndim == 1: + arg = position[:,None] * div_term[None,:] + else: + arg = position[:,:,None] * div_term[None,None,:] + + pos_enc = torch.concat( + [torch.sin(arg), torch.cos(arg)], + dim=-1 + ) + if add_dims: + for dim in add_dims: + pos_enc = pos_enc.unsqueeze(dim) + return pos_enc \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/resnet.py b/src/mlcast/models/ldcast/blocks/resnet.py new file mode 100644 index 0000000..90dacbc --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/resnet.py @@ -0,0 +1,89 @@ +from torch import nn +from torch.nn.utils.parametrizations import spectral_norm as sn + +from ..utils import activation, normalization + + +class ResBlock3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + resample=None, + resample_factor=(1, 1, 1), + kernel_size=(3, 3, 3), + act="swish", + norm="group", + norm_kwargs=None, + spectral_norm=False, + **kwargs, + ): + super().__init__(**kwargs) + if in_channels != out_channels: + self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1) + else: + self.proj = nn.Identity() + + padding = tuple(k // 2 for k in kernel_size) + if resample == "down": + self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=resample_factor, + padding=padding, + ) + self.conv2 = nn.Conv3d( + out_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + elif resample == "up": + self.resample = nn.Upsample(scale_factor=resample_factor, mode="trilinear") + self.conv1 = nn.ConvTranspose3d( + in_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + output_padding = tuple( + 2 * p + s - k + for (p, s, k) in zip(padding, resample_factor, kernel_size) + ) + self.conv2 = nn.ConvTranspose3d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=resample_factor, + padding=padding, + output_padding=output_padding, + ) + else: + self.resample = nn.Identity() + self.conv1 = nn.Conv3d( + in_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + self.conv2 = nn.Conv3d( + out_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + + if isinstance(act, str): + act = (act, act) + self.act1 = activation(act_type=act[0]) + self.act2 = activation(act_type=act[1]) + + if norm_kwargs is None: + norm_kwargs = {} + self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) + self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) + if spectral_norm: + self.conv1 = sn(self.conv1) + self.conv2 = sn(self.conv2) + if not isinstance(self.proj, nn.Identity): + self.proj = sn(self.proj) + + def forward(self, x): + x_in = self.resample(self.proj(x)) + x = self.norm1(x) + x = self.act1(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.act2(x) + x = self.conv2(x) + return x + x_in \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py new file mode 100644 index 0000000..ac76a7d --- /dev/null +++ b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .nowcast import AFNONowcastNetBase +from ..blocks.resnet import ResBlock3D + + +class AFNONowcastNetCascade(AFNONowcastNetBase): + def __init__(self, *args, cascade_depth=4, **kwargs): + super().__init__(*args, **kwargs) + self.cascade_depth = cascade_depth + self.resnet = nn.ModuleList() + ch = self.embed_dim_out + self.cascade_dims = [ch] + for i in range(cascade_depth-1): + ch_out = 2*ch + self.cascade_dims.append(ch_out) + self.resnet.append( + ResBlock3D(ch, ch_out, kernel_size=(1,3,3), norm=None) + ) + ch = ch_out + + def forward(self, x, timesteps): + x = super().forward(x, timesteps) + img_shape = tuple(x.shape[-2:]) + cascade = {img_shape: x} + for i in range(self.cascade_depth-1): + x = F.avg_pool3d(x, (1,2,2)) + x = self.resnet[i](x) + img_shape = tuple(x.shape[-2:]) + cascade[img_shape] = x + return cascade \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py new file mode 100644 index 0000000..54cc43b --- /dev/null +++ b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py @@ -0,0 +1,125 @@ +import collections + +import torch +from torch import nn +from torch.nn import functional as F +import pytorch_lightning as pl + +from ..blocks.afno import AFNOBlock3d +from ..blocks.attention import positional_encoding, TemporalTransformer + +class FusionBlock3d(nn.Module): + def __init__(self, dim, size_ratio, dim_out=None, afno_fusion=False): + super().__init__() + + if dim_out is None: + dim_out = dim + + if size_ratio == 1: + scale = nn.Identity() + else: + scale = [] + while size_ratio > 1: + scale.append(nn.ConvTranspose3d( + dim[i], dim_out if size_ratio==2 else dim[i], + kernel_size=(1,3,3), stride=(1,2,2), + padding=(0,1,1), output_padding=(0,1,1) + )) + size_ratio //= 2 + scale = nn.Sequential(*scale) + self.scale = scale + + self.afno_fusion = afno_fusion + + if self.afno_fusion: + self.fusion = nn.Identity() + + def resize_proj(self, x): + x = x.permute(0,4,1,2,3) + x = self.scale(x) + x = x.permute(0,2,3,4,1) + return x + + def forward(self, x): + x = self.resize_proj(x) + return x + + +class AFNONowcastNetBase(nn.Module): + def __init__( + self, + autoencoder_dim, + embed_dim=128, + embed_dim_out=None, + analysis_depth=4, + forecast_depth=4, + input_patches=1, + input_size_ratios=1, + output_patches=2, + train_autoenc=False, + afno_fusion=False + ): + super().__init__() + + if embed_dim_out is None: + embed_dim_out = embed_dim + self.embed_dim = embed_dim + self.embed_dim_out = embed_dim_out + self.output_patches = output_patches + + self.proj = nn.Conv3d(autoencoder_dim, embed_dim, kernel_size=1) + + self.analysis = nn.Sequential( + *(AFNOBlock3d(embed_dim) for _ in range(analysis_depth)) + ) + + # temporal transformer + self.use_temporal_transformer = input_patches != output_patches + if self.use_temporal_transformer: + self.temporal_transformer = TemporalTransformer(embed_dim) + + # data fusion + self.fusion = FusionBlock3d(embed_dim, input_size_ratios, + afno_fusion=afno_fusion, dim_out=embed_dim_out) + + # forecast + self.forecast = nn.Sequential( + *(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth)) + ) + + def add_pos_enc(self, x, t): + '''not sure the this does what it was supposed to do in the original LDCast code''' + if t.shape[1] != x.shape[1]: + # this can happen if x has been compressed + # by the autoencoder in the time dimension + ds_factor = t.shape[1] // x.shape[1] + t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:,0,:] + + pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2,3)) + return x + pos_enc + + def forward(self, z, timesteps): + '''z is the latent representation of the conditioning and timesteps is contains the timesteps of the input frames (it is [-3, -2, -1, 0])''' + z = self.proj(z) + z = z.permute(0,2,3,4,1) + z = self.analysis(z) + + if self.use_temporal_transformer: + # add positional encoding + z = self.add_pos_enc(z, timesteps) + + # transform to output shape and coordinates + expand_shape = z.shape[:1] + (-1,) + z.shape[2:] + pos_enc_output = positional_encoding( + torch.arange(1,self.output_patches+1, device=z.device), + self.embed_dim, add_dims=(0,2,3) + ) + pe_out = pos_enc_output.expand(*expand_shape) + z = self.temporal_transformer(pe_out, z) + + + # merge inputs + z = self.fusion(z) + # produce prediction + z = self.forecast(z) + return z.permute(0,4,1,2,3) # to channels-first order \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/context.py b/src/mlcast/models/ldcast/context/context.py new file mode 100644 index 0000000..ac76a7d --- /dev/null +++ b/src/mlcast/models/ldcast/context/context.py @@ -0,0 +1,33 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .nowcast import AFNONowcastNetBase +from ..blocks.resnet import ResBlock3D + + +class AFNONowcastNetCascade(AFNONowcastNetBase): + def __init__(self, *args, cascade_depth=4, **kwargs): + super().__init__(*args, **kwargs) + self.cascade_depth = cascade_depth + self.resnet = nn.ModuleList() + ch = self.embed_dim_out + self.cascade_dims = [ch] + for i in range(cascade_depth-1): + ch_out = 2*ch + self.cascade_dims.append(ch_out) + self.resnet.append( + ResBlock3D(ch, ch_out, kernel_size=(1,3,3), norm=None) + ) + ch = ch_out + + def forward(self, x, timesteps): + x = super().forward(x, timesteps) + img_shape = tuple(x.shape[-2:]) + cascade = {img_shape: x} + for i in range(self.cascade_depth-1): + x = F.avg_pool3d(x, (1,2,2)) + x = self.resnet[i](x) + img_shape = tuple(x.shape[-2:]) + cascade[img_shape] = x + return cascade \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/nowcast.py b/src/mlcast/models/ldcast/context/nowcast.py new file mode 100644 index 0000000..54cc43b --- /dev/null +++ b/src/mlcast/models/ldcast/context/nowcast.py @@ -0,0 +1,125 @@ +import collections + +import torch +from torch import nn +from torch.nn import functional as F +import pytorch_lightning as pl + +from ..blocks.afno import AFNOBlock3d +from ..blocks.attention import positional_encoding, TemporalTransformer + +class FusionBlock3d(nn.Module): + def __init__(self, dim, size_ratio, dim_out=None, afno_fusion=False): + super().__init__() + + if dim_out is None: + dim_out = dim + + if size_ratio == 1: + scale = nn.Identity() + else: + scale = [] + while size_ratio > 1: + scale.append(nn.ConvTranspose3d( + dim[i], dim_out if size_ratio==2 else dim[i], + kernel_size=(1,3,3), stride=(1,2,2), + padding=(0,1,1), output_padding=(0,1,1) + )) + size_ratio //= 2 + scale = nn.Sequential(*scale) + self.scale = scale + + self.afno_fusion = afno_fusion + + if self.afno_fusion: + self.fusion = nn.Identity() + + def resize_proj(self, x): + x = x.permute(0,4,1,2,3) + x = self.scale(x) + x = x.permute(0,2,3,4,1) + return x + + def forward(self, x): + x = self.resize_proj(x) + return x + + +class AFNONowcastNetBase(nn.Module): + def __init__( + self, + autoencoder_dim, + embed_dim=128, + embed_dim_out=None, + analysis_depth=4, + forecast_depth=4, + input_patches=1, + input_size_ratios=1, + output_patches=2, + train_autoenc=False, + afno_fusion=False + ): + super().__init__() + + if embed_dim_out is None: + embed_dim_out = embed_dim + self.embed_dim = embed_dim + self.embed_dim_out = embed_dim_out + self.output_patches = output_patches + + self.proj = nn.Conv3d(autoencoder_dim, embed_dim, kernel_size=1) + + self.analysis = nn.Sequential( + *(AFNOBlock3d(embed_dim) for _ in range(analysis_depth)) + ) + + # temporal transformer + self.use_temporal_transformer = input_patches != output_patches + if self.use_temporal_transformer: + self.temporal_transformer = TemporalTransformer(embed_dim) + + # data fusion + self.fusion = FusionBlock3d(embed_dim, input_size_ratios, + afno_fusion=afno_fusion, dim_out=embed_dim_out) + + # forecast + self.forecast = nn.Sequential( + *(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth)) + ) + + def add_pos_enc(self, x, t): + '''not sure the this does what it was supposed to do in the original LDCast code''' + if t.shape[1] != x.shape[1]: + # this can happen if x has been compressed + # by the autoencoder in the time dimension + ds_factor = t.shape[1] // x.shape[1] + t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:,0,:] + + pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2,3)) + return x + pos_enc + + def forward(self, z, timesteps): + '''z is the latent representation of the conditioning and timesteps is contains the timesteps of the input frames (it is [-3, -2, -1, 0])''' + z = self.proj(z) + z = z.permute(0,2,3,4,1) + z = self.analysis(z) + + if self.use_temporal_transformer: + # add positional encoding + z = self.add_pos_enc(z, timesteps) + + # transform to output shape and coordinates + expand_shape = z.shape[:1] + (-1,) + z.shape[2:] + pos_enc_output = positional_encoding( + torch.arange(1,self.output_patches+1, device=z.device), + self.embed_dim, add_dims=(0,2,3) + ) + pe_out = pos_enc_output.expand(*expand_shape) + z = self.temporal_transformer(pe_out, z) + + + # merge inputs + z = self.fusion(z) + # produce prediction + z = self.forecast(z) + return z.permute(0,4,1,2,3) # to channels-first order \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py new file mode 100644 index 0000000..301e7dd --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py @@ -0,0 +1,220 @@ +""" +From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py +Pared down to simplify code. + +The original file acknowledges: +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from contextlib import contextmanager +from functools import partial + +import contextlib + +from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding +from .ema import LitEma +from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d +from .plms import PLMSSampler + +print('take care of ema scope') + +class DiffusionModel(pl.LightningModule): # replaces LatentDiffusion + def __init__(self, + denoiser, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + use_ema=True, + lr=1e-4, + lr_warmup=0, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + parameterization="eps", # all assuming fixed variance schedules + ): + super().__init__() + self.denoiser = denoiser + self.lr = lr + self.lr_warmup = lr_warmup + + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + + self.use_ema = use_ema + if self.use_ema: + self.denoiser_ema = LitEma(self.denoiser) + + self.register_schedule( + beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s + ) + + self.loss_type = loss_type + + self.sampler = PLMSSampler(self.denoiser, timesteps) + + def forward(self, conditioning, num_diffusion_iters = 50, verbose = True): + gen_shape = (32, 5, 256//4, 256//4) + with contextlib.redirect_stdout(None): + (s, intermediates) = self.sampler.sample( + num_diffusion_iters, + 1, # batch_size + gen_shape, + self.q_sample, + conditioning, + progbar=verbose + ) + return s + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + + betas = make_beta_schedule( + beta_schedule, timesteps, + linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s + ) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.denoiser.register_buffer('betas', to_torch(betas)) + self.denoiser.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.denoiser.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.denoiser.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.denoiser.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.denoiser_ema.store(self.denoiser.parameters()) + self.denoiser_ema.copy_to(self.denoiser) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.denoiser_ema.restore(self.denoiser.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + return ( + extract_into_tensor(self.denoiser.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.denoiser.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None, context=None): + if noise is None: + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + denoised = self.denoiser(x_noisy, t, context=context) + + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") + + return self.get_loss(denoised, target, mean=False).mean() + ''' + def forward(self, x, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + ''' + ''' + def shared_step(self, batch): + (x,y) = batch + y = self.autoencoder.encode(y)[0] + context = self.context_encoder(x) if self.conditional else None + return self(y, context=context) + ''' + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log("train_loss", loss) + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + with self.ema_scope(): + loss_ema = self.shared_step(batch) + log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} + self.log("val_loss", loss, **log_params) + self.log("val_loss_ema", loss, **log_params) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.denoiser_ema(self.denoiser) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, + betas=(0.5, 0.9), weight_decay=1e-3) + reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, patience=3, factor=0.25, verbose=True + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": reduce_lr, + "monitor": "val_loss_ema", + "frequency": 1, + }, + } + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + **kwargs + ): + if self.trainer.global_step < self.lr_warmup: + lr_scale = (self.trainer.global_step+1) / self.lr_warmup + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * self.lr + + super().optimizer_step( + epoch, batch_idx, optimizer, + optimizer_idx, optimizer_closure, + **kwargs + ) + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py new file mode 100644 index 0000000..cd2f8e3 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py new file mode 100644 index 0000000..bc87241 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py @@ -0,0 +1,367 @@ +""" +From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +""" + +"""SAMPLING ONLY.""" + +import numpy as np +import torch +from tqdm import tqdm + +from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler: + def __init__(self, model, timesteps, schedule="linear", **kwargs): + self.model = model + self.ddpm_num_timesteps = timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + # if type(attr) == torch.Tensor: + # if attr.device != torch.device("cuda"): + # attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + if ddim_eta != 0: + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(torch.sqrt(alphas_cumprod)) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(torch.sqrt(1.0 - alphas_cumprod)), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(torch.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(torch.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(torch.sqrt(1.0 / alphas_cumprod - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod, + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer( + "ddim_sqrt_one_minus_alphas", torch.sqrt(1.0 - ddim_alphas) + ) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + q_sample_func, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + progbar=True, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + """ + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + """ + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + size = (batch_size,) + shape + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + q_sample_func, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + progbar=progbar, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling( + self, + cond, + shape, + q_sample_func, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + progbar=True, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = time_range + if progbar: + iterator = tqdm(iterator, desc="PLMS Sampler", total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=device, + dtype=torch.long, + ) + + if mask is not None: + assert x0 is not None + img_orig = q_sample_func( + x0, ts + ) # TODO: deterministic forward pass? + print('after q_sample 1', img_orig.shape) + img = img_orig * mask + (1.0 - mask) * img + print('after q_sample 2', img.shape) + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms( + self, + x, + condition, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): + e_t = self.model(x, t, condition) + else: + pass + '''is never used + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, condition]) + e_t_uncond, e_t = self.model.apply_denoiser(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + ''' + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, condition, **corrector_kwargs + ) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + param_shape = (b,) + (1,) * (x.ndim - 1) + a_t = torch.full(param_shape, alphas[index], device=device) + a_prev = torch.full(param_shape, alphas_prev[index], device=device) + sigma_t = torch.full(param_shape, sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + param_shape, sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = ( + 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] + ) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py new file mode 100644 index 0000000..1eb22d7 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py @@ -0,0 +1,490 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl + +from .utils import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ..blocks.afno import AFNOCrossAttentionBlock3d +SpatialTransformer = type(None) +#from ldm.modules.attention import SpatialTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AFNOCrossAttentionBlock3d): + img_shape = tuple(x.shape[-2:]) + x = layer(x, context[img_shape]) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + + """ + + def __init__( + self, + model_channels, + in_channels=1, + out_channels=1, + num_res_blocks=2, + attention_resolutions=(1,2,4), + context_ch=128, + dropout=0, + channel_mult=(1, 2, 4, 4), + conv_resample=True, + dims=3, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + legacy=True, + num_timesteps=1 + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + timesteps = th.arange(1, num_timesteps+1) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = num_head_channels + layers.append( + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[level], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[-1], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = num_head_channels + layers.append( + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[level], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + self.device = next(self.parameters()).device + def forward(self, x, timesteps=None, context=None): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :return: an [N x C x ...] Tensor of outputs. + """ + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + return self.out(h) + diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000..ab90f9e --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,247 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps].cpu() + numpied = [alphacums[0].cpu()] + alphacums[ddim_timesteps[:-1]].cpu().tolist() + alphas_prev = np.asarray(numpied) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return nn.Identity() #GroupNorm32(32, channels) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py new file mode 100644 index 0000000..301e7dd --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -0,0 +1,220 @@ +""" +From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py +Pared down to simplify code. + +The original file acknowledges: +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from contextlib import contextmanager +from functools import partial + +import contextlib + +from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding +from .ema import LitEma +from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d +from .plms import PLMSSampler + +print('take care of ema scope') + +class DiffusionModel(pl.LightningModule): # replaces LatentDiffusion + def __init__(self, + denoiser, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + use_ema=True, + lr=1e-4, + lr_warmup=0, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + parameterization="eps", # all assuming fixed variance schedules + ): + super().__init__() + self.denoiser = denoiser + self.lr = lr + self.lr_warmup = lr_warmup + + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + + self.use_ema = use_ema + if self.use_ema: + self.denoiser_ema = LitEma(self.denoiser) + + self.register_schedule( + beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s + ) + + self.loss_type = loss_type + + self.sampler = PLMSSampler(self.denoiser, timesteps) + + def forward(self, conditioning, num_diffusion_iters = 50, verbose = True): + gen_shape = (32, 5, 256//4, 256//4) + with contextlib.redirect_stdout(None): + (s, intermediates) = self.sampler.sample( + num_diffusion_iters, + 1, # batch_size + gen_shape, + self.q_sample, + conditioning, + progbar=verbose + ) + return s + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + + betas = make_beta_schedule( + beta_schedule, timesteps, + linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s + ) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.denoiser.register_buffer('betas', to_torch(betas)) + self.denoiser.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.denoiser.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.denoiser.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.denoiser.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.denoiser_ema.store(self.denoiser.parameters()) + self.denoiser_ema.copy_to(self.denoiser) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.denoiser_ema.restore(self.denoiser.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + return ( + extract_into_tensor(self.denoiser.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.denoiser.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None, context=None): + if noise is None: + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + denoised = self.denoiser(x_noisy, t, context=context) + + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") + + return self.get_loss(denoised, target, mean=False).mean() + ''' + def forward(self, x, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + ''' + ''' + def shared_step(self, batch): + (x,y) = batch + y = self.autoencoder.encode(y)[0] + context = self.context_encoder(x) if self.conditional else None + return self(y, context=context) + ''' + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + self.log("train_loss", loss) + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + with self.ema_scope(): + loss_ema = self.shared_step(batch) + log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} + self.log("val_loss", loss, **log_params) + self.log("val_loss_ema", loss, **log_params) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.denoiser_ema(self.denoiser) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, + betas=(0.5, 0.9), weight_decay=1e-3) + reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, patience=3, factor=0.25, verbose=True + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": reduce_lr, + "monitor": "val_loss_ema", + "frequency": 1, + }, + } + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + **kwargs + ): + if self.trainer.global_step < self.lr_warmup: + lr_scale = (self.trainer.global_step+1) / self.lr_warmup + for pg in optimizer.param_groups: + pg['lr'] = lr_scale * self.lr + + super().optimizer_step( + epoch, batch_idx, optimizer, + optimizer_idx, optimizer_closure, + **kwargs + ) + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py new file mode 100644 index 0000000..cd2f8e3 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/plms.py b/src/mlcast/models/ldcast/diffusion/plms.py new file mode 100644 index 0000000..bc87241 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/plms.py @@ -0,0 +1,367 @@ +""" +From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +""" + +"""SAMPLING ONLY.""" + +import numpy as np +import torch +from tqdm import tqdm + +from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler: + def __init__(self, model, timesteps, schedule="linear", **kwargs): + self.model = model + self.ddpm_num_timesteps = timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + # if type(attr) == torch.Tensor: + # if attr.device != torch.device("cuda"): + # attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + if ddim_eta != 0: + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(torch.sqrt(alphas_cumprod)) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(torch.sqrt(1.0 - alphas_cumprod)), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(torch.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(torch.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(torch.sqrt(1.0 / alphas_cumprod - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod, + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer( + "ddim_sqrt_one_minus_alphas", torch.sqrt(1.0 - ddim_alphas) + ) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + q_sample_func, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + progbar=True, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + """ + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + """ + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + size = (batch_size,) + shape + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + q_sample_func, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + progbar=progbar, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling( + self, + cond, + shape, + q_sample_func, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + progbar=True, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = time_range + if progbar: + iterator = tqdm(iterator, desc="PLMS Sampler", total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=device, + dtype=torch.long, + ) + + if mask is not None: + assert x0 is not None + img_orig = q_sample_func( + x0, ts + ) # TODO: deterministic forward pass? + print('after q_sample 1', img_orig.shape) + img = img_orig * mask + (1.0 - mask) * img + print('after q_sample 2', img.shape) + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms( + self, + x, + condition, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): + e_t = self.model(x, t, condition) + else: + pass + '''is never used + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, condition]) + e_t_uncond, e_t = self.model.apply_denoiser(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + ''' + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, condition, **corrector_kwargs + ) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + param_shape = (b,) + (1,) * (x.ndim - 1) + a_t = torch.full(param_shape, alphas[index], device=device) + a_prev = torch.full(param_shape, alphas_prev[index], device=device) + sigma_t = torch.full(param_shape, sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + param_shape, sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = ( + 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] + ) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/unet.py b/src/mlcast/models/ldcast/diffusion/unet.py new file mode 100644 index 0000000..1eb22d7 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/unet.py @@ -0,0 +1,490 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl + +from .utils import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ..blocks.afno import AFNOCrossAttentionBlock3d +SpatialTransformer = type(None) +#from ldm.modules.attention import SpatialTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AFNOCrossAttentionBlock3d): + img_shape = tuple(x.shape[-2:]) + x = layer(x, context[img_shape]) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + + """ + + def __init__( + self, + model_channels, + in_channels=1, + out_channels=1, + num_res_blocks=2, + attention_resolutions=(1,2,4), + context_ch=128, + dropout=0, + channel_mult=(1, 2, 4, 4), + conv_resample=True, + dims=3, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + legacy=True, + num_timesteps=1 + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + timesteps = th.arange(1, num_timesteps+1) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = num_head_channels + layers.append( + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[level], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[-1], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = num_head_channels + layers.append( + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[level], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + self.device = next(self.parameters()).device + def forward(self, x, timesteps=None, context=None): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :return: an [N x C x ...] Tensor of outputs. + """ + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + return self.out(h) + diff --git a/src/mlcast/models/ldcast/diffusion/utils.py b/src/mlcast/models/ldcast/diffusion/utils.py new file mode 100644 index 0000000..ab90f9e --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/utils.py @@ -0,0 +1,247 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps].cpu() + numpied = [alphacums[0].cpu()] + alphacums[ddim_timesteps[:-1]].cpu().tolist() + alphas_prev = np.asarray(numpied) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return nn.Identity() #GroupNorm32(32, channels) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") \ No newline at end of file diff --git a/src/mlcast/models/ldcast/distributions.py b/src/mlcast/models/ldcast/distributions.py new file mode 100644 index 0000000..b7f68c2 --- /dev/null +++ b/src/mlcast/models/ldcast/distributions.py @@ -0,0 +1,29 @@ +import numpy as np +import torch + + +def kl_from_standard_normal(mean, log_var): + kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) + return kl.mean() + + +def sample_from_standard_normal(mean, log_var, num=None): + std = (0.5 * log_var).exp() + shape = mean.shape + if num is not None: + # expand channel 1 to create several samples + shape = shape[:1] + (num,) + shape[1:] + mean = mean[:, None, ...] + std = std[:, None, ...] + return mean + std * torch.randn(shape, device=mean.device) + + +def ensemble_nll_normal(ensemble, sample, epsilon=1e-5): + mean = ensemble.mean(dim=1) + var = ensemble.var(dim=1, unbiased=True) + epsilon + logvar = var.log() + + diff = sample[:, None, ...] - mean + logtwopi = np.log(2 * np.pi) + nll = (logtwopi + logvar + diff.square() / var).mean() + return nll \ No newline at end of file diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py new file mode 100644 index 0000000..f64f772 --- /dev/null +++ b/src/mlcast/models/ldcast/ldcast.py @@ -0,0 +1,317 @@ +import abc +from pathlib import Path +from typing import Any + +import numpy as np +import pytorch_lightning as L +import torch +import xarray as xr +from torch import nn + +from ..base import NowcastingModelBase, NowcastingLightningModule + + +class LDCastLightningModule(NowcastingLightningModule): + """PyTorch Lightning module for LDCast diffusion model.""" + + def __init__( + self, + net: nn.Module, + loss: nn.Module, + optimizer_class: type | None = None, + optimizer_kwargs: dict | None = None, + **kwargs: Any, + ): + super().__init__( + net=net, + loss=loss, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + **kwargs, + ) + + +class LDCast(NowcastingModelBase): + """LDCast precipitation nowcasting model. + + This model implements a latent diffusion approach for precipitation forecasting, + combining an autoencoder for dimensionality reduction with a diffusion model + for temporal prediction. + + Attributes: + timestep_length: Time resolution of predictions (e.g., 5 minutes) + PLModuleClass: The Lightning module class used for training + """ + + timestep_length: np.timedelta64 | None = None + #PLModuleClass = LDCastLightningModule + + def __init__(self, config: dict | None = None): + """Initialize LDCast model. + + Args: + config: Configuration dictionary with model parameters + """ + #super().__init__() + self.pl_module = LDCastLightningModule(nn.Module(), nn.Module()) + self.config = config or {} + self.autoencoder = None + self.diffusion_model = None + self.scaler = None + + def save(self, path: str, **kwargs: Any) -> None: + """Save the trained LDCast model to disk. + + Args: + path: File path where the model should be saved + **kwargs: Additional arguments for model saving + """ + model_path = Path(path) + model_path.mkdir(parents=True, exist_ok=True) + + # Save autoencoder weights + if self.autoencoder is not None: + torch.save( + self.autoencoder.state_dict(), + model_path / "autoencoder.pt" + ) + + # Save diffusion model weights + if self.diffusion_model is not None: + torch.save( + self.diffusion_model.state_dict(), + model_path / "diffusion_model.pt" + ) + + # Save scaler parameters if present + if self.scaler is not None: + import pickle + with open(model_path / "scaler.pkl", "wb") as f: + pickle.dump(self.scaler, f) + + # Save configuration + import json + with open(model_path / "config.json", "w") as f: + json.dump(self.config, f) + + def load(self, path: str, **kwargs: Any) -> None: + """Load a pre-trained LDCast model from disk. + + Args: + path: File path to the saved model + **kwargs: Additional arguments for model loading + """ + model_path = Path(path) + + # Load configuration + import json + with open(model_path / "config.json", "r") as f: + self.config = json.load(f) + + # Load autoencoder weights if available + autoenc_path = model_path / "autoencoder.pt" + if autoenc_path.exists(): + # Initialize autoencoder architecture from config + self.autoencoder = self._build_autoencoder() + self.autoencoder.load_state_dict(torch.load(autoenc_path)) + + # Load diffusion model weights if available + diffusion_path = model_path / "diffusion_model.pt" + if diffusion_path.exists(): + # Initialize diffusion model architecture from config + self.diffusion_model = self._build_diffusion_model() + self.diffusion_model.load_state_dict(torch.load(diffusion_path)) + + # Load scaler parameters if available + scaler_path = model_path / "scaler.pkl" + if scaler_path.exists(): + import pickle + with open(scaler_path, "rb") as f: + self.scaler = pickle.load(f) + + def fit(self, da_rr: xr.DataArray, **kwargs: Any) -> None: + """Train the LDCast model on precipitation data. + + Args: + da_rr: xarray DataArray containing precipitation radar data + with time, latitude, and longitude dimensions + **kwargs: Additional arguments: + - epochs: Number of training epochs + - batch_size: Batch size for training + - val_split: Validation split ratio + - num_timesteps: Number of input timesteps + """ + # Extract configuration from kwargs + epochs = kwargs.get('epochs', self.config.get('max_epochs', 100)) + batch_size = kwargs.get('batch_size', self.config.get('batch_size', 32)) + num_timesteps = kwargs.get('num_timesteps', self.config.get('timesteps', 12)) + + # Step 1: Data preprocessing and scaling + self._preprocess_data(da_rr, **kwargs) + + # Step 2: Train autoencoder + self._train_autoencoder( + da_rr, + epochs=epochs, + batch_size=batch_size, + **kwargs + ) + + # Step 3: Train diffusion model + self._train_diffusion_model( + da_rr, + num_timesteps=num_timesteps, + epochs=epochs, + batch_size=batch_size, + **kwargs + ) + + # Store timestep length + if 'time' in da_rr.dims: + time_coords = da_rr.coords['time'].values + if len(time_coords) > 1: + self.timestep_length = np.timedelta64( + int(np.diff(time_coords[:2])[0]), 'ns' + ) + + def predict( + self, + da_rr: xr.DataArray, + duration: str, + **kwargs: Any + ) -> xr.DataArray: + """Generate precipitation forecasts. + + Args: + da_rr: xarray DataArray containing initial precipitation conditions + duration: ISO 8601 duration string (e.g., "PT1H" for 1 hour) + **kwargs: Additional arguments: + - num_samples: Number of ensemble samples to generate + - num_diffusion_steps: Number of diffusion steps + + Returns: + xarray DataArray containing precipitation predictions with + original spatial dimensions plus an "elapsed_time" dimension + """ + from isodate import parse_duration + + # Parse duration string + duration_obj = parse_duration(duration) + num_forecasts = int(duration_obj.total_seconds() / + self.timestep_length.astype(int)) + + # Extract configuration from kwargs + num_samples = kwargs.get('num_samples', 1) + num_diffusion_steps = kwargs.get('num_diffusion_steps', 50) + + # Preprocess input using stored scaler + processed_input = self._preprocess_input(da_rr) + + # Encode to latent space using autoencoder + with torch.no_grad(): + latent = self.autoencoder.encode(processed_input) + + # Generate predictions using diffusion model + predictions = [] + for _ in range(num_samples): + pred = self._diffusion_predict( + latent, + num_forecasts, + num_diffusion_steps + ) + predictions.append(pred) + + # Stack and average predictions + predictions = torch.stack(predictions, dim=0).mean(dim=0) + + # Decode from latent space + with torch.no_grad(): + forecasted = self.autoencoder.decode(predictions) + + # Postprocess and convert back to original scale + output = self._postprocess_output(forecasted, da_rr) + + # Create output DataArray with elapsed_time dimension + time_coords = da_rr.coords['time'].values[-1] + elapsed_times = [ + np.timedelta64(i, 'm') * 5 # Assuming 5-minute steps + for i in range(1, num_forecasts + 1) + ] + + output_da = xr.DataArray( + output, + dims=['elapsed_time', 'latitude', 'longitude'], + coords={ + 'elapsed_time': ('elapsed_time', elapsed_times), + 'latitude': ('latitude', da_rr.coords['latitude'].values), + 'longitude': ('longitude', da_rr.coords['longitude'].values), + }, + name='precipitation' + ) + + return output_da + + def _preprocess_data(self, da_rr: xr.DataArray, **kwargs: Any) -> None: + """Preprocess precipitation data and fit scaler.""" + # Implement data scaling/normalization + # Store scaling parameters in self.scaler + pass + + def _train_autoencoder( + self, + da_rr: xr.DataArray, + epochs: int, + batch_size: int, + **kwargs: Any + ) -> None: + """Train the autoencoder component.""" + # Import and use ldcast autoencoder training + from ldcast.models.autoenc import setup_and_train + # Implementation details + pass + + def _train_diffusion_model( + self, + da_rr: xr.DataArray, + num_timesteps: int, + epochs: int, + batch_size: int, + **kwargs: Any + ) -> None: + """Train the diffusion model component.""" + # Import and use ldcast genforecast training + from ldcast.models.genforecast import setup_and_train + # Implementation details + pass + + def _preprocess_input(self, da_rr: xr.DataArray) -> torch.Tensor: + """Convert input xarray to scaled tensor.""" + # Apply stored scaler + pass + + def _postprocess_output( + self, + output: torch.Tensor, + reference_da: xr.DataArray + ) -> np.ndarray: + """Convert predictions back to original scale and format.""" + # Reverse scaling using stored scaler + pass + + def _diffusion_predict( + self, + latent: torch.Tensor, + num_forecasts: int, + num_steps: int + ) -> torch.Tensor: + """Generate predictions using the diffusion model.""" + # Use ldcast diffusion inference + pass + + def _build_autoencoder(self) -> nn.Module: + """Build autoencoder architecture from config.""" + pass + + def _build_diffusion_model(self) -> nn.Module: + """Build diffusion model architecture from config.""" + pass \ No newline at end of file diff --git a/src/mlcast/models/ldcast/utils.py b/src/mlcast/models/ldcast/utils.py new file mode 100644 index 0000000..65cada8 --- /dev/null +++ b/src/mlcast/models/ldcast/utils.py @@ -0,0 +1,28 @@ +import torch +from torch import nn + + +def normalization(channels, norm_type="group", num_groups=32): + if norm_type == "batch": + return nn.BatchNorm3d(channels) + elif norm_type == "group": + return nn.GroupNorm(num_groups=num_groups, num_channels=channels) + elif (not norm_type) or (norm_type.tolower() == "none"): + return nn.Identity() + else: + raise NotImplementedError(norm) + + +def activation(act_type="swish"): + if act_type == "swish": + return nn.SiLU() + elif act_type == "gelu": + return nn.GELU() + elif act_type == "relu": + return nn.ReLU() + elif act_type == "tanh": + return nn.Tanh() + elif not act_type: + return nn.Identity() + else: + raise NotImplementedError(act_type) \ No newline at end of file From 7c9c9a7bcc133ae4deaecf8291ce979df4b1ab75 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Fri, 13 Feb 2026 13:46:39 +0100 Subject: [PATCH 02/13] changes with respect to previous commit: - some changes in /src/mlcast/models/base.py (attribute problem if the loss is a dict) - in /src/mlcast/models/ldcast/autoenc/autoenc.py: added the loss for the autoencoder; renamed the 'AutoencoderKL' class in 'AutoencoderKLNet' and set the encoder and decoder to default configurations (which were the ones used in the original code). Removed all the training logic from that class (it will be handled by the trainer) - in src/mlcast/models/ldcast/context/context.py: the timesteps passed to AFNONetCascade.forward were always [-3, -2, -1, 0], so I included them in this function - in /src/mlcast/models/ldcast/diffusion/diffusion.py: I tried to separate as much as possible what concerns the samplers from the rest. I replaced the LatentDiffusion class by the LatentNowcaster one. I could not manage to make this a subclass of NowcastingLightningModule but it would be nice to do so. I removed all the training logic which was contained in the LatentDiffusion class (it will be handled by the trainer) - /src/mlcast/models/ldcast/diffusion/plms.py: removed the score_corrector, corrector_kwargs and mask keywords, which were not used - /src/mlcast/models/ldcast/ldcast.py now contains the main LDCast class subclassing the NowcastingModelBase (only the predict method is implemented, partially) I did not manage to make LatentNowcaster a sublcass of NowcastingLightningModule because LatentNowcaster needs two nets (denoiser and conditioner) and because the training logic is not as straightforward as it is for the moment in NowcastingLightningModule. One should also take into account the fact that two different samplers are used for training and inference, so that the forward method can not just be self.net(x) It would be nice to have cleaner and consistent APIs for the samplers. For the moment, the PLMSSampler and the SimpleSampler are not totally consistent in their APIs, because the SimpleSampler (better/more common name for this one?) was only used during training, while the PLMSSampler was used during inference. The handling of the schedule of each sampler with respect to the schedule saved in the denoiser could also be clearer During training, an EMA scope was used for the weights of the denoiser, I removed this for the moment, but it should reincluded in some way. The variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. An AutoencoderKLNet instance can now be passed to the NowcastingLightningModule with the autoenc_loss to handle the training In /src/mlcast/models/ldcast/diffusion/diffusion.py, one has to choose which sampler to use for testing --- .../models/ldcast/diffusion/simple_sampler.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 src/mlcast/models/ldcast/diffusion/simple_sampler.py diff --git a/src/mlcast/models/ldcast/diffusion/simple_sampler.py b/src/mlcast/models/ldcast/diffusion/simple_sampler.py new file mode 100644 index 0000000..ad7dd45 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/simple_sampler.py @@ -0,0 +1,84 @@ +# from https://github.com/mfroelund/ldcast-dmi-public/blob/master/ldcast/models/diffusion/diffusion.py, but reworked + +""" +From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py +Pared down to simplify code. + +The original file acknowledges: +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +""" + +import torch +import torch.nn as nn +import pytorch_lightning as L +from functools import partial +import numpy as np + +from .utils import make_beta_schedule, extract_into_tensor + + +class SimpleSampler(L.LightningModule): + '''Sampler used for training (the PLMSSampler is used for inference). The sample method is not consistent with the sample method of PLMSSampler''' + def __init__(self, + timesteps=1000, + beta_schedule="linear", + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + parameterization="eps", # all assuming fixed variance schedules + ): + super().__init__() + + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + self.beta_schedule = beta_schedule + self.timesteps = timesteps + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + def register_schedule(self, denoiser): + + # check if the denoiser has already some saved buffers + buffer_names = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod'] + already_saved = [n for n in buffer_names if n in dict(denoiser.named_buffers()).keys()] + if len(already_saved) > 0: + raise AttributeError(f'The denoiser has already some saved values for {already_saved}') + + betas = make_beta_schedule( + self.beta_schedule, self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, + cosine_s=self.cosine_s + ) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + assert alphas_cumprod.shape[0] == self.timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32, device = next(denoiser.parameters()).device) + + denoiser.register_buffer('betas', to_torch(betas)) + denoiser.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + denoiser.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + denoiser.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + denoiser.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + + + def q_sample(self, denoiser, x_start, noise=None): + '''generate noise target for training''' + if noise is None: + noise = torch.randn_like(x_start) + t = torch.randint(0, self.timesteps, (x_start.shape[0],), device=x_start.device).long() + x_noisy = extract_into_tensor(denoiser.sqrt_alphas_cumprod, t, x_start.shape) * x_start + \ + extract_into_tensor(denoiser.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + return t, noise, x_noisy + + def sample(self, denoiser, conditioning, num_diffusion_iters = 50): + '''sampling for inference, should maybe be implemented to be consistent with the PLMSSampler class''' + pass + \ No newline at end of file From 09a21a34e58e44c456d42768932d493d3fdcf66f Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Fri, 13 Feb 2026 13:51:43 +0100 Subject: [PATCH 03/13] not added the files to the previous commit, here there are --- README.md | 148 +++++---- src/mlcast/models/base.py | 7 +- .../distributions-checkpoint.py | 2 + .../.ipynb_checkpoints/ldcast-checkpoint.py | 308 +----------------- .../.ipynb_checkpoints/utils-checkpoint.py | 2 + .../.ipynb_checkpoints/autoenc-checkpoint.py | 80 ++--- .../.ipynb_checkpoints/encoder-checkpoint.py | 2 + src/mlcast/models/ldcast/autoenc/autoenc.py | 80 ++--- src/mlcast/models/ldcast/autoenc/encoder.py | 2 + .../.ipynb_checkpoints/afno-checkpoint.py | 4 +- .../attention-checkpoint.py | 2 + .../.ipynb_checkpoints/resnet-checkpoint.py | 2 + src/mlcast/models/ldcast/blocks/afno.py | 4 +- src/mlcast/models/ldcast/blocks/attention.py | 2 + src/mlcast/models/ldcast/blocks/resnet.py | 2 + .../.ipynb_checkpoints/context-checkpoint.py | 9 +- .../.ipynb_checkpoints/nowcast-checkpoint.py | 2 + src/mlcast/models/ldcast/context/context.py | 9 +- src/mlcast/models/ldcast/context/nowcast.py | 2 + .../diffusion-checkpoint.py | 304 ++++++----------- .../.ipynb_checkpoints/ema-checkpoint.py | 2 + .../.ipynb_checkpoints/plms-checkpoint.py | 34 +- .../.ipynb_checkpoints/unet-checkpoint.py | 4 +- .../.ipynb_checkpoints/utils-checkpoint.py | 2 + .../models/ldcast/diffusion/diffusion.py | 304 ++++++----------- src/mlcast/models/ldcast/diffusion/ema.py | 2 + src/mlcast/models/ldcast/diffusion/plms.py | 34 +- src/mlcast/models/ldcast/diffusion/unet.py | 4 +- src/mlcast/models/ldcast/diffusion/utils.py | 2 + src/mlcast/models/ldcast/distributions.py | 2 + src/mlcast/models/ldcast/ldcast.py | 308 +----------------- src/mlcast/models/ldcast/utils.py | 2 + 32 files changed, 452 insertions(+), 1221 deletions(-) diff --git a/README.md b/README.md index d6e9b47..a54367e 100644 --- a/README.md +++ b/README.md @@ -1,83 +1,101 @@ -# mlcast +# MLCast implementation of LDCast - +see main branch ([https://github.com/mlcast-community/mlcast]) for details. -The MLCast Community is a collaborative effort bringing together meteorological services, research institutions, and academia across Europe to develop a unified Python package for AI-based nowcasting. This is an initiative of the E-AI WG6 (Nowcasting) of EUMETNET. +# Main LDCast class -This repo contains the `mlcast` package for machine learning-based weather nowcasting. +The main class is LDCast and takes an autoencoder and a latent_nowcaster modules. Only the predict method is implemented, to show the encode-latent_nowcasting-decode pattern. -## Project Status - -⚠️ **Under Development** - This package is currently in early development stages and not usable by end users. The API and functionality are subject to change. - -## Installation -```bash -# Install from pypi -pip install mlcast ``` - -or -```bash -# Install from source -git clone https://github.com/mlcast-community/mlcast -cd mlcast -uv pip install -e . - -# For development -uv pip install -e ".[dev]" +from src.mlcast.models.ldcast.ldcast import LDCast +ldcast = LDCast(autoencoder, latent_nowcaster) ``` -## Project Structure +# Autoencoder ``` -mlcast/ -├── src/mlcast/ # Main package source code -│ ├── __init__.py # Package initialization and version -│ ├── data/ # Data loading and preprocessing -│ │ ├── zarr_datamodule.py # PyTorch Lightning data module for Zarr -│ │ └── zarr_dataset.py # PyTorch dataset for Zarr arrays -│ ├── models/ # Lightning model implementations -│ │ └── base.py # Abstract base classes for nowcasting models -│ └── modules/ # Pure PyTorch neural network modules -│ └── convgru_modules.py # ConvGRU encoder-decoder modules -├── examples/ # Example scripts and notebooks -│ └── scripts/ -│ └── simple_train.py # Basic training example -├── pyproject.toml # Project metadata and dependencies -├── LICENSE # Apache 2.0 license -└── README.md # This file +from src.mlcast.models.ldcast.autoenc.autoenc import AutoencoderKLNet, autoenc_loss +from src.mlcast.models.base import NowcastingLightningModule +autoencoder = NowcastingLightningModule(AutoencoderKLNet(), autoenc_loss()).to('cuda') ``` - -## Development - -This project uses `uv` for dependency management. To set up the development environment: - -```bash -# Install uv if not already installed -curl -LsSf https://astral.sh/uv/install.sh | sh - -# Install dependencies -uv sync - -# Run pre-commit hooks -uv run pre-commit install +The autoencoder is an instance of the NowcastingLightningModule. Training the autoencoder: +``` +# create fake data +x = torch.randn(2, 1, 4, 256, 256, device = 'cuda', requires_grad = False) +y = autoencoder(x, 4)[0] +y = y.detach() +batch = (x, y) + +import pytorch_lightning as L +trainer = L.Trainer() +trainer.fit(autoencoder, batch) ``` -## Contributing - -Please feel free to raise issues or PRs if you have any suggestions or questions. - -## Links to presentations for discussion about the API +# Latent nowcaster (= conditioner + denoiser + samplers) +The latent nowcaster manages the conditioner, the denoiser and the samplers. There can be two different samplers for training and for inference. +``` +# setup forecaster +conditioner = AFNONowcastNetCascade( + 32, + train_autoenc=False, + output_patches=future_timesteps//autoenc_time_ratio, + cascade_depth=3, + embed_dim=128, + analysis_depth=4 +).to('cuda') + +# setup denoiser +from src.mlcast.models.ldcast.diffusion.unet import UNetModel +denoiser = UNetModel(in_channels=autoencoder.net.hidden_width, + model_channels=256, out_channels=autoencoder.net.hidden_width, + num_res_blocks=2, attention_resolutions=(1,2), + dims=3, channel_mult=(1, 2, 4), num_heads=8, + num_timesteps=future_timesteps//autoenc_time_ratio, + context_ch=[128, 256, 512] # context channels (= analysis_net.cascade_dims) + ).to('cuda') + +# define the training and inference samplers +training_sampler = SimpleSampler() +inference_sampler = PLMSSampler(denoiser, 1000) + +# define the latent_nowcaster +from torch.nn import L1Loss +from src.mlcast.models.ldcast.diffusion.diffusion import LatentNowcaster +latent_nowcaster = LatentNowcaster(conditioner, denoiser, L1Loss(), training_sampler, inference_sampler) +``` +Create fake data for inference and training: +``` +inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') +target = torch.randn(2, 1, 20, 256, 256, device = 'cuda') +loss = nn.L1Loss() +autoencoder.eval() +latent_inputs = autoencoder.net.encode(inputs)[0].detach() +latent_target = autoencoder.net.encode(target)[0].detach() +``` +Inference with the latent_nowcaster (PLMSSampler is used during inference) +``` +latent_nowcaster.infer(latent_inputs) +``` +Training the latent_nowcaster (SimpleSampler is used during training) +``` +from torch.utils.data import DataLoader, TensorDataset +dataset = TensorDataset(latent_inputs, latent_target) +dataloader = DataLoader(dataset, batch_size=2) + +latent_batch = (latent_inputs, latent_target) +import pytorch_lightning as L +trainer = L.Trainer() +trainer.fit(latent_nowcaster, dataloader) +``` -- [2025/02/04 first design discussions](https://docs.google.com/presentation/d/1oWmnyxOfUMWgeQi0XyX4fX9YDMX1vl6h/edit?usp=drive_link&rtpof=true&sd=true) +# Notes -## License +I did not manage to make LatentNowcaster a sublcass of NowcastingLightningModule because I would basically have to overwrite everything... LatentNowcaster needs two nets (denoiser and conditioner) and the training logic is not as straightforward as it is for the moment in NowcastingLightningModule. One should also take into account the fact that two different samplers are used for training and inference, so that the forward method can not just be self.net(x) -This project is dual-licensed under either: +It would be nice to have cleaner and consistent APIs for the samplers. For the moment, the PLMSSampler and the SimpleSampler are not totally consistent in their APIs, because the SimpleSampler (better/more common name for this one?) was only used during training, while the PLMSSampler was used during inference. The handling of the schedule of each sampler with respect to the schedule saved in the denoiser could also be clearer. -* Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -* BSD 3-Clause License ([LICENSE-BSD](LICENSE-BSD) or https://opensource.org/licenses/BSD-3-Clause) +During training, an EMA scope was used for the weights of the denoiser, I removed this for the moment, but it should reincluded in some way. -at your option. +The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. -See [LICENSE](LICENSE) for more details. +In /src/mlcast/models/ldcast/diffusion/diffusion.py, one has to choose which sampler to use for testing \ No newline at end of file diff --git a/src/mlcast/models/base.py b/src/mlcast/models/base.py index d526d63..bb4783d 100644 --- a/src/mlcast/models/base.py +++ b/src/mlcast/models/base.py @@ -142,17 +142,18 @@ def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> to Loss value for the current batch """ x, y = batch - predictions = self.forward(x, n_timesteps=y.shape[1]) + predictions = self.forward(x, n_timesteps=4) loss = self.loss(predictions, y) if isinstance(loss, dict): # append step name to loss keys for logging - loss = {f"{step_name}/{k}": v.item() for k, v in loss} + loss = {f"{step_name}/{k}": v for k, v in loss.items()} self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - loss = loss.get("loss", loss.get("total_loss", None)) + loss = loss.get(f"{step_name}/total_loss", None) if loss is None: raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") else: self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py index b7f68c2..3dcb183 100644 --- a/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py +++ b/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/distributions.py + import numpy as np import torch diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py index b49c551..444d6ce 100644 --- a/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py +++ b/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py @@ -1,4 +1,4 @@ -"""LDCast model implementation compliant with mlcast-ldcast structure.""" +# new file with respect to original code import abc from pathlib import Path @@ -12,308 +12,26 @@ from ..base import NowcastingModelBase, NowcastingLightningModule - -class LDCastLightningModule(NowcastingLightningModule): - """PyTorch Lightning module for LDCast diffusion model.""" - - def __init__( - self, - net: nn.Module, - loss: nn.Module, - optimizer_class: type | None = None, - optimizer_kwargs: dict | None = None, - **kwargs: Any, - ): - super().__init__( - net=net, - loss=loss, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - **kwargs, - ) - - class LDCast(NowcastingModelBase): - """LDCast precipitation nowcasting model. - - This model implements a latent diffusion approach for precipitation forecasting, - combining an autoencoder for dimensionality reduction with a diffusion model - for temporal prediction. - - Attributes: - timestep_length: Time resolution of predictions (e.g., 5 minutes) - PLModuleClass: The Lightning module class used for training - """ - - timestep_length: np.timedelta64 | None = None - #PLModuleClass = LDCastLightningModule - def __init__(self, config: dict | None = None): - """Initialize LDCast model. - - Args: - config: Configuration dictionary with model parameters - """ + def __init__(self, autoencoder, latent_nowcaster): #super().__init__() - self.pl_module = LDCastLightningModule(nn.Module(), nn.Module()) - self.config = config or {} - self.autoencoder = None - self.diffusion_model = None - self.scaler = None - - def save(self, path: str, **kwargs: Any) -> None: - """Save the trained LDCast model to disk. - - Args: - path: File path where the model should be saved - **kwargs: Additional arguments for model saving - """ - model_path = Path(path) - model_path.mkdir(parents=True, exist_ok=True) - - # Save autoencoder weights - if self.autoencoder is not None: - torch.save( - self.autoencoder.state_dict(), - model_path / "autoencoder.pt" - ) - - # Save diffusion model weights - if self.diffusion_model is not None: - torch.save( - self.diffusion_model.state_dict(), - model_path / "diffusion_model.pt" - ) - - # Save scaler parameters if present - if self.scaler is not None: - import pickle - with open(model_path / "scaler.pkl", "wb") as f: - pickle.dump(self.scaler, f) - - # Save configuration - import json - with open(model_path / "config.json", "w") as f: - json.dump(self.config, f) - - def load(self, path: str, **kwargs: Any) -> None: - """Load a pre-trained LDCast model from disk. - - Args: - path: File path to the saved model - **kwargs: Additional arguments for model loading - """ - model_path = Path(path) - - # Load configuration - import json - with open(model_path / "config.json", "r") as f: - self.config = json.load(f) - - # Load autoencoder weights if available - autoenc_path = model_path / "autoencoder.pt" - if autoenc_path.exists(): - # Initialize autoencoder architecture from config - self.autoencoder = self._build_autoencoder() - self.autoencoder.load_state_dict(torch.load(autoenc_path)) - - # Load diffusion model weights if available - diffusion_path = model_path / "diffusion_model.pt" - if diffusion_path.exists(): - # Initialize diffusion model architecture from config - self.diffusion_model = self._build_diffusion_model() - self.diffusion_model.load_state_dict(torch.load(diffusion_path)) - - # Load scaler parameters if available - scaler_path = model_path / "scaler.pkl" - if scaler_path.exists(): - import pickle - with open(scaler_path, "rb") as f: - self.scaler = pickle.load(f) + self.autoencoder = autoencoder + self.latent_nowcaster = latent_nowcaster def fit(self, da_rr: xr.DataArray, **kwargs: Any) -> None: - """Train the LDCast model on precipitation data. - - Args: - da_rr: xarray DataArray containing precipitation radar data - with time, latitude, and longitude dimensions - **kwargs: Additional arguments: - - epochs: Number of training epochs - - batch_size: Batch size for training - - val_split: Validation split ratio - - num_timesteps: Number of input timesteps - """ - # Extract configuration from kwargs - epochs = kwargs.get('epochs', self.config.get('max_epochs', 100)) - batch_size = kwargs.get('batch_size', self.config.get('batch_size', 32)) - num_timesteps = kwargs.get('num_timesteps', self.config.get('timesteps', 12)) - - # Step 1: Data preprocessing and scaling - self._preprocess_data(da_rr, **kwargs) - - # Step 2: Train autoencoder - self._train_autoencoder( - da_rr, - epochs=epochs, - batch_size=batch_size, - **kwargs - ) - - # Step 3: Train diffusion model - self._train_diffusion_model( - da_rr, - num_timesteps=num_timesteps, - epochs=epochs, - batch_size=batch_size, - **kwargs - ) - - # Store timestep length - if 'time' in da_rr.dims: - time_coords = da_rr.coords['time'].values - if len(time_coords) > 1: - self.timestep_length = np.timedelta64( - int(np.diff(time_coords[:2])[0]), 'ns' - ) - - def predict( - self, - da_rr: xr.DataArray, - duration: str, - **kwargs: Any - ) -> xr.DataArray: - """Generate precipitation forecasts. - - Args: - da_rr: xarray DataArray containing initial precipitation conditions - duration: ISO 8601 duration string (e.g., "PT1H" for 1 hour) - **kwargs: Additional arguments: - - num_samples: Number of ensemble samples to generate - - num_diffusion_steps: Number of diffusion steps - - Returns: - xarray DataArray containing precipitation predictions with - original spatial dimensions plus an "elapsed_time" dimension - """ - from isodate import parse_duration - - # Parse duration string - duration_obj = parse_duration(duration) - num_forecasts = int(duration_obj.total_seconds() / - self.timestep_length.astype(int)) - - # Extract configuration from kwargs - num_samples = kwargs.get('num_samples', 1) - num_diffusion_steps = kwargs.get('num_diffusion_steps', 50) - - # Preprocess input using stored scaler - processed_input = self._preprocess_input(da_rr) - - # Encode to latent space using autoencoder - with torch.no_grad(): - latent = self.autoencoder.encode(processed_input) - - # Generate predictions using diffusion model - predictions = [] - for _ in range(num_samples): - pred = self._diffusion_predict( - latent, - num_forecasts, - num_diffusion_steps - ) - predictions.append(pred) - - # Stack and average predictions - predictions = torch.stack(predictions, dim=0).mean(dim=0) - - # Decode from latent space - with torch.no_grad(): - forecasted = self.autoencoder.decode(predictions) - - # Postprocess and convert back to original scale - output = self._postprocess_output(forecasted, da_rr) - - # Create output DataArray with elapsed_time dimension - time_coords = da_rr.coords['time'].values[-1] - elapsed_times = [ - np.timedelta64(i, 'm') * 5 # Assuming 5-minute steps - for i in range(1, num_forecasts + 1) - ] - - output_da = xr.DataArray( - output, - dims=['elapsed_time', 'latitude', 'longitude'], - coords={ - 'elapsed_time': ('elapsed_time', elapsed_times), - 'latitude': ('latitude', da_rr.coords['latitude'].values), - 'longitude': ('longitude', da_rr.coords['longitude'].values), - }, - name='precipitation' - ) - - return output_da - - def _preprocess_data(self, da_rr: xr.DataArray, **kwargs: Any) -> None: - """Preprocess precipitation data and fit scaler.""" - # Implement data scaling/normalization - # Store scaling parameters in self.scaler - pass - - def _train_autoencoder( - self, - da_rr: xr.DataArray, - epochs: int, - batch_size: int, - **kwargs: Any - ) -> None: - """Train the autoencoder component.""" - # Import and use ldcast autoencoder training - from ldcast.models.autoenc import setup_and_train - # Implementation details - pass - - def _train_diffusion_model( - self, - da_rr: xr.DataArray, - num_timesteps: int, - epochs: int, - batch_size: int, - **kwargs: Any - ) -> None: - """Train the diffusion model component.""" - # Import and use ldcast genforecast training - from ldcast.models.genforecast import setup_and_train - # Implementation details - pass - - def _preprocess_input(self, da_rr: xr.DataArray) -> torch.Tensor: - """Convert input xarray to scaled tensor.""" - # Apply stored scaler pass + - def _postprocess_output( - self, - output: torch.Tensor, - reference_da: xr.DataArray - ) -> np.ndarray: - """Convert predictions back to original scale and format.""" - # Reverse scaling using stored scaler - pass + def predict(self, inputs): + '''inputs is of shape (batch_size, 1, 4) + spatial_shape''' + latent_inputs = self.autoencoder.net.encode(inputs) + latent_pred = self.latent_nowcaster(latent_inputs) + return self.autoencoder.net.decode(latent_pred) - def _diffusion_predict( - self, - latent: torch.Tensor, - num_forecasts: int, - num_steps: int - ) -> torch.Tensor: - """Generate predictions using the diffusion model.""" - # Use ldcast diffusion inference + def _train_autoencoder(self, da_rr: xr.DataArray, epochs: int, batch_size: int, **kwargs: Any) -> None: pass - def _build_autoencoder(self) -> nn.Module: - """Build autoencoder architecture from config.""" + def _train_latent_nowcaster(self, da_rr: xr.DataArray, num_timesteps: int, epochs: int, batch_size: int, **kwargs: Any) -> None: pass - - def _build_diffusion_model(self) -> nn.Module: - """Build diffusion model architecture from config.""" - pass \ No newline at end of file + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py index 65cada8..f38bd29 100644 --- a/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py +++ b/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/utils.py + import torch from torch import nn diff --git a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py index 51cc18b..c7f43a6 100644 --- a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py +++ b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py @@ -1,6 +1,9 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py + import pytorch_lightning as pl import torch from torch import nn +from .encoder import SimpleConvEncoder, SimpleConvDecoder from ..distributions import ( ensemble_nll_normal, @@ -8,12 +11,27 @@ sample_from_standard_normal, ) +class autoenc_loss(nn.Module): + def __init__(self, kl_weight = 0.01): + super().__init__() + self.kl_weight = kl_weight + + def forward(self, predictions, y): + (y_pred, mean, log_var) = predictions + + rec_loss = (y - y_pred).abs().mean() + kl_loss = kl_from_standard_normal(mean, log_var) + + total_loss = rec_loss + self.kl_weight * kl_loss + + return {'total_loss': total_loss, 'rec_loss': rec_loss, 'kl_loss': kl_loss} -class AutoencoderKL(pl.LightningModule): + +class AutoencoderKLNet(pl.LightningModule): def __init__( self, - encoder, - decoder, + encoder = SimpleConvEncoder(), + decoder = SimpleConvDecoder(), kl_weight=0.01, encoded_channels=64, hidden_width=32, @@ -29,6 +47,8 @@ def __init__( self.kl_weight = kl_weight def encode(self, x): + if len(x.shape) < 5: + x = x[None] h = self.encoder(x) (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) return (mean, log_var) @@ -38,59 +58,11 @@ def decode(self, z): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): - (mean, log_var) = self.encode(input) + def forward(self, x, n_timesteps, sample_posterior=True): + (mean, log_var) = self.encode(x) if sample_posterior: z = sample_from_standard_normal(mean, log_var) else: z = mean dec = self.decode(z) - return (dec, mean, log_var) - - def _loss(self, batch): - (x, y) = batch - while isinstance(x, list) or isinstance(x, tuple): - x = x[0][0] - (y_pred, mean, log_var) = self.forward(x) - - rec_loss = (y - y_pred).abs().mean() - kl_loss = kl_from_standard_normal(mean, log_var) - - total_loss = rec_loss + self.kl_weight * kl_loss - - return (total_loss, rec_loss, kl_loss) - - def training_step(self, batch, batch_idx): - loss = self._loss(batch)[0] - self.log("train_loss", loss, on_step=True) - return loss - - @torch.no_grad() - def val_test_step(self, batch, batch_idx, split="val"): - (total_loss, rec_loss, kl_loss) = self._loss(batch) - log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} - self.log(f"{split}_loss", total_loss, **log_params, sync_dist=True) - self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params, sync_dist=True) - self.log(f"{split}_kl_loss", kl_loss, **log_params, sync_dist=True) - - def validation_step(self, batch, batch_idx): - self.val_test_step(batch, batch_idx, split="val") - - def test_step(self, batch, batch_idx): - self.val_test_step(batch, batch_idx, split="test") - - def configure_optimizers(self): - optimizer = torch.optim.AdamW( - self.parameters(), lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-3 - ) - reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, patience=3, factor=0.25, verbose=True - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": reduce_lr, - "monitor": "val_rec_loss", - "frequency": 1, - }, - } \ No newline at end of file + return (dec, mean, log_var) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py index aab9f7a..157af11 100644 --- a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py +++ b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py + import numpy as np import torch.nn as nn diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py index 51cc18b..c7f43a6 100644 --- a/src/mlcast/models/ldcast/autoenc/autoenc.py +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -1,6 +1,9 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py + import pytorch_lightning as pl import torch from torch import nn +from .encoder import SimpleConvEncoder, SimpleConvDecoder from ..distributions import ( ensemble_nll_normal, @@ -8,12 +11,27 @@ sample_from_standard_normal, ) +class autoenc_loss(nn.Module): + def __init__(self, kl_weight = 0.01): + super().__init__() + self.kl_weight = kl_weight + + def forward(self, predictions, y): + (y_pred, mean, log_var) = predictions + + rec_loss = (y - y_pred).abs().mean() + kl_loss = kl_from_standard_normal(mean, log_var) + + total_loss = rec_loss + self.kl_weight * kl_loss + + return {'total_loss': total_loss, 'rec_loss': rec_loss, 'kl_loss': kl_loss} -class AutoencoderKL(pl.LightningModule): + +class AutoencoderKLNet(pl.LightningModule): def __init__( self, - encoder, - decoder, + encoder = SimpleConvEncoder(), + decoder = SimpleConvDecoder(), kl_weight=0.01, encoded_channels=64, hidden_width=32, @@ -29,6 +47,8 @@ def __init__( self.kl_weight = kl_weight def encode(self, x): + if len(x.shape) < 5: + x = x[None] h = self.encoder(x) (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) return (mean, log_var) @@ -38,59 +58,11 @@ def decode(self, z): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): - (mean, log_var) = self.encode(input) + def forward(self, x, n_timesteps, sample_posterior=True): + (mean, log_var) = self.encode(x) if sample_posterior: z = sample_from_standard_normal(mean, log_var) else: z = mean dec = self.decode(z) - return (dec, mean, log_var) - - def _loss(self, batch): - (x, y) = batch - while isinstance(x, list) or isinstance(x, tuple): - x = x[0][0] - (y_pred, mean, log_var) = self.forward(x) - - rec_loss = (y - y_pred).abs().mean() - kl_loss = kl_from_standard_normal(mean, log_var) - - total_loss = rec_loss + self.kl_weight * kl_loss - - return (total_loss, rec_loss, kl_loss) - - def training_step(self, batch, batch_idx): - loss = self._loss(batch)[0] - self.log("train_loss", loss, on_step=True) - return loss - - @torch.no_grad() - def val_test_step(self, batch, batch_idx, split="val"): - (total_loss, rec_loss, kl_loss) = self._loss(batch) - log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} - self.log(f"{split}_loss", total_loss, **log_params, sync_dist=True) - self.log(f"{split}_rec_loss", rec_loss.mean(), **log_params, sync_dist=True) - self.log(f"{split}_kl_loss", kl_loss, **log_params, sync_dist=True) - - def validation_step(self, batch, batch_idx): - self.val_test_step(batch, batch_idx, split="val") - - def test_step(self, batch, batch_idx): - self.val_test_step(batch, batch_idx, split="test") - - def configure_optimizers(self): - optimizer = torch.optim.AdamW( - self.parameters(), lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-3 - ) - reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, patience=3, factor=0.25, verbose=True - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": reduce_lr, - "monitor": "val_rec_loss", - "frequency": 1, - }, - } \ No newline at end of file + return (dec, mean, log_var) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/encoder.py b/src/mlcast/models/ldcast/autoenc/encoder.py index aab9f7a..157af11 100644 --- a/src/mlcast/models/ldcast/autoenc/encoder.py +++ b/src/mlcast/models/ldcast/autoenc/encoder.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py + import numpy as np import torch.nn as nn diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py index 3e7f801..84c73d0 100644 --- a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py +++ b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/afno.py + #reference: https://github.com/NVlabs/AFNO-transformer import numpy as np import torch @@ -336,7 +338,7 @@ def forward(self, x, y): # AFNO natively uses a channels-last order x = x.permute(0,2,3,4,1) y = y.permute(0,2,3,4,1) - + xy = torch.concat((self.norm1(x),y), axis=-1) xy = self.pre_proj(xy) + xy xy = self.filter(self.norm2(xy)) + xy # AFNO filter diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py index c3b791e..b8b3149 100644 --- a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py +++ b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/attention.py + import math import torch diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py index 90dacbc..983092d 100644 --- a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py +++ b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/resnet.py + from torch import nn from torch.nn.utils.parametrizations import spectral_norm as sn diff --git a/src/mlcast/models/ldcast/blocks/afno.py b/src/mlcast/models/ldcast/blocks/afno.py index 3e7f801..84c73d0 100644 --- a/src/mlcast/models/ldcast/blocks/afno.py +++ b/src/mlcast/models/ldcast/blocks/afno.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/afno.py + #reference: https://github.com/NVlabs/AFNO-transformer import numpy as np import torch @@ -336,7 +338,7 @@ def forward(self, x, y): # AFNO natively uses a channels-last order x = x.permute(0,2,3,4,1) y = y.permute(0,2,3,4,1) - + xy = torch.concat((self.norm1(x),y), axis=-1) xy = self.pre_proj(xy) + xy xy = self.filter(self.norm2(xy)) + xy # AFNO filter diff --git a/src/mlcast/models/ldcast/blocks/attention.py b/src/mlcast/models/ldcast/blocks/attention.py index c3b791e..b8b3149 100644 --- a/src/mlcast/models/ldcast/blocks/attention.py +++ b/src/mlcast/models/ldcast/blocks/attention.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/attention.py + import math import torch diff --git a/src/mlcast/models/ldcast/blocks/resnet.py b/src/mlcast/models/ldcast/blocks/resnet.py index 90dacbc..983092d 100644 --- a/src/mlcast/models/ldcast/blocks/resnet.py +++ b/src/mlcast/models/ldcast/blocks/resnet.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/resnet.py + from torch import nn from torch.nn.utils.parametrizations import spectral_norm as sn diff --git a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py index ac76a7d..caedbb6 100644 --- a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py +++ b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/analysis.py + import torch from torch import nn from torch.nn import functional as F @@ -21,8 +23,11 @@ def __init__(self, *args, cascade_depth=4, **kwargs): ) ch = ch_out - def forward(self, x, timesteps): - x = super().forward(x, timesteps) + def forward(self, x): + # the past timesteps are needed here, but they are always the same... + # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why + past_timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32).unsqueeze(0).expand(1,-1) + x = super().forward(x, past_timesteps) img_shape = tuple(x.shape[-2:]) cascade = {img_shape: x} for i in range(self.cascade_depth-1): diff --git a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py index 54cc43b..f13b994 100644 --- a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py +++ b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/nowcast/nowcast.py, but removed the Nowcaster, AFNONowcastNetBasic and AFNONowcastNet classes because they were not used. Reworked also the two remaining classes (FusionBlock3D and AFNONowcastNetBase) to simplify the code by removing the unused parts. + import collections import torch diff --git a/src/mlcast/models/ldcast/context/context.py b/src/mlcast/models/ldcast/context/context.py index ac76a7d..caedbb6 100644 --- a/src/mlcast/models/ldcast/context/context.py +++ b/src/mlcast/models/ldcast/context/context.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/analysis.py + import torch from torch import nn from torch.nn import functional as F @@ -21,8 +23,11 @@ def __init__(self, *args, cascade_depth=4, **kwargs): ) ch = ch_out - def forward(self, x, timesteps): - x = super().forward(x, timesteps) + def forward(self, x): + # the past timesteps are needed here, but they are always the same... + # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why + past_timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32).unsqueeze(0).expand(1,-1) + x = super().forward(x, past_timesteps) img_shape = tuple(x.shape[-2:]) cascade = {img_shape: x} for i in range(self.cascade_depth-1): diff --git a/src/mlcast/models/ldcast/context/nowcast.py b/src/mlcast/models/ldcast/context/nowcast.py index 54cc43b..f13b994 100644 --- a/src/mlcast/models/ldcast/context/nowcast.py +++ b/src/mlcast/models/ldcast/context/nowcast.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/nowcast/nowcast.py, but removed the Nowcaster, AFNONowcastNetBasic and AFNONowcastNet classes because they were not used. Reworked also the two remaining classes (FusionBlock3D and AFNONowcastNetBase) to simplify the code by removing the unused parts. + import collections import torch diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py index 301e7dd..480a623 100644 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py @@ -1,220 +1,132 @@ -""" -From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py -Pared down to simplify code. - -The original file acknowledges: -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers -""" - import torch import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from contextlib import contextmanager -from functools import partial - +import pytorch_lightning as L +from typing import Any import contextlib -from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding -from .ema import LitEma -from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d -from .plms import PLMSSampler - -print('take care of ema scope') - -class DiffusionModel(pl.LightningModule): # replaces LatentDiffusion - def __init__(self, - denoiser, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - use_ema=True, - lr=1e-4, - lr_warmup=0, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - parameterization="eps", # all assuming fixed variance schedules +print('take care of ema scope, which was used as context manager each exactly when denoiser.forward was called, so it should be a taken care of in the code code about the denoiser or about the diffuser (nothing to do with samplers)') + +import pytorch_lightning as L +class LatentNowcaster(L.LightningModule): + """Base class for PyTorch Lightning modules used in nowcasting models. + + This class provides a standard interface for training and validation + steps, as well as optimizer configuration. + """ + + def __init__( + self, + conditioner: nn.Module, + denoiser: nn.Module, + loss: nn.Module, + training_sampler: nn.Module, + inference_sampler: nn.Module, + optimizer_class: Any | None = None, + optimizer_kwargs: dict | None = None, + **kwargs: Any, ): super().__init__() + self.save_hyperparameters(ignore=["denoiser", "conditioner", "training_sampler", "inference_sampler", "loss"]) + self.conditioner = conditioner self.denoiser = denoiser - self.lr = lr - self.lr_warmup = lr_warmup + self.loss = loss + self.training_sampler = training_sampler + self.inference_sampler = inference_sampler + self.optimizer_class = torch.optim.Adam if optimizer_class is None else optimizer_class - assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - - self.use_ema = use_ema - if self.use_ema: - self.denoiser_ema = LitEma(self.denoiser) - - self.register_schedule( - beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s - ) + training_sampler.register_schedule(denoiser) - self.loss_type = loss_type + def infer(self, latent_inputs, num_diffusion_iters = 50, verbose = True): - self.sampler = PLMSSampler(self.denoiser, timesteps) + condition = self.conditioner(latent_inputs) - def forward(self, conditioning, num_diffusion_iters = 50, verbose = True): gen_shape = (32, 5, 256//4, 256//4) + batch_size = len(list(condition.values())[0]) with contextlib.redirect_stdout(None): - (s, intermediates) = self.sampler.sample( + (s, intermediates) = self.inference_sampler.sample( num_diffusion_iters, - 1, # batch_size + batch_size, gen_shape, - self.q_sample, - conditioning, + condition, progbar=verbose ) return s - - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - - betas = make_beta_schedule( - beta_schedule, timesteps, - linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s - ) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.denoiser.register_buffer('betas', to_torch(betas)) - self.denoiser.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.denoiser.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.denoiser.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.denoiser.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.denoiser_ema.store(self.denoiser.parameters()) - self.denoiser_ema.copy_to(self.denoiser) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.denoiser_ema.restore(self.denoiser.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def q_sample(self, x_start, t, noise=None): - if noise is None: - noise = torch.randn_like(x_start) - return ( - extract_into_tensor(self.denoiser.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.denoiser.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == 'l2': - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - return loss + def model_step(self, latent_batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: + """Generic model step for training or validation. - def p_losses(self, x_start, t, noise=None, context=None): - if noise is None: - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - denoised = self.denoiser(x_noisy, t, context=context) + Args: + batch: Input batch of data + batch_idx: Index of the current batch - if self.parameterization == "eps": - target = noise - elif self.parameterization == "x0": - target = x_start + Returns: + Loss value for the current batch + """ + latent_inputs, latent_targets = latent_batch + + condition = self.conditioner(latent_inputs) + t, noise, latent_target_noisy = self.training_sampler.q_sample(self.denoiser, latent_targets) + guessed_noise = self.denoiser(latent_target_noisy, t, context = condition) + loss = self.loss(guessed_noise, noise) + + if isinstance(loss, dict): + # append step name to loss keys for logging + loss = {f"{step_name}/{k}": v for k, v in loss.items()} + self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + loss = loss.get(f"{step_name}/total_loss", None) + if loss is None: + raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") else: - raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") - - return self.get_loss(denoised, target, mean=False).mean() - ''' - def forward(self, x, *args, **kwargs): - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - return self.p_losses(x, t, *args, **kwargs) - ''' - ''' - def shared_step(self, batch): - (x,y) = batch - y = self.autoencoder.encode(y)[0] - context = self.context_encoder(x) if self.conditional else None - return self(y, context=context) - ''' - def training_step(self, batch, batch_idx): - loss = self.shared_step(batch) - self.log("train_loss", loss) + self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss - @torch.no_grad() - def validation_step(self, batch, batch_idx): - loss = self.shared_step(batch) - with self.ema_scope(): - loss_ema = self.shared_step(batch) - log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} - self.log("val_loss", loss, **log_params) - self.log("val_loss_ema", loss, **log_params) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.denoiser_ema(self.denoiser) - - def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, - betas=(0.5, 0.9), weight_decay=1e-3) - reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, patience=3, factor=0.25, verbose=True - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": reduce_lr, - "monitor": "val_loss_ema", - "frequency": 1, - }, - } - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - **kwargs - ): - if self.trainer.global_step < self.lr_warmup: - lr_scale = (self.trainer.global_step+1) / self.lr_warmup - for pg in optimizer.param_groups: - pg['lr'] = lr_scale * self.lr - - super().optimizer_step( - epoch, batch_idx, optimizer, - optimizer_idx, optimizer_closure, - **kwargs - ) - \ No newline at end of file + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Training step for a single batch. + + Args: + batch: Input batch of data + batch_idx: Index of the current batch + + Returns: + Loss value for the current batch + """ + return self.model_step(batch, batch_idx, step_name="train") + + def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Validation step for a single batch. + + Args: + batch: Input batch of data + batch_idx: Index of the current batch + + Returns: + Loss value for the current batch + """ + return self.model_step(batch, batch_idx, step_name="val") + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure the optimizer for training. + + Returns: + Optimizer instance to use for training + """ + return self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) + + + def on_train_start(self): + self._current_sampler = self.training_sampler + super().on_train_start() + + def on_validation_start(self): + self._current_sampler_mode = self.training_sampler + super().on_validation_start() + + def on_predict_start(self): + self._current_sampler_mode = self.inference_sampler + super().on_predict_start() + + def on_test_start(self): + # training or inference sampler ??? + self._current_sampler_mode = self.training_sampler + super().on_test_start() \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py index cd2f8e3..296c8a3 100644 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/ema.py + import torch from torch import nn diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py index bc87241..3b733d8 100644 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py @@ -1,3 +1,6 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/plms.py, but changed model.apply_model into model.forward + + """ From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py """ @@ -38,7 +41,8 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + device = next(self.model.parameters()).device + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) self.register_buffer("betas", to_torch(self.model.betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) @@ -93,19 +97,15 @@ def sample( S, batch_size, shape, - q_sample_func, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, - mask=None, x0=None, temperature=1.0, noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, @@ -133,17 +133,13 @@ def sample( samples, intermediates = self.plms_sampling( conditioning, size, - q_sample_func, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, - mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, @@ -157,7 +153,6 @@ def plms_sampling( self, cond, shape, - q_sample_func, x_T=None, ddim_use_original_steps=False, callback=None, @@ -170,7 +165,6 @@ def plms_sampling( temperature=1.0, noise_dropout=0.0, score_corrector=None, - corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, progbar=True, @@ -222,14 +216,6 @@ def plms_sampling( dtype=torch.long, ) - if mask is not None: - assert x0 is not None - img_orig = q_sample_func( - x0, ts - ) # TODO: deterministic forward pass? - print('after q_sample 1', img_orig.shape) - img = img_orig * mask + (1.0 - mask) * img - print('after q_sample 2', img.shape) outs = self.p_sample_plms( img, cond, @@ -239,8 +225,6 @@ def plms_sampling( quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, @@ -273,8 +257,6 @@ def p_sample_plms( quantize_denoised=False, temperature=1.0, noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, old_eps=None, @@ -297,11 +279,7 @@ def get_model_output(x, t): e_t_uncond, e_t = self.model.apply_denoiser(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) ''' - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, condition, **corrector_kwargs - ) + return e_t diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py index 1eb22d7..f63d2d4 100644 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/unet.py: weirdly, this part (which is the denoiser) was in not in the diffusion folder in the original code + from abc import abstractmethod from functools import partial import math @@ -464,7 +466,7 @@ def __init__( nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) - self.device = next(self.parameters()).device + def forward(self, x, timesteps=None, context=None): """ Apply the model to an input batch. diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py index ab90f9e..e908cd1 100644 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py +++ b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/utils.py + # adopted from # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py # and diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index 301e7dd..480a623 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -1,220 +1,132 @@ -""" -From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py -Pared down to simplify code. - -The original file acknowledges: -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers -""" - import torch import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from contextlib import contextmanager -from functools import partial - +import pytorch_lightning as L +from typing import Any import contextlib -from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding -from .ema import LitEma -from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d -from .plms import PLMSSampler - -print('take care of ema scope') - -class DiffusionModel(pl.LightningModule): # replaces LatentDiffusion - def __init__(self, - denoiser, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - use_ema=True, - lr=1e-4, - lr_warmup=0, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - parameterization="eps", # all assuming fixed variance schedules +print('take care of ema scope, which was used as context manager each exactly when denoiser.forward was called, so it should be a taken care of in the code code about the denoiser or about the diffuser (nothing to do with samplers)') + +import pytorch_lightning as L +class LatentNowcaster(L.LightningModule): + """Base class for PyTorch Lightning modules used in nowcasting models. + + This class provides a standard interface for training and validation + steps, as well as optimizer configuration. + """ + + def __init__( + self, + conditioner: nn.Module, + denoiser: nn.Module, + loss: nn.Module, + training_sampler: nn.Module, + inference_sampler: nn.Module, + optimizer_class: Any | None = None, + optimizer_kwargs: dict | None = None, + **kwargs: Any, ): super().__init__() + self.save_hyperparameters(ignore=["denoiser", "conditioner", "training_sampler", "inference_sampler", "loss"]) + self.conditioner = conditioner self.denoiser = denoiser - self.lr = lr - self.lr_warmup = lr_warmup + self.loss = loss + self.training_sampler = training_sampler + self.inference_sampler = inference_sampler + self.optimizer_class = torch.optim.Adam if optimizer_class is None else optimizer_class - assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - - self.use_ema = use_ema - if self.use_ema: - self.denoiser_ema = LitEma(self.denoiser) - - self.register_schedule( - beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s - ) + training_sampler.register_schedule(denoiser) - self.loss_type = loss_type + def infer(self, latent_inputs, num_diffusion_iters = 50, verbose = True): - self.sampler = PLMSSampler(self.denoiser, timesteps) + condition = self.conditioner(latent_inputs) - def forward(self, conditioning, num_diffusion_iters = 50, verbose = True): gen_shape = (32, 5, 256//4, 256//4) + batch_size = len(list(condition.values())[0]) with contextlib.redirect_stdout(None): - (s, intermediates) = self.sampler.sample( + (s, intermediates) = self.inference_sampler.sample( num_diffusion_iters, - 1, # batch_size + batch_size, gen_shape, - self.q_sample, - conditioning, + condition, progbar=verbose ) return s - - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - - betas = make_beta_schedule( - beta_schedule, timesteps, - linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s - ) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.denoiser.register_buffer('betas', to_torch(betas)) - self.denoiser.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.denoiser.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.denoiser.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.denoiser.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.denoiser_ema.store(self.denoiser.parameters()) - self.denoiser_ema.copy_to(self.denoiser) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.denoiser_ema.restore(self.denoiser.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def q_sample(self, x_start, t, noise=None): - if noise is None: - noise = torch.randn_like(x_start) - return ( - extract_into_tensor(self.denoiser.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.denoiser.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == 'l2': - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - return loss + def model_step(self, latent_batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: + """Generic model step for training or validation. - def p_losses(self, x_start, t, noise=None, context=None): - if noise is None: - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - denoised = self.denoiser(x_noisy, t, context=context) + Args: + batch: Input batch of data + batch_idx: Index of the current batch - if self.parameterization == "eps": - target = noise - elif self.parameterization == "x0": - target = x_start + Returns: + Loss value for the current batch + """ + latent_inputs, latent_targets = latent_batch + + condition = self.conditioner(latent_inputs) + t, noise, latent_target_noisy = self.training_sampler.q_sample(self.denoiser, latent_targets) + guessed_noise = self.denoiser(latent_target_noisy, t, context = condition) + loss = self.loss(guessed_noise, noise) + + if isinstance(loss, dict): + # append step name to loss keys for logging + loss = {f"{step_name}/{k}": v for k, v in loss.items()} + self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + loss = loss.get(f"{step_name}/total_loss", None) + if loss is None: + raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") else: - raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") - - return self.get_loss(denoised, target, mean=False).mean() - ''' - def forward(self, x, *args, **kwargs): - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - return self.p_losses(x, t, *args, **kwargs) - ''' - ''' - def shared_step(self, batch): - (x,y) = batch - y = self.autoencoder.encode(y)[0] - context = self.context_encoder(x) if self.conditional else None - return self(y, context=context) - ''' - def training_step(self, batch, batch_idx): - loss = self.shared_step(batch) - self.log("train_loss", loss) + self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss - @torch.no_grad() - def validation_step(self, batch, batch_idx): - loss = self.shared_step(batch) - with self.ema_scope(): - loss_ema = self.shared_step(batch) - log_params = {"on_step": False, "on_epoch": True, "prog_bar": True} - self.log("val_loss", loss, **log_params) - self.log("val_loss_ema", loss, **log_params) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.denoiser_ema(self.denoiser) - - def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, - betas=(0.5, 0.9), weight_decay=1e-3) - reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, patience=3, factor=0.25, verbose=True - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": reduce_lr, - "monitor": "val_loss_ema", - "frequency": 1, - }, - } - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - **kwargs - ): - if self.trainer.global_step < self.lr_warmup: - lr_scale = (self.trainer.global_step+1) / self.lr_warmup - for pg in optimizer.param_groups: - pg['lr'] = lr_scale * self.lr - - super().optimizer_step( - epoch, batch_idx, optimizer, - optimizer_idx, optimizer_closure, - **kwargs - ) - \ No newline at end of file + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Training step for a single batch. + + Args: + batch: Input batch of data + batch_idx: Index of the current batch + + Returns: + Loss value for the current batch + """ + return self.model_step(batch, batch_idx, step_name="train") + + def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + """Validation step for a single batch. + + Args: + batch: Input batch of data + batch_idx: Index of the current batch + + Returns: + Loss value for the current batch + """ + return self.model_step(batch, batch_idx, step_name="val") + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure the optimizer for training. + + Returns: + Optimizer instance to use for training + """ + return self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) + + + def on_train_start(self): + self._current_sampler = self.training_sampler + super().on_train_start() + + def on_validation_start(self): + self._current_sampler_mode = self.training_sampler + super().on_validation_start() + + def on_predict_start(self): + self._current_sampler_mode = self.inference_sampler + super().on_predict_start() + + def on_test_start(self): + # training or inference sampler ??? + self._current_sampler_mode = self.training_sampler + super().on_test_start() \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py index cd2f8e3..296c8a3 100644 --- a/src/mlcast/models/ldcast/diffusion/ema.py +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/ema.py + import torch from torch import nn diff --git a/src/mlcast/models/ldcast/diffusion/plms.py b/src/mlcast/models/ldcast/diffusion/plms.py index bc87241..3b733d8 100644 --- a/src/mlcast/models/ldcast/diffusion/plms.py +++ b/src/mlcast/models/ldcast/diffusion/plms.py @@ -1,3 +1,6 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/plms.py, but changed model.apply_model into model.forward + + """ From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py """ @@ -38,7 +41,8 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + device = next(self.model.parameters()).device + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) self.register_buffer("betas", to_torch(self.model.betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) @@ -93,19 +97,15 @@ def sample( S, batch_size, shape, - q_sample_func, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, - mask=None, x0=None, temperature=1.0, noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, @@ -133,17 +133,13 @@ def sample( samples, intermediates = self.plms_sampling( conditioning, size, - q_sample_func, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, - mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, @@ -157,7 +153,6 @@ def plms_sampling( self, cond, shape, - q_sample_func, x_T=None, ddim_use_original_steps=False, callback=None, @@ -170,7 +165,6 @@ def plms_sampling( temperature=1.0, noise_dropout=0.0, score_corrector=None, - corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, progbar=True, @@ -222,14 +216,6 @@ def plms_sampling( dtype=torch.long, ) - if mask is not None: - assert x0 is not None - img_orig = q_sample_func( - x0, ts - ) # TODO: deterministic forward pass? - print('after q_sample 1', img_orig.shape) - img = img_orig * mask + (1.0 - mask) * img - print('after q_sample 2', img.shape) outs = self.p_sample_plms( img, cond, @@ -239,8 +225,6 @@ def plms_sampling( quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, @@ -273,8 +257,6 @@ def p_sample_plms( quantize_denoised=False, temperature=1.0, noise_dropout=0.0, - score_corrector=None, - corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, old_eps=None, @@ -297,11 +279,7 @@ def get_model_output(x, t): e_t_uncond, e_t = self.model.apply_denoiser(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) ''' - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score( - self.model, e_t, x, t, condition, **corrector_kwargs - ) + return e_t diff --git a/src/mlcast/models/ldcast/diffusion/unet.py b/src/mlcast/models/ldcast/diffusion/unet.py index 1eb22d7..f63d2d4 100644 --- a/src/mlcast/models/ldcast/diffusion/unet.py +++ b/src/mlcast/models/ldcast/diffusion/unet.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/unet.py: weirdly, this part (which is the denoiser) was in not in the diffusion folder in the original code + from abc import abstractmethod from functools import partial import math @@ -464,7 +466,7 @@ def __init__( nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) - self.device = next(self.parameters()).device + def forward(self, x, timesteps=None, context=None): """ Apply the model to an input batch. diff --git a/src/mlcast/models/ldcast/diffusion/utils.py b/src/mlcast/models/ldcast/diffusion/utils.py index ab90f9e..e908cd1 100644 --- a/src/mlcast/models/ldcast/diffusion/utils.py +++ b/src/mlcast/models/ldcast/diffusion/utils.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/utils.py + # adopted from # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py # and diff --git a/src/mlcast/models/ldcast/distributions.py b/src/mlcast/models/ldcast/distributions.py index b7f68c2..3dcb183 100644 --- a/src/mlcast/models/ldcast/distributions.py +++ b/src/mlcast/models/ldcast/distributions.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/distributions.py + import numpy as np import torch diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index f64f772..444d6ce 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -1,3 +1,5 @@ +# new file with respect to original code + import abc from pathlib import Path from typing import Any @@ -10,308 +12,26 @@ from ..base import NowcastingModelBase, NowcastingLightningModule - -class LDCastLightningModule(NowcastingLightningModule): - """PyTorch Lightning module for LDCast diffusion model.""" - - def __init__( - self, - net: nn.Module, - loss: nn.Module, - optimizer_class: type | None = None, - optimizer_kwargs: dict | None = None, - **kwargs: Any, - ): - super().__init__( - net=net, - loss=loss, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - **kwargs, - ) - - class LDCast(NowcastingModelBase): - """LDCast precipitation nowcasting model. - - This model implements a latent diffusion approach for precipitation forecasting, - combining an autoencoder for dimensionality reduction with a diffusion model - for temporal prediction. - - Attributes: - timestep_length: Time resolution of predictions (e.g., 5 minutes) - PLModuleClass: The Lightning module class used for training - """ - timestep_length: np.timedelta64 | None = None - #PLModuleClass = LDCastLightningModule - - def __init__(self, config: dict | None = None): - """Initialize LDCast model. - - Args: - config: Configuration dictionary with model parameters - """ + def __init__(self, autoencoder, latent_nowcaster): #super().__init__() - self.pl_module = LDCastLightningModule(nn.Module(), nn.Module()) - self.config = config or {} - self.autoencoder = None - self.diffusion_model = None - self.scaler = None - - def save(self, path: str, **kwargs: Any) -> None: - """Save the trained LDCast model to disk. - - Args: - path: File path where the model should be saved - **kwargs: Additional arguments for model saving - """ - model_path = Path(path) - model_path.mkdir(parents=True, exist_ok=True) - - # Save autoencoder weights - if self.autoencoder is not None: - torch.save( - self.autoencoder.state_dict(), - model_path / "autoencoder.pt" - ) - - # Save diffusion model weights - if self.diffusion_model is not None: - torch.save( - self.diffusion_model.state_dict(), - model_path / "diffusion_model.pt" - ) - - # Save scaler parameters if present - if self.scaler is not None: - import pickle - with open(model_path / "scaler.pkl", "wb") as f: - pickle.dump(self.scaler, f) - - # Save configuration - import json - with open(model_path / "config.json", "w") as f: - json.dump(self.config, f) - - def load(self, path: str, **kwargs: Any) -> None: - """Load a pre-trained LDCast model from disk. - - Args: - path: File path to the saved model - **kwargs: Additional arguments for model loading - """ - model_path = Path(path) - - # Load configuration - import json - with open(model_path / "config.json", "r") as f: - self.config = json.load(f) - - # Load autoencoder weights if available - autoenc_path = model_path / "autoencoder.pt" - if autoenc_path.exists(): - # Initialize autoencoder architecture from config - self.autoencoder = self._build_autoencoder() - self.autoencoder.load_state_dict(torch.load(autoenc_path)) - - # Load diffusion model weights if available - diffusion_path = model_path / "diffusion_model.pt" - if diffusion_path.exists(): - # Initialize diffusion model architecture from config - self.diffusion_model = self._build_diffusion_model() - self.diffusion_model.load_state_dict(torch.load(diffusion_path)) - - # Load scaler parameters if available - scaler_path = model_path / "scaler.pkl" - if scaler_path.exists(): - import pickle - with open(scaler_path, "rb") as f: - self.scaler = pickle.load(f) + self.autoencoder = autoencoder + self.latent_nowcaster = latent_nowcaster def fit(self, da_rr: xr.DataArray, **kwargs: Any) -> None: - """Train the LDCast model on precipitation data. - - Args: - da_rr: xarray DataArray containing precipitation radar data - with time, latitude, and longitude dimensions - **kwargs: Additional arguments: - - epochs: Number of training epochs - - batch_size: Batch size for training - - val_split: Validation split ratio - - num_timesteps: Number of input timesteps - """ - # Extract configuration from kwargs - epochs = kwargs.get('epochs', self.config.get('max_epochs', 100)) - batch_size = kwargs.get('batch_size', self.config.get('batch_size', 32)) - num_timesteps = kwargs.get('num_timesteps', self.config.get('timesteps', 12)) - - # Step 1: Data preprocessing and scaling - self._preprocess_data(da_rr, **kwargs) - - # Step 2: Train autoencoder - self._train_autoencoder( - da_rr, - epochs=epochs, - batch_size=batch_size, - **kwargs - ) - - # Step 3: Train diffusion model - self._train_diffusion_model( - da_rr, - num_timesteps=num_timesteps, - epochs=epochs, - batch_size=batch_size, - **kwargs - ) - - # Store timestep length - if 'time' in da_rr.dims: - time_coords = da_rr.coords['time'].values - if len(time_coords) > 1: - self.timestep_length = np.timedelta64( - int(np.diff(time_coords[:2])[0]), 'ns' - ) - - def predict( - self, - da_rr: xr.DataArray, - duration: str, - **kwargs: Any - ) -> xr.DataArray: - """Generate precipitation forecasts. - - Args: - da_rr: xarray DataArray containing initial precipitation conditions - duration: ISO 8601 duration string (e.g., "PT1H" for 1 hour) - **kwargs: Additional arguments: - - num_samples: Number of ensemble samples to generate - - num_diffusion_steps: Number of diffusion steps - - Returns: - xarray DataArray containing precipitation predictions with - original spatial dimensions plus an "elapsed_time" dimension - """ - from isodate import parse_duration - - # Parse duration string - duration_obj = parse_duration(duration) - num_forecasts = int(duration_obj.total_seconds() / - self.timestep_length.astype(int)) - - # Extract configuration from kwargs - num_samples = kwargs.get('num_samples', 1) - num_diffusion_steps = kwargs.get('num_diffusion_steps', 50) - - # Preprocess input using stored scaler - processed_input = self._preprocess_input(da_rr) - - # Encode to latent space using autoencoder - with torch.no_grad(): - latent = self.autoencoder.encode(processed_input) - - # Generate predictions using diffusion model - predictions = [] - for _ in range(num_samples): - pred = self._diffusion_predict( - latent, - num_forecasts, - num_diffusion_steps - ) - predictions.append(pred) - - # Stack and average predictions - predictions = torch.stack(predictions, dim=0).mean(dim=0) - - # Decode from latent space - with torch.no_grad(): - forecasted = self.autoencoder.decode(predictions) - - # Postprocess and convert back to original scale - output = self._postprocess_output(forecasted, da_rr) - - # Create output DataArray with elapsed_time dimension - time_coords = da_rr.coords['time'].values[-1] - elapsed_times = [ - np.timedelta64(i, 'm') * 5 # Assuming 5-minute steps - for i in range(1, num_forecasts + 1) - ] - - output_da = xr.DataArray( - output, - dims=['elapsed_time', 'latitude', 'longitude'], - coords={ - 'elapsed_time': ('elapsed_time', elapsed_times), - 'latitude': ('latitude', da_rr.coords['latitude'].values), - 'longitude': ('longitude', da_rr.coords['longitude'].values), - }, - name='precipitation' - ) - - return output_da - - def _preprocess_data(self, da_rr: xr.DataArray, **kwargs: Any) -> None: - """Preprocess precipitation data and fit scaler.""" - # Implement data scaling/normalization - # Store scaling parameters in self.scaler - pass - - def _train_autoencoder( - self, - da_rr: xr.DataArray, - epochs: int, - batch_size: int, - **kwargs: Any - ) -> None: - """Train the autoencoder component.""" - # Import and use ldcast autoencoder training - from ldcast.models.autoenc import setup_and_train - # Implementation details - pass - - def _train_diffusion_model( - self, - da_rr: xr.DataArray, - num_timesteps: int, - epochs: int, - batch_size: int, - **kwargs: Any - ) -> None: - """Train the diffusion model component.""" - # Import and use ldcast genforecast training - from ldcast.models.genforecast import setup_and_train - # Implementation details - pass - - def _preprocess_input(self, da_rr: xr.DataArray) -> torch.Tensor: - """Convert input xarray to scaled tensor.""" - # Apply stored scaler pass + - def _postprocess_output( - self, - output: torch.Tensor, - reference_da: xr.DataArray - ) -> np.ndarray: - """Convert predictions back to original scale and format.""" - # Reverse scaling using stored scaler - pass + def predict(self, inputs): + '''inputs is of shape (batch_size, 1, 4) + spatial_shape''' + latent_inputs = self.autoencoder.net.encode(inputs) + latent_pred = self.latent_nowcaster(latent_inputs) + return self.autoencoder.net.decode(latent_pred) - def _diffusion_predict( - self, - latent: torch.Tensor, - num_forecasts: int, - num_steps: int - ) -> torch.Tensor: - """Generate predictions using the diffusion model.""" - # Use ldcast diffusion inference + def _train_autoencoder(self, da_rr: xr.DataArray, epochs: int, batch_size: int, **kwargs: Any) -> None: pass - def _build_autoencoder(self) -> nn.Module: - """Build autoencoder architecture from config.""" + def _train_latent_nowcaster(self, da_rr: xr.DataArray, num_timesteps: int, epochs: int, batch_size: int, **kwargs: Any) -> None: pass - - def _build_diffusion_model(self) -> nn.Module: - """Build diffusion model architecture from config.""" - pass \ No newline at end of file + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/utils.py b/src/mlcast/models/ldcast/utils.py index 65cada8..f38bd29 100644 --- a/src/mlcast/models/ldcast/utils.py +++ b/src/mlcast/models/ldcast/utils.py @@ -1,3 +1,5 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/utils.py + import torch from torch import nn From 5d9490d60093511f6960a40e222a95cdf7907558 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Tue, 24 Feb 2026 14:44:58 +0100 Subject: [PATCH 04/13] changes with respect to previous commit: - in base.py, I added a training_logic method to NowcastingLightningModule; this method can be rewritten in case the training logic is not straightforward (this is the case for diffusion models); I added also a print_log_loss method to take care of the printing and logging of the loss - in autoenc.py, I added the fact that, by default, autoencoder.decode returns only what is used as the latent encoding (which is the mean, see README.md) - I have understood that samplers are only used in inference ! The training (and validation) step is always done by predicting the noise (or a quantity which is related to it by a simple formula). The scheduler has some role to play before training, and I put the code of the scheduler in the diffusion/scheduler.py file (the SimpleSampler is not used anymore, because it was the scheduler) - I added the LatentDiffusion and LatentDiffusionLightning classes in diffusion/diffusion.py (the latter replaces the LatentNowcaster class) - to train the latent diffusion part, I created a LatentDataset class which converts in latent space the data with the autoencoder (to be used once the autoencoder has been trained) in data.py - I updated a bit the LDCast class in ldcast.py (this is where the main part of the work remains to be done) - I updated the README accordingly, with examples on how deal with these different parts; I also added some basic details I have understood on diffusion models and on the variational autoencoder --- LDCast.ipynb | 243 ------------------ README.md | 144 +++++++---- src/mlcast/models/base.py | 19 +- src/mlcast/models/ldcast/autoenc/autoenc.py | 9 +- src/mlcast/models/ldcast/data.py | 28 ++ .../models/ldcast/diffusion/diffusion.py | 133 +++++++++- .../models/ldcast/diffusion/scheduler.py | 39 +++ src/mlcast/models/ldcast/ldcast.py | 53 ++-- 8 files changed, 348 insertions(+), 320 deletions(-) delete mode 100644 LDCast.ipynb create mode 100644 src/mlcast/models/ldcast/data.py create mode 100644 src/mlcast/models/ldcast/diffusion/scheduler.py diff --git a/LDCast.ipynb b/LDCast.ipynb deleted file mode 100644 index 6ebf429..0000000 --- a/LDCast.ipynb +++ /dev/null @@ -1,243 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "aa7c06c7-6229-46bb-a06f-8aa8b03ab250", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "83a840b0-0705-4ece-a93d-6496ec075931", - "metadata": {}, - "outputs": [], - "source": [ - "#from src.mlcast.models.ldcast.ldcast import LDCast, LDCastLightningModule" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "5e248d5a-84b4-4a61-9fef-ed4a9c613d5a", - "metadata": {}, - "outputs": [], - "source": [ - "#LDCastLightningModule(nn.Module(), nn.Module())" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "2a5fbc98-56de-48b5-9afa-3a46882e0a8b", - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn\n", - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e7639166-2162-4f71-9954-0e0d7e01dde8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "take care of ema scope\n" - ] - } - ], - "source": [ - "from src.mlcast.models.ldcast.autoenc.autoenc import AutoencoderKL\n", - "from src.mlcast.models.ldcast.autoenc.encoder import SimpleConvEncoder, SimpleConvDecoder\n", - "from src.mlcast.models.ldcast.context.context import AFNONowcastNetCascade\n", - "from src.mlcast.models.ldcast.diffusion.diffusion import DiffusionModel" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f0d7ac7d-0672-4a08-b08a-f06bcbb2d28e", - "metadata": {}, - "outputs": [], - "source": [ - "future_timesteps = 20\n", - "autoenc_time_ratio = 4 # number of timesteps encoded in the autoencoder" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "51608d56-c711-4da0-ab92-854be45d12ed", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# setup the different parts of LDCast\n", - "\n", - "# setup forecaster\n", - "conditioner = AFNONowcastNetCascade(\n", - " 32,\n", - " train_autoenc=False,\n", - " output_patches=future_timesteps//autoenc_time_ratio,\n", - " cascade_depth=3,\n", - " embed_dim=128,\n", - " analysis_depth=4\n", - ").to('cuda')\n", - "\n", - "enc = SimpleConvEncoder()\n", - "dec = SimpleConvDecoder()\n", - "autoencoder = AutoencoderKL(enc, dec).to('cuda')\n", - "\n", - "# setup denoiser\n", - "from src.mlcast.models.ldcast.diffusion.unet import UNetModel\n", - "denoiser = UNetModel(in_channels=autoencoder.hidden_width,\n", - " model_channels=256, out_channels=autoencoder.hidden_width,\n", - " num_res_blocks=2, attention_resolutions=(1,2), \n", - " dims=3, channel_mult=(1, 2, 4), num_heads=8,\n", - " num_timesteps=future_timesteps//autoenc_time_ratio,\n", - " # context channels (= analysis_net.cascade_dims)\n", - " context_ch=[128, 256, 512]).to('cuda')\n", - "\n", - "diffuser = DiffusionModel(denoiser).to('cuda')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d68fe5a1-3c0b-4d83-bac5-e3d64e08d719", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "aea2ed01-350c-48aa-8e43-464e63fd5d6b", - "metadata": {}, - "outputs": [], - "source": [ - "class LDCastLightningModule(nn.ModuleDict):\n", - " def __init__(self, autoencoder, conditioner, diffuser):\n", - " super().__init__({'autoencoder': autoencoder, 'conditioner': conditioner, 'diffuser': diffuser})\n", - "\n", - " def forward(self, x, timesteps):\n", - " \n", - " # encoded is tuple of 3 tensors, but only the first one is used !!\n", - " encoded = self.autoencoder.encode(x) \n", - "\n", - " # condition is a dict of tensors\n", - " condition = conditioner(encoded[0], timesteps)\n", - "\n", - " latent_diffused = diffuser(condition) # tensor\n", - "\n", - " prediction = self.autoencoder.decode(latent_diffused) # tensor\n", - " \n", - " return prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "8bd80398-d1c8-4a45-82a5-a31b80e5f02d", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "ldcast = LDCastLightningModule(autoencoder, conditioner, denoiser)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "55930f7c-ef89-4761-b053-39cdd52aba38", - "metadata": {}, - "outputs": [], - "source": [ - "# create fake data\n", - "timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32)\n", - "timesteps = timesteps.unsqueeze(0).expand(1,-1) # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why\n", - "x = torch.randn(1, 1, 4, 256, 256, device = 'cuda')" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "43867cb4-3c8a-4851-879c-e62dc7ed96d1", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "PLMS Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:10<00:00, 4.84it/s]\n" - ] - } - ], - "source": [ - "prediction = ldcast(x, timesteps)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "3cb247ac-4b87-4013-a591-0e97e8f66413", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 1, 20, 256, 256])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prediction.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "700453df-61b2-4f04-93fc-f631e50d5ec0", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.14.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/README.md b/README.md index a54367e..1cf201a 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,46 @@ # MLCast implementation of LDCast -see main branch ([https://github.com/mlcast-community/mlcast]) for details. +see main branch https://github.com/mlcast-community/mlcast for details. -# Main LDCast class - -The main class is LDCast and takes an autoencoder and a latent_nowcaster modules. Only the predict method is implemented, to show the encode-latent_nowcasting-decode pattern. - -``` -from src.mlcast.models.ldcast.ldcast import LDCast -ldcast = LDCast(autoencoder, latent_nowcaster) +```python +future_timesteps = 20 +autoenc_time_ratio = 4 # number of timesteps encoded in the autoencoder ``` +Here, 4 consecutive radar images are encoded at once. # Autoencoder -``` +```python from src.mlcast.models.ldcast.autoenc.autoenc import AutoencoderKLNet, autoenc_loss from src.mlcast.models.base import NowcastingLightningModule autoencoder = NowcastingLightningModule(AutoencoderKLNet(), autoenc_loss()).to('cuda') ``` The autoencoder is an instance of the NowcastingLightningModule. Training the autoencoder: -``` +```python # create fake data -x = torch.randn(2, 1, 4, 256, 256, device = 'cuda', requires_grad = False) -y = autoencoder(x, 4)[0] -y = y.detach() +inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') + +with torch.no_grad(): + # the forward pass of the autoencoder returns also the encoding + # so [0] is needed to select the decoded part only + y = autoencoder(x, 4)[0] batch = (x, y) import pytorch_lightning as L trainer = L.Trainer() trainer.fit(autoencoder, batch) ``` +The inputs tensors have shape `(batch_size, n_channels, number of input radar images,) + spatial shape`. In latent space, the tensors have shape `(batch_size, 32, n, 64, 64)`, where 32 is the `hidden_width` of the `autoencoder` and `n` is the number of consecutive encoded radar images divided by `autoenc_time_ratio` (set to 4). -# Latent nowcaster (= conditioner + denoiser + samplers) -The latent nowcaster manages the conditioner, the denoiser and the samplers. There can be two different samplers for training and for inference. +The original weights can be loaded directly as +```python +autoenc_weights_fn = '/path/to/original/autoencoder/weights' +autoencoder.net.load_state_dict(torch.load(autoenc_weights_fn)) ``` + +# Latent diffusion (= conditioner + denoiser) +The `LatentDiffusion` class is a `nn.Module` combining the conditioner and the denoiser. +```python # setup forecaster conditioner = AFNONowcastNetCascade( 32, @@ -54,48 +61,97 @@ denoiser = UNetModel(in_channels=autoencoder.net.hidden_width, context_ch=[128, 256, 512] # context channels (= analysis_net.cascade_dims) ).to('cuda') -# define the training and inference samplers -training_sampler = SimpleSampler() -inference_sampler = PLMSSampler(denoiser, 1000) - -# define the latent_nowcaster -from torch.nn import L1Loss -from src.mlcast.models.ldcast.diffusion.diffusion import LatentNowcaster -latent_nowcaster = LatentNowcaster(conditioner, denoiser, L1Loss(), training_sampler, inference_sampler) +from src.mlcast.models.ldcast.diffusion.diffusion import LatentDiffusion +ldm = LatentDiffusion(conditioner, denoiser) ``` -Create fake data for inference and training: +The `LatentDiffusion` class has a forward pass: it takes the noise, the timesteps of the diffusion and the encoded inputs +```python +latent_inputs = autoencoder.net.encode(inputs) +noise = torch.randn(2, 32, 5, 64, 64, device = latent_inputs.device) +t = torch.tensor([2, 3], device = latent_inputs.device) +ldm((t, noise, latent_inputs)) ``` -inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') -target = torch.randn(2, 1, 20, 256, 256, device = 'cuda') -loss = nn.L1Loss() -autoencoder.eval() -latent_inputs = autoencoder.net.encode(inputs)[0].detach() -latent_target = autoencoder.net.encode(target)[0].detach() -``` -Inference with the latent_nowcaster (PLMSSampler is used during inference) +The noise has to have the shape true radar images encoded in latent space. + +Create fake data to train the ldm: +```python +from torch.utils.data import TensorDataset +true = torch.randn(2, 1, future_timesteps, 256, 256, device = 'cuda') +dataset = TensorDataset(inputs, true) ``` -latent_nowcaster.infer(latent_inputs) +Create a ```Dataset``` which convert the samples in latent space with the autoencoder ``` -Training the latent_nowcaster (SimpleSampler is used during training) +self.autoencoder.net.eval() +latent_dataset = LatentDataset(dataset, autoencoder.net) +dataloader = DataLoader(latent_dataset, batch_size=2) ``` -from torch.utils.data import DataLoader, TensorDataset -dataset = TensorDataset(latent_inputs, latent_target) -dataloader = DataLoader(dataset, batch_size=2) +Put `ldm` in a `LightningModule` and train: +```python +from torch.nn import L1Loss +from src.mlcast.models.ldcast.diffusion.scheduler import Scheduler +from src.mlcast.models.ldcast.diffusion.diffusion import LatentDiffusionLightning -latent_batch = (latent_inputs, latent_target) -import pytorch_lightning as L +ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler()) trainer = L.Trainer() -trainer.fit(latent_nowcaster, dataloader) +trainer.fit(ldm_lightning, dataloader) ``` -# Notes +The original weights can not be directly loaded because the models are structured a little differently, but the original weights can be loaded with +```python +ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' +unexpected_keys = ldm_lightning.load_original_weights(ldm_weights_fn) +``` +`unexpected_keys` contains the keys that were not loaded (only the ema weights because I did not take care of the ema scope for the moment) -I did not manage to make LatentNowcaster a sublcass of NowcastingLightningModule because I would basically have to overwrite everything... LatentNowcaster needs two nets (denoiser and conditioner) and the training logic is not as straightforward as it is for the moment in NowcastingLightningModule. One should also take into account the fact that two different samplers are used for training and inference, so that the forward method can not just be self.net(x) +# Main LDCast class -It would be nice to have cleaner and consistent APIs for the samplers. For the moment, the PLMSSampler and the SimpleSampler are not totally consistent in their APIs, because the SimpleSampler (better/more common name for this one?) was only used during training, while the PLMSSampler was used during inference. The handling of the schedule of each sampler with respect to the schedule saved in the denoiser could also be clearer. +```python +from src.mlcast.models.ldcast.ldcast import LDCast +ldcast = LDCast(ldm_lightning, autoencoder) +``` + +# Notes During training, an EMA scope was used for the weights of the denoiser, I removed this for the moment, but it should reincluded in some way. The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. -In /src/mlcast/models/ldcast/diffusion/diffusion.py, one has to choose which sampler to use for testing \ No newline at end of file +I have understood that samplers are only used in inference ! The training (and validation) step is always done by predicting the noise (or a quantity which is related to it by a simple formula). What I called previously the SimpleSampler is actually simply a scheduler (which determines the values of alphas and betas, and add the noise on the latent samples during training) + +We might integrate this code within the Hugging Face Diffusers Library. + +# Basics on diffusion models + +See https://huggingface.co/blog/annotated-diffusion for some notations and formulas. + +During training, we start from a sample $x_0$ and create a series of samples $x_0, x_1, ..., x_T$ according to the formula (forward diffusion) +$$ +x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha_t}}\epsilon_t, \quad t = 1, ..., T +$$ +where $\epsilon_t \sim \mathcal{N}(0, 1)$. The constants $\bar{\alpha}_t$ are chosen, but they need the property that $\bar{\alpha}_t \to 0$ as $t \to T$, so that $x_T \sim \mathcal{N}(0, 1)$. These constants are computed with an algorithm called a scheduler. + +From a given $x_t$, the model can either be trained to predict $x_0$, $\epsilon_t$ or the velocity $v_t$. The latter is defined as +$$ +v_t = \sqrt{\bar{\alpha}_t} \epsilon_t - \sqrt{1-\bar{\alpha}_t} x_0, +$$ +which is equivalent to +$$ +x_0 = \sqrt{\bar{\alpha}_t} x_t - \sqrt{1-\bar{\alpha}_t} v_t. +$$ +The model is also given the timestep $t$. The loss is computed by comparing the target quantity ($\epsilon_t$, $x_0$ or $v_t$) with the predicted quantity by the model. Choosing to predict $x_0$, $\epsilon_t$ or $v_t$ is conceptually equivalent, the difference is in the numerical properties of the scheme (like in ODE integration schemes). + +The validation and test steps are done in the same way. + +So the model is trained to predict $x_0$ (or something from which we can compute $x_0$) from $x_T\sim \mathcal{N}(0, 1)$. But for large values of $t$, this prediction is actually quite bad. During actual prediction, the prediction is usually iteratively refined with sampler schemes. The idea is that, from the noise predicted based on $x_T$, the sampler scheme allows to compute $x_{T - \Delta t}$. The model is then used to predict the noise based on this estimation of $x_{T - \Delta t}$, and the sampler scheme allows to deduce $x_{T - 2\Delta t}$, etc. $\Delta t$ is usually taken of the order of 50 (while $T$ is usually 1000). + +In Hugging Face Diffusers library, the scheduler and the sampler parts are often combined in one object called a scheduler, but the sampler part is only used during inference. + +The original code was using antialiasing before feeding the samples to the model (at least during inference), I should add this + +# The variational autoencoder + +Source https://medium.com/@jpark7/finally-a-clear-derivation-of-the-vae-kl-loss-4cb38d2e47b3. + +Variational autoencoders encode the data through a normal distribution in latent space: each sample is represented by the mean and the standard deviation of the normal distribution. When decoding the sample, a new sample is created resembling the original sample, but is not quite the same. The degree to which we force the decoded samples to resemble the original ones is tuned by the `kl_weight` parameter of the KL loss function. + +When using the encoded sample (for example to produce a condition with the conditioner), only the mean is used. In the original code, `autoencoder.decode` was returning a tuple `(mean, log_var)`, so that one had to select the mean with `autoencoder.decode(x)[0]`, which is not very clear. I replaced this by adding a keyword `return_log_var` in `autoencoder.decode`. \ No newline at end of file diff --git a/src/mlcast/models/base.py b/src/mlcast/models/base.py index bb4783d..fe068fb 100644 --- a/src/mlcast/models/base.py +++ b/src/mlcast/models/base.py @@ -131,6 +131,12 @@ def forward(self, x: torch.Tensor, n_timesteps: int) -> torch.Tensor: """ return self.net(x, n_timesteps) # Assuming net is a callable model + def training_logic(self, batch, batch_idx): + """Can be overwritten if needed""" + x, y = batch + predictions = self.forward(x, n_timesteps = 4) + loss = self.loss(predictions, y) + def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: """Generic model step for training or validation. @@ -141,9 +147,13 @@ def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> to Returns: Loss value for the current batch """ - x, y = batch - predictions = self.forward(x, n_timesteps=4) - loss = self.loss(predictions, y) + + loss = self.training_logic(batch, batch_idx) + loss = self.print_log_loss(loss, step_name) + + return loss + + def print_log_loss(self, loss, step_name): if isinstance(loss, dict): # append step name to loss keys for logging loss = {f"{step_name}/{k}": v for k, v in loss.items()} @@ -153,9 +163,8 @@ def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> to raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") else: self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss - + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """Training step for a single batch. diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py index c7f43a6..5d5a798 100644 --- a/src/mlcast/models/ldcast/autoenc/autoenc.py +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -46,12 +46,15 @@ def __init__( self.log_var = nn.Parameter(torch.zeros(size=())) self.kl_weight = kl_weight - def encode(self, x): + def encode(self, x, return_log_var = False): if len(x.shape) < 5: x = x[None] h = self.encoder(x) (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) - return (mean, log_var) + if return_log_var: + return (mean, log_var) + else: + return mean def decode(self, z): z = self.to_decoder(z) @@ -59,7 +62,7 @@ def decode(self, z): return dec def forward(self, x, n_timesteps, sample_posterior=True): - (mean, log_var) = self.encode(x) + (mean, log_var) = self.encode(x, return_log_var = True) if sample_posterior: z = sample_from_standard_normal(mean, log_var) else: diff --git a/src/mlcast/models/ldcast/data.py b/src/mlcast/models/ldcast/data.py new file mode 100644 index 0000000..edc6773 --- /dev/null +++ b/src/mlcast/models/ldcast/data.py @@ -0,0 +1,28 @@ +from torch.utils.data import Dataset +import torch + +class LatentDataset(Dataset): + def __init__(self, dataset, autoencoder): + super().__init__() + + self.autoencoder = autoencoder + self.dataset = dataset + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + + with torch.no_grad(): + inputs, true = self.dataset[idx] + latent_inputs = self.autoencoder.encode(inputs) + latent_true = self.autoencoder.encode(true) + + # until here, the first dimension of latent_inputs and latent_true is the 'batch dimension' + # if idx is a list, keep this batch dimension along this list + # if idx is not a list, this batch dimension is 1 and needs to be removed because the dataloader will repeatedly call __getitem__ and add an extra dimension for the batch dimension + if not isinstance(idx, list): + latent_inputs = latent_inputs[0] + latent_true = latent_true[0] + + return (latent_inputs, latent_true) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index 480a623..aeff27d 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -3,10 +3,139 @@ import pytorch_lightning as L from typing import Any import contextlib +from src.mlcast.models.base import NowcastingLightningModule +import numpy as np +from src.mlcast.models.ldcast.diffusion.utils import extract_into_tensor print('take care of ema scope, which was used as context manager each exactly when denoiser.forward was called, so it should be a taken care of in the code code about the denoiser or about the diffuser (nothing to do with samplers)') -import pytorch_lightning as L +class LatentDiffusion(nn.Module): + def __init__(self, conditioner, denoiser, parametrization = "eps"): + super().__init__() + self.conditioner = conditioner + self.denoiser = denoiser + self.parametrization = parametrization + + def forward(self, x, n_timesteps = 4): + # during training, noisy should be x_t + # during inference, noisy should be noise + t, noisy, latent_inputs = x + condition = self.conditioner(latent_inputs) + + # if parametrization is eps, out is the predicted noise + # if parametrization is x0, out is the guessed x0 + # if parametrization is v, out is the guessed v + out = self.denoiser(noisy, t, context = condition) + return out + + +class LatentDiffusionLightning(NowcastingLightningModule): + def __init__(self, ldm, loss, scheduler): + super().__init__(ldm, loss) + self.scheduler = scheduler + + # register the schedules (i.e. the values of alpha, beta etc). + self.register_schedule() + + def register_schedule(self): + + schedule = self.scheduler.schedule(torch.float32, next(self.net.parameters()).device) + + # check if the ldm has already some saved buffers + saved_buffers = dict(self.net.named_buffers()).keys() + already_saved = [name for name in schedule.keys() if name in saved_buffers] + if len(already_saved) > 0: + raise AttributeError(f'The denoiser has already some saved values for {already_saved}') + + for k, v in schedule.items(): + self.net.register_buffer(k, v) + + def training_logic(self, batch, batch_idx): + latent_inputs, latent_true = batch + x0 = latent_true + + t, noise, x_t = self.q_sample(x0) + + if self.net.parametrization == 'eps': + target = noise + if self.net.parametrization == 'x0': + target = x0 + + model_output = self.net((t, x_t, latent_inputs)) + + return self.loss(model_output, target) + + def q_sample(self, x0, noise = None, t = None): + '''generate noise target for training''' + if noise is None: + noise = torch.randn_like(x0) + if t is None: + t = torch.randint(0, self.scheduler.timesteps, (x0.shape[0],), device=x0.device).long() + + x_noisy = extract_into_tensor(self.net.sqrt_alphas_cumprod, t, x0.shape) * x0 + \ + extract_into_tensor(self.net.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + + return t, noise, x_noisy + + def load_original_weights(self, ldm_weights_fn): + ''' + load the weights of the denoiser and the conditioner, from the way they were saved originally + at the moment, the ema scope is not taken into account + returns the weights which were not loaded in one of the two nets (should be only those related to the ema scope) + ''' + ldm_state_dict = torch.load(ldm_weights_fn) + + # track the keys + ldm_keys = list(ldm_state_dict.keys()) + + # remove the weights of the autoencoder + for k in ldm_keys.copy(): + if k.startswith('autoencoder.') or k.startswith('context_encoder.autoencoder.'): + ldm_keys.remove(k) + + # extract the keys of the denoiser (it was called 'model' in the original code) + denoiser_state_dict = {} + for k in ldm_keys.copy(): + if k.startswith('model.'): + new_key = k.replace('model.', '') + denoiser_state_dict[new_key] = ldm_state_dict[k] + ldm_keys.remove(k) + + # extract the keys of the conditioner (it was called 'context_encoder' in the original code) + conditioner_state_dict = {} + for k in ldm_keys.copy(): + if k.startswith('context_encoder.'): + new_key = k.replace('context_encoder.', '') + conditioner_state_dict[new_key] = ldm_state_dict[k] + ldm_keys.remove(k) + + # proj, temporal_transformer and analysis were lists one only element, I simplified this + # the keys have to be adapted + new_conditioner_state_dict = {} + for k, v in conditioner_state_dict.items(): + new_key = k + if k.startswith('proj.0.'): + new_key = k.replace('proj.0.', 'proj.') + if k.startswith('temporal_transformer.0.'): + new_key = k.replace('temporal_transformer.0.', 'temporal_transformer.') + if k.startswith('analysis.0.'): + new_key = k.replace('analysis.0.', 'analysis.') + new_conditioner_state_dict[new_key] = v + conditioner_state_dict = new_conditioner_state_dict + + self.net.conditioner.load_state_dict(conditioner_state_dict) + self.net.denoiser.load_state_dict(denoiser_state_dict) + + # check that the buffers saved in self.net are the same than the original ones + for buffer in self.net.named_buffers(): + name, value = buffer + assert (value == ldm_state_dict[name].to(value.device)).all() + ldm_keys.remove(name) + + return ldm_keys + + + class LatentNowcaster(L.LightningModule): """Base class for PyTorch Lightning modules used in nowcasting models. @@ -129,4 +258,4 @@ def on_predict_start(self): def on_test_start(self): # training or inference sampler ??? self._current_sampler_mode = self.training_sampler - super().on_test_start() \ No newline at end of file + super().on_test_start() diff --git a/src/mlcast/models/ldcast/diffusion/scheduler.py b/src/mlcast/models/ldcast/diffusion/scheduler.py new file mode 100644 index 0000000..cd0a607 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/scheduler.py @@ -0,0 +1,39 @@ +from functools import partial +from src.mlcast.models.ldcast.diffusion.utils import make_beta_schedule +import numpy as np +import torch + +class Scheduler(): + def __init__(self, + timesteps = 1000, + beta_schedule = "linear", + linear_start = 1e-4, + linear_end = 2e-2, + cosine_s = 8e-3, + ): + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + def schedule(self, dtype, device): + + betas = make_beta_schedule( + self.beta_schedule, self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, + cosine_s=self.cosine_s + ) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + assert alphas_cumprod.shape[0] == self.timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype = dtype, device = device) + + return {'betas': to_torch(betas), + 'alphas_cumprod': to_torch(alphas_cumprod), + 'alphas_cumprod_prev': to_torch(alphas_cumprod_prev), + 'sqrt_alphas_cumprod': to_torch(np.sqrt(alphas_cumprod)), + 'sqrt_one_minus_alphas_cumprod': to_torch(np.sqrt(1. - alphas_cumprod))} \ No newline at end of file diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index 444d6ce..9291866 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -1,37 +1,44 @@ # new file with respect to original code -import abc -from pathlib import Path -from typing import Any - -import numpy as np +from src.mlcast.models.base import NowcastingModelBase import pytorch_lightning as L -import torch -import xarray as xr -from torch import nn +from src.mlcast.models.ldcast.data import LatentDataset +from torch.utils.data import DataLoader -from ..base import NowcastingModelBase, NowcastingLightningModule class LDCast(NowcastingModelBase): - - def __init__(self, autoencoder, latent_nowcaster): - #super().__init__() + def __init__(self, ldm_lightning, autoencoder): + super().__init__() + self.ldm_lightning = ldm_lightning self.autoencoder = autoencoder - self.latent_nowcaster = latent_nowcaster - def fit(self, da_rr: xr.DataArray, **kwargs: Any) -> None: - pass - + def fit(self, dataset): + '''dataset should contains pairs of (inputs, true), with + inputs.shape = (batch_size, 1, 4, 256, 256) + true.shape = (batch_size, 1, 20, 256, 256) + ''' + self.fit_autoencoder(dataset) + self.fit_ldm(dataset) + + def fit_ldm(self, dataset): + self.autoencoder.net.eval() + + latent_dataset = LatentDataset(dataset, self.autoencoder.net) + dataloader = DataLoader(latent_dataset, batch_size=2) + trainer = L.Trainer() + trainer.fit(self.ldm_lightning, dataloader) + def fit_autoencoder(self, dataset): + pass + + def load(self): + pass def predict(self, inputs): - '''inputs is of shape (batch_size, 1, 4) + spatial_shape''' + '''inputs.shape = (batch_size, 1, 4, 256, 256)''' latent_inputs = self.autoencoder.net.encode(inputs) - latent_pred = self.latent_nowcaster(latent_inputs) + latent_pred = self.ldm_lightning(latent_inputs) return self.autoencoder.net.decode(latent_pred) - - def _train_autoencoder(self, da_rr: xr.DataArray, epochs: int, batch_size: int, **kwargs: Any) -> None: - pass - - def _train_latent_nowcaster(self, da_rr: xr.DataArray, num_timesteps: int, epochs: int, batch_size: int, **kwargs: Any) -> None: + + def save(self): pass \ No newline at end of file From 30b0427fd815c9bd49b808a0aac595fdcc3da5ad Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Tue, 24 Feb 2026 14:57:02 +0100 Subject: [PATCH 05/13] Changed the 'Notes' section in 'TO DO', and improved the display of equations (both in README) --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1cf201a..32b2063 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ from src.mlcast.models.ldcast.ldcast import LDCast ldcast = LDCast(ldm_lightning, autoencoder) ``` -# Notes +# TO DO During training, an EMA scope was used for the weights of the denoiser, I removed this for the moment, but it should reincluded in some way. @@ -120,24 +120,32 @@ I have understood that samplers are only used in inference ! The training (and v We might integrate this code within the Hugging Face Diffusers Library. +It remains mainly to write code in the main LDCast class (in `ldcast.py`) + # Basics on diffusion models See https://huggingface.co/blog/annotated-diffusion for some notations and formulas. During training, we start from a sample $x_0$ and create a series of samples $x_0, x_1, ..., x_T$ according to the formula (forward diffusion) + $$ x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha_t}}\epsilon_t, \quad t = 1, ..., T $$ + where $\epsilon_t \sim \mathcal{N}(0, 1)$. The constants $\bar{\alpha}_t$ are chosen, but they need the property that $\bar{\alpha}_t \to 0$ as $t \to T$, so that $x_T \sim \mathcal{N}(0, 1)$. These constants are computed with an algorithm called a scheduler. From a given $x_t$, the model can either be trained to predict $x_0$, $\epsilon_t$ or the velocity $v_t$. The latter is defined as + $$ v_t = \sqrt{\bar{\alpha}_t} \epsilon_t - \sqrt{1-\bar{\alpha}_t} x_0, $$ + which is equivalent to + $$ x_0 = \sqrt{\bar{\alpha}_t} x_t - \sqrt{1-\bar{\alpha}_t} v_t. $$ + The model is also given the timestep $t$. The loss is computed by comparing the target quantity ($\epsilon_t$, $x_0$ or $v_t$) with the predicted quantity by the model. Choosing to predict $x_0$, $\epsilon_t$ or $v_t$ is conceptually equivalent, the difference is in the numerical properties of the scheme (like in ODE integration schemes). The validation and test steps are done in the same way. @@ -154,4 +162,4 @@ Source https://medium.com/@jpark7/finally-a-clear-derivation-of-the-vae-kl-loss- Variational autoencoders encode the data through a normal distribution in latent space: each sample is represented by the mean and the standard deviation of the normal distribution. When decoding the sample, a new sample is created resembling the original sample, but is not quite the same. The degree to which we force the decoded samples to resemble the original ones is tuned by the `kl_weight` parameter of the KL loss function. -When using the encoded sample (for example to produce a condition with the conditioner), only the mean is used. In the original code, `autoencoder.decode` was returning a tuple `(mean, log_var)`, so that one had to select the mean with `autoencoder.decode(x)[0]`, which is not very clear. I replaced this by adding a keyword `return_log_var` in `autoencoder.decode`. \ No newline at end of file +When using the encoded sample (for example to produce a condition with the conditioner), only the mean is used. In the original code, `autoencoder.decode` was returning a tuple `(mean, log_var)`, so that one had to select the mean with `autoencoder.decode(x)[0]`, which is not very clear. I replaced this by adding a keyword `return_log_var` in `autoencoder.decode`. From 5d16e1a2e25cb360419120a35c076b26317245a3 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Wed, 25 Feb 2026 18:30:46 +0100 Subject: [PATCH 06/13] changes with respect to previous commit: - I reincluded the ema weights. I changed a little the implementation. First, in the original code, ema weights were used within a python scope. It seems more standard to have an object holding the ema weights and using lightning hooks, apply the ema weights before validation and test steps and before inference, and to restore the model weights after these. The original code was holding the ema weights as buffers to have them saved automatically, but it is simpler (and more standard it seems) to hold them in a dictionary. For that, I changed diffusion/ema.py and diffusion.diffusion.py - I removed the method to load the weights of the denoiser and of the conditioner from the original way they were saved, and put this code in a function in original_weights.py (I think it is cleaner not to have the LatentDiffusion class with this method). In original_weights.py, I added a function to check that the saved buffers are the same than the ones already in ldm. - I added to the LDCast class (ldcast.py) methods to load and save weights from a folder where the weights of the autoencoder, of the conditioner and of the denoiser (and ema weights if any) are stored. --- README.md | 29 ++- .../.ipynb_checkpoints/ldcast-checkpoint.py | 37 --- .../.ipynb_checkpoints/autoenc-checkpoint.py | 68 ------ .../diffusion-checkpoint.py | 132 ----------- .../.ipynb_checkpoints/ema-checkpoint.py | 78 ------- .../models/ldcast/diffusion/diffusion.py | 214 +++--------------- src/mlcast/models/ldcast/diffusion/ema.py | 100 ++++---- src/mlcast/models/ldcast/ldcast.py | 22 +- src/mlcast/models/ldcast/original_weights.py | 72 ++++++ 9 files changed, 180 insertions(+), 572 deletions(-) delete mode 100644 src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py create mode 100644 src/mlcast/models/ldcast/original_weights.py diff --git a/README.md b/README.md index 32b2063..64b6eae 100644 --- a/README.md +++ b/README.md @@ -96,12 +96,25 @@ trainer = L.Trainer() trainer.fit(ldm_lightning, dataloader) ``` -The original weights can not be directly loaded because the models are structured a little differently, but the original weights can be loaded with +The original weights can not be directly loaded because the models are structured a little differently, but the original weights files can be converted with ```python +from src.mlcast.models.ldcast.original_weights import convert_original_weights ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' -unexpected_keys = ldm_lightning.load_original_weights(ldm_weights_fn) +state_dict = convert_original_weights(ldm_weights_fn) +torch.save(state_dict['denoiser_state_dict'], 'denoiser.pt') +torch.save(state_dict['conditioner_state_dict'], 'conditioner.pt') ``` -`unexpected_keys` contains the keys that were not loaded (only the ema weights because I did not take care of the ema scope for the moment) +`state_dict['unmatched']` contains a `dict` with the elements that were not matched (only the ema weights because I did not take care of the ema scope for the moment, and the buffer keys for the scheduling). The weights for the conditioner and the denoiser can then be loaded with +```python +conditioner.load_state_dict(torch.load('conditioner_state_dict.pt')) +denoiser.load_state_dict(torch.load('denoiser_state_dict.pt')) +``` +One can check that the buffers have the same values with +```python +from src.mlcast.models.ldcast.original_weights import check_saved_buffers +unmatched = check_saved_buffers(state_dict['unmatched'], ldm) +``` +Here, `unmatched` contains the element which have not been matched (only the ema weights). # Main LDCast class @@ -109,11 +122,17 @@ unexpected_keys = ldm_lightning.load_original_weights(ldm_weights_fn) from src.mlcast.models.ldcast.ldcast import LDCast ldcast = LDCast(ldm_lightning, autoencoder) ``` +To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): +```python +ldcast.load('/path/to/folder') +``` +To save in a folder: +```python +ldcast.save('/path/to/folder') +``` # TO DO -During training, an EMA scope was used for the weights of the denoiser, I removed this for the moment, but it should reincluded in some way. - The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. I have understood that samplers are only used in inference ! The training (and validation) step is always done by predicting the noise (or a quantity which is related to it by a simple formula). What I called previously the SimpleSampler is actually simply a scheduler (which determines the values of alphas and betas, and add the noise on the latent samples during training) diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py deleted file mode 100644 index 444d6ce..0000000 --- a/src/mlcast/models/ldcast/.ipynb_checkpoints/ldcast-checkpoint.py +++ /dev/null @@ -1,37 +0,0 @@ -# new file with respect to original code - -import abc -from pathlib import Path -from typing import Any - -import numpy as np -import pytorch_lightning as L -import torch -import xarray as xr -from torch import nn - -from ..base import NowcastingModelBase, NowcastingLightningModule - -class LDCast(NowcastingModelBase): - - def __init__(self, autoencoder, latent_nowcaster): - #super().__init__() - self.autoencoder = autoencoder - self.latent_nowcaster = latent_nowcaster - - def fit(self, da_rr: xr.DataArray, **kwargs: Any) -> None: - pass - - - def predict(self, inputs): - '''inputs is of shape (batch_size, 1, 4) + spatial_shape''' - latent_inputs = self.autoencoder.net.encode(inputs) - latent_pred = self.latent_nowcaster(latent_inputs) - return self.autoencoder.net.decode(latent_pred) - - def _train_autoencoder(self, da_rr: xr.DataArray, epochs: int, batch_size: int, **kwargs: Any) -> None: - pass - - def _train_latent_nowcaster(self, da_rr: xr.DataArray, num_timesteps: int, epochs: int, batch_size: int, **kwargs: Any) -> None: - pass - \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py deleted file mode 100644 index c7f43a6..0000000 --- a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/autoenc-checkpoint.py +++ /dev/null @@ -1,68 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py - -import pytorch_lightning as pl -import torch -from torch import nn -from .encoder import SimpleConvEncoder, SimpleConvDecoder - -from ..distributions import ( - ensemble_nll_normal, - kl_from_standard_normal, - sample_from_standard_normal, -) - -class autoenc_loss(nn.Module): - def __init__(self, kl_weight = 0.01): - super().__init__() - self.kl_weight = kl_weight - - def forward(self, predictions, y): - (y_pred, mean, log_var) = predictions - - rec_loss = (y - y_pred).abs().mean() - kl_loss = kl_from_standard_normal(mean, log_var) - - total_loss = rec_loss + self.kl_weight * kl_loss - - return {'total_loss': total_loss, 'rec_loss': rec_loss, 'kl_loss': kl_loss} - - -class AutoencoderKLNet(pl.LightningModule): - def __init__( - self, - encoder = SimpleConvEncoder(), - decoder = SimpleConvDecoder(), - kl_weight=0.01, - encoded_channels=64, - hidden_width=32, - **kwargs, - ): - super().__init__(**kwargs) - self.encoder = encoder - self.decoder = decoder - self.hidden_width = hidden_width - self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, kernel_size=1) - self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, kernel_size=1) - self.log_var = nn.Parameter(torch.zeros(size=())) - self.kl_weight = kl_weight - - def encode(self, x): - if len(x.shape) < 5: - x = x[None] - h = self.encoder(x) - (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) - return (mean, log_var) - - def decode(self, z): - z = self.to_decoder(z) - dec = self.decoder(z) - return dec - - def forward(self, x, n_timesteps, sample_posterior=True): - (mean, log_var) = self.encode(x) - if sample_posterior: - z = sample_from_standard_normal(mean, log_var) - else: - z = mean - dec = self.decode(z) - return (dec, mean, log_var) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py deleted file mode 100644 index 480a623..0000000 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/diffusion-checkpoint.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import torch.nn as nn -import pytorch_lightning as L -from typing import Any -import contextlib - -print('take care of ema scope, which was used as context manager each exactly when denoiser.forward was called, so it should be a taken care of in the code code about the denoiser or about the diffuser (nothing to do with samplers)') - -import pytorch_lightning as L -class LatentNowcaster(L.LightningModule): - """Base class for PyTorch Lightning modules used in nowcasting models. - - This class provides a standard interface for training and validation - steps, as well as optimizer configuration. - """ - - def __init__( - self, - conditioner: nn.Module, - denoiser: nn.Module, - loss: nn.Module, - training_sampler: nn.Module, - inference_sampler: nn.Module, - optimizer_class: Any | None = None, - optimizer_kwargs: dict | None = None, - **kwargs: Any, - ): - super().__init__() - self.save_hyperparameters(ignore=["denoiser", "conditioner", "training_sampler", "inference_sampler", "loss"]) - self.conditioner = conditioner - self.denoiser = denoiser - self.loss = loss - self.training_sampler = training_sampler - self.inference_sampler = inference_sampler - self.optimizer_class = torch.optim.Adam if optimizer_class is None else optimizer_class - - training_sampler.register_schedule(denoiser) - - def infer(self, latent_inputs, num_diffusion_iters = 50, verbose = True): - - condition = self.conditioner(latent_inputs) - - gen_shape = (32, 5, 256//4, 256//4) - batch_size = len(list(condition.values())[0]) - with contextlib.redirect_stdout(None): - (s, intermediates) = self.inference_sampler.sample( - num_diffusion_iters, - batch_size, - gen_shape, - condition, - progbar=verbose - ) - return s - - def model_step(self, latent_batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: - """Generic model step for training or validation. - - Args: - batch: Input batch of data - batch_idx: Index of the current batch - - Returns: - Loss value for the current batch - """ - latent_inputs, latent_targets = latent_batch - - condition = self.conditioner(latent_inputs) - t, noise, latent_target_noisy = self.training_sampler.q_sample(self.denoiser, latent_targets) - guessed_noise = self.denoiser(latent_target_noisy, t, context = condition) - loss = self.loss(guessed_noise, noise) - - if isinstance(loss, dict): - # append step name to loss keys for logging - loss = {f"{step_name}/{k}": v for k, v in loss.items()} - self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - loss = loss.get(f"{step_name}/total_loss", None) - if loss is None: - raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") - else: - self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) - - return loss - - def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: - """Training step for a single batch. - - Args: - batch: Input batch of data - batch_idx: Index of the current batch - - Returns: - Loss value for the current batch - """ - return self.model_step(batch, batch_idx, step_name="train") - - def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: - """Validation step for a single batch. - - Args: - batch: Input batch of data - batch_idx: Index of the current batch - - Returns: - Loss value for the current batch - """ - return self.model_step(batch, batch_idx, step_name="val") - - def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure the optimizer for training. - - Returns: - Optimizer instance to use for training - """ - return self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) - - - def on_train_start(self): - self._current_sampler = self.training_sampler - super().on_train_start() - - def on_validation_start(self): - self._current_sampler_mode = self.training_sampler - super().on_validation_start() - - def on_predict_start(self): - self._current_sampler_mode = self.inference_sampler - super().on_predict_start() - - def on_test_start(self): - # training or inference sampler ??? - self._current_sampler_mode = self.training_sampler - super().on_test_start() \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py deleted file mode 100644 index 296c8a3..0000000 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ema-checkpoint.py +++ /dev/null @@ -1,78 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/ema.py - -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') - - self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) - - for name, p in model.named_parameters(): - if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) - - self.collected_params = [] - - def forward(self,model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index aeff27d..e4c8b69 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -6,8 +6,7 @@ from src.mlcast.models.base import NowcastingLightningModule import numpy as np from src.mlcast.models.ldcast.diffusion.utils import extract_into_tensor - -print('take care of ema scope, which was used as context manager each exactly when denoiser.forward was called, so it should be a taken care of in the code code about the denoiser or about the diffuser (nothing to do with samplers)') +from .ema import EMA class LatentDiffusion(nn.Module): def __init__(self, conditioner, denoiser, parametrization = "eps"): @@ -30,25 +29,30 @@ def forward(self, x, n_timesteps = 4): class LatentDiffusionLightning(NowcastingLightningModule): - def __init__(self, ldm, loss, scheduler): + def __init__(self, ldm, loss, scheduler, use_ema = True): super().__init__(ldm, loss) self.scheduler = scheduler # register the schedules (i.e. the values of alpha, beta etc). self.register_schedule() + if use_ema: + self.ema = EMA(self.net.denoiser) + def register_schedule(self): schedule = self.scheduler.schedule(torch.float32, next(self.net.parameters()).device) # check if the ldm has already some saved buffers - saved_buffers = dict(self.net.named_buffers()).keys() - already_saved = [name for name in schedule.keys() if name in saved_buffers] - if len(already_saved) > 0: - raise AttributeError(f'The denoiser has already some saved values for {already_saved}') - - for k, v in schedule.items(): - self.net.register_buffer(k, v) + saved_buffers = dict(self.net.named_buffers()) + already_saved_and_different = [name for name in schedule.keys() + if (name in saved_buffers.keys() and (schedule[name] != saved_buffers[name]).any()) + ] + if len(already_saved_and_different) > 0: + raise AttributeError(f'The denoiser has already some different values for {already_saved}') + + for k in schedule.keys(): + self.net.register_buffer(k, schedule[k]) def training_logic(self, batch, batch_idx): latent_inputs, latent_true = batch @@ -77,185 +81,23 @@ def q_sample(self, x0, noise = None, t = None): return t, noise, x_noisy - def load_original_weights(self, ldm_weights_fn): - ''' - load the weights of the denoiser and the conditioner, from the way they were saved originally - at the moment, the ema scope is not taken into account - returns the weights which were not loaded in one of the two nets (should be only those related to the ema scope) - ''' - ldm_state_dict = torch.load(ldm_weights_fn) - - # track the keys - ldm_keys = list(ldm_state_dict.keys()) - - # remove the weights of the autoencoder - for k in ldm_keys.copy(): - if k.startswith('autoencoder.') or k.startswith('context_encoder.autoencoder.'): - ldm_keys.remove(k) - - # extract the keys of the denoiser (it was called 'model' in the original code) - denoiser_state_dict = {} - for k in ldm_keys.copy(): - if k.startswith('model.'): - new_key = k.replace('model.', '') - denoiser_state_dict[new_key] = ldm_state_dict[k] - ldm_keys.remove(k) - - # extract the keys of the conditioner (it was called 'context_encoder' in the original code) - conditioner_state_dict = {} - for k in ldm_keys.copy(): - if k.startswith('context_encoder.'): - new_key = k.replace('context_encoder.', '') - conditioner_state_dict[new_key] = ldm_state_dict[k] - ldm_keys.remove(k) - - # proj, temporal_transformer and analysis were lists one only element, I simplified this - # the keys have to be adapted - new_conditioner_state_dict = {} - for k, v in conditioner_state_dict.items(): - new_key = k - if k.startswith('proj.0.'): - new_key = k.replace('proj.0.', 'proj.') - if k.startswith('temporal_transformer.0.'): - new_key = k.replace('temporal_transformer.0.', 'temporal_transformer.') - if k.startswith('analysis.0.'): - new_key = k.replace('analysis.0.', 'analysis.') - new_conditioner_state_dict[new_key] = v - conditioner_state_dict = new_conditioner_state_dict - - self.net.conditioner.load_state_dict(conditioner_state_dict) - self.net.denoiser.load_state_dict(denoiser_state_dict) - - # check that the buffers saved in self.net are the same than the original ones - for buffer in self.net.named_buffers(): - name, value = buffer - assert (value == ldm_state_dict[name].to(value.device)).all() - ldm_keys.remove(name) - - return ldm_keys - - - -class LatentNowcaster(L.LightningModule): - """Base class for PyTorch Lightning modules used in nowcasting models. - - This class provides a standard interface for training and validation - steps, as well as optimizer configuration. - """ - - def __init__( - self, - conditioner: nn.Module, - denoiser: nn.Module, - loss: nn.Module, - training_sampler: nn.Module, - inference_sampler: nn.Module, - optimizer_class: Any | None = None, - optimizer_kwargs: dict | None = None, - **kwargs: Any, - ): - super().__init__() - self.save_hyperparameters(ignore=["denoiser", "conditioner", "training_sampler", "inference_sampler", "loss"]) - self.conditioner = conditioner - self.denoiser = denoiser - self.loss = loss - self.training_sampler = training_sampler - self.inference_sampler = inference_sampler - self.optimizer_class = torch.optim.Adam if optimizer_class is None else optimizer_class - - training_sampler.register_schedule(denoiser) - - def infer(self, latent_inputs, num_diffusion_iters = 50, verbose = True): - - condition = self.conditioner(latent_inputs) - - gen_shape = (32, 5, 256//4, 256//4) - batch_size = len(list(condition.values())[0]) - with contextlib.redirect_stdout(None): - (s, intermediates) = self.inference_sampler.sample( - num_diffusion_iters, - batch_size, - gen_shape, - condition, - progbar=verbose - ) - return s - - def model_step(self, latent_batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: - """Generic model step for training or validation. - - Args: - batch: Input batch of data - batch_idx: Index of the current batch - - Returns: - Loss value for the current batch - """ - latent_inputs, latent_targets = latent_batch - - condition = self.conditioner(latent_inputs) - t, noise, latent_target_noisy = self.training_sampler.q_sample(self.denoiser, latent_targets) - guessed_noise = self.denoiser(latent_target_noisy, t, context = condition) - loss = self.loss(guessed_noise, noise) - - if isinstance(loss, dict): - # append step name to loss keys for logging - loss = {f"{step_name}/{k}": v for k, v in loss.items()} - self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - loss = loss.get(f"{step_name}/total_loss", None) - if loss is None: - raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") - else: - self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) - - return loss - - def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: - """Training step for a single batch. - - Args: - batch: Input batch of data - batch_idx: Index of the current batch + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.ema.update() - Returns: - Loss value for the current batch - """ - return self.model_step(batch, batch_idx, step_name="train") - - def validation_step(self, batch: Any, batch_idx: int) -> torch.Tensor: - """Validation step for a single batch. - - Args: - batch: Input batch of data - batch_idx: Index of the current batch + def on_validation_start(self): + self.ema.apply_shadow() - Returns: - Loss value for the current batch - """ - return self.model_step(batch, batch_idx, step_name="val") - - def configure_optimizers(self) -> torch.optim.Optimizer: - """Configure the optimizer for training. + def on_validation_end(self): + self.ema.restore() - Returns: - Optimizer instance to use for training - """ - return self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) - + def on_test_start(self): + self.ema.apply_shadow() - def on_train_start(self): - self._current_sampler = self.training_sampler - super().on_train_start() - - def on_validation_start(self): - self._current_sampler_mode = self.training_sampler - super().on_validation_start() + def on_test_end(self): + self.ema.restore() def on_predict_start(self): - self._current_sampler_mode = self.inference_sampler - super().on_predict_start() - - def on_test_start(self): - # training or inference sampler ??? - self._current_sampler_mode = self.training_sampler - super().on_test_start() + self.ema.apply_shadow() + + def on_predict_end(self): + self.ema.restore() diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py index 296c8a3..4ccec75 100644 --- a/src/mlcast/models/ldcast/diffusion/ema.py +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -1,78 +1,54 @@ # from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/ema.py +''' +modifications following https://medium.com/@heyamit10/exponential-moving-average-ema-in-pytorch-eb8b6f1718eb +In the original code, EMA was a subclass of nn.Module, in order to register the parameters as buffers and (I guess) to have them saved automatically when saving the model. This made things a bit tricky and messy, because '.'-characters naturally appear in the names of model parameters, while they cannot appear in buffers names... Instead, the EMA weights will have to be saved with torch.save(ema.shadow) and loaded with ema.shadow = torch.load('ema_weights')''' import torch from torch import nn - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): +class EMA(nn.Module): + def __init__(self, model, decay=0.9999, use_num_updates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') - self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) + self.model = model + self.decay = decay + self.shadow = {} # to store EMA weights + self.backup = {} # to store the model weights when we replace them by ema weights + self.num_updates = 0 if use_num_updates else -1 # for dynamical decay - for name, p in model.named_parameters(): - if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) + self.register() + + def register(self): + '''initialize the ema weights with the model weights''' + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() - self.collected_params = [] + def update(self): + '''update the shadow parameters''' - def forward(self,model): + # use dynamical decay if use_num_updates was true in __init__ decay = self.decay - if self.num_updates >= 0: self.num_updates += 1 decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) \ No newline at end of file + + for name, param in self.model.named_parameters(): + if param.requires_grad: + new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] + self.shadow[name] = new_average.clone() + + def apply_shadow(self): + '''apply shadow (EMA) weights to the model''' + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.backup[name] = param.data.clone() + param.data = self.shadow[name] + + def restore(self): + '''restore original model weights from backup''' + for name, param in self.model.named_parameters(): + if param.requires_grad: + param.data = self.backup[name] \ No newline at end of file diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index 9291866..5dd6777 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -4,6 +4,7 @@ import pytorch_lightning as L from src.mlcast.models.ldcast.data import LatentDataset from torch.utils.data import DataLoader +import torch class LDCast(NowcastingModelBase): @@ -31,14 +32,27 @@ def fit_ldm(self, dataset): def fit_autoencoder(self, dataset): pass - def load(self): - pass def predict(self, inputs): '''inputs.shape = (batch_size, 1, 4, 256, 256)''' latent_inputs = self.autoencoder.net.encode(inputs) latent_pred = self.ldm_lightning(latent_inputs) return self.autoencoder.net.decode(latent_pred) - def save(self): - pass + def save(self, folder): + torch.save(self.autoencoder.net.state_dict(), f'{folder}/autoencoder.pt') + torch.save(self.ldm_lightning.net.conditioner.state_dict(), f'{folder}/conditioner.pt') + torch.save(self.ldm_lightning.net.denoiser.state_dict(), f'{folder}/denoiser.pt') + + if hasattr(self.ldm_lightning, 'ema'): + torch.save(self.ldm_lightning.ema.shadow, f'{folder}/ema.pt') + + def load(self, folder): + self.autoencoder.net.load_state_dict(torch.load(f'{folder}/autoencoder.pt')) + self.ldm_lightning.net.conditioner.load_state_dict(torch.load(f'{folder}/conditioner.pt')) + self.ldm_lightning.net.denoiser.load_state_dict(torch.load(f'{folder}/denoiser.pt')) + + if hasattr(self.ldm_lightning, 'ema'): + self.ldm_lightning.ema.shadow = torch.load(f'{folder}/ema.pt') + + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/original_weights.py b/src/mlcast/models/ldcast/original_weights.py new file mode 100644 index 0000000..b54a76c --- /dev/null +++ b/src/mlcast/models/ldcast/original_weights.py @@ -0,0 +1,72 @@ +import torch + +def convert_original_weights(ldm_weights_fn): + ''' + returns the original weights of the denoiser and the conditioner from the way they were saved originally + at the moment, the ema scope is not taken into account + the unmatched_keys are the ema keys and the buffer keys for the schedule (at the moment) + ''' + ldm_state_dict = torch.load(ldm_weights_fn) + + # track unmatched keys + unmatched_keys = list(ldm_state_dict.keys()) + + # remove the weights of the autoencoder + for k in unmatched_keys.copy(): + if k.startswith('autoencoder.') or k.startswith('context_encoder.autoencoder.'): + unmatched_keys.remove(k) + + # extract the keys of the denoiser (it was called 'model' in the original code) + denoiser_state_dict = {} + for k in unmatched_keys.copy(): + if k.startswith('model.'): + new_key = k.replace('model.', '') + denoiser_state_dict[new_key] = ldm_state_dict[k] + unmatched_keys.remove(k) + + # extract the keys of the conditioner (it was called 'context_encoder' in the original code) + conditioner_state_dict = {} + for k in unmatched_keys.copy(): + if k.startswith('context_encoder.'): + new_key = k.replace('context_encoder.', '') + conditioner_state_dict[new_key] = ldm_state_dict[k] + unmatched_keys.remove(k) + + # proj, temporal_transformer and analysis were lists one only element, I simplified this + # the keys have to be adapted + new_conditioner_state_dict = {} + for k, v in conditioner_state_dict.items(): + new_key = k + if k.startswith('proj.0.'): + new_key = k.replace('proj.0.', 'proj.') + if k.startswith('temporal_transformer.0.'): + new_key = k.replace('temporal_transformer.0.', 'temporal_transformer.') + if k.startswith('analysis.0.'): + new_key = k.replace('analysis.0.', 'analysis.') + new_conditioner_state_dict[new_key] = v + conditioner_state_dict = new_conditioner_state_dict + + # create dict with unmatched keys + unmatched = {key: ldm_state_dict[key] for key in unmatched_keys} + + return {'denoiser_state_dict': denoiser_state_dict, + 'conditioner_state_dict': conditioner_state_dict, + 'unmatched': unmatched} + +def check_saved_buffers(d, ldm): + ''' + checks that the buffers saved in ldm are the same than the ones in d (which is a dict containing these values) + returns the unmatched elements in d + ''' + + unmatched_keys = list(d.keys()) + + for buffer in ldm.named_buffers(): + name, value = buffer + assert (value == d[name].to(value.device)).all() + unmatched_keys.remove(name) + + # create dict with unmatched keys + unmatched = {key: d[key] for key in unmatched_keys} + + return unmatched \ No newline at end of file From 8994ef6501c3f65510ef3da31a75521357179517 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Fri, 27 Feb 2026 11:05:04 +0100 Subject: [PATCH 07/13] changes with respect to previous commit: small changes. Apart from a few typos generating errors: - the EMA class can save its weights, and the weights can be loaded through methods of the class - change the obsolute imports into relative imports - worked on the LDCast class (in ldcast.py): it can be loaded from a yaml file or from dict containing the config and I implemented a very minimal version of the fit method - I also chnaged the convert_original_weights function in original_weights.py so that it handles all weights (conditioner, denoiser, scheduling buffers and ema weights) --- .gitignore | 4 + README.md | 39 +- original_config.yaml | 21 ++ src/mlcast/models/base.py | 1 + .../.ipynb_checkpoints/context-checkpoint.py | 38 -- .../.ipynb_checkpoints/nowcast-checkpoint.py | 127 ------- src/mlcast/models/ldcast/context/nowcast.py | 1 - .../.ipynb_checkpoints/plms-checkpoint.py | 345 ------------------ .../models/ldcast/diffusion/diffusion.py | 8 +- src/mlcast/models/ldcast/diffusion/ema.py | 19 +- src/mlcast/models/ldcast/diffusion/plms.py | 2 +- .../models/ldcast/diffusion/scheduler.py | 2 +- src/mlcast/models/ldcast/ldcast.py | 88 ++++- src/mlcast/models/ldcast/original_weights.py | 57 ++- 14 files changed, 188 insertions(+), 564 deletions(-) create mode 100644 original_config.yaml delete mode 100644 src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py diff --git a/.gitignore b/.gitignore index 9853d60..a663386 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,7 @@ __pycache__/ build/ dist/ .venv/ +.ipynb_checkpoints/ +src/mlcast/modules/.ipynb_checkpoints/ +src/mlcast/models/ldcast/context/.ipynb_checkpoints/ +src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ diff --git a/README.md b/README.md index 64b6eae..83fae2c 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,9 @@ autoencoder.net.load_state_dict(torch.load(autoenc_weights_fn)) # Latent diffusion (= conditioner + denoiser) The `LatentDiffusion` class is a `nn.Module` combining the conditioner and the denoiser. ```python +from src.mlcast.models.ldcast.diffusion.unet import UNetModel +from src.mlcast.models.ldcast.context.context import AFNONowcastNetCascade + # setup forecaster conditioner = AFNONowcastNetCascade( 32, @@ -73,6 +76,8 @@ ldm((t, noise, latent_inputs)) ``` The noise has to have the shape true radar images encoded in latent space. +## LatentDiffusionLightning class and training of the ldm + Create fake data to train the ldm: ```python from torch.utils.data import TensorDataset @@ -95,32 +100,38 @@ ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler()) trainer = L.Trainer() trainer.fit(ldm_lightning, dataloader) ``` - +## Loading the original weights The original weights can not be directly loaded because the models are structured a little differently, but the original weights files can be converted with ```python from src.mlcast.models.ldcast.original_weights import convert_original_weights ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' state_dict = convert_original_weights(ldm_weights_fn) -torch.save(state_dict['denoiser_state_dict'], 'denoiser.pt') -torch.save(state_dict['conditioner_state_dict'], 'conditioner.pt') +torch.save(state_dict['denoiser'], 'denoiser_state_dict.pt') +torch.save(state_dict['conditioner'], 'conditioner_state_dict.pt') +torch.save(state_dict['ema'], 'ema.pt') ``` -`state_dict['unmatched']` contains a `dict` with the elements that were not matched (only the ema weights because I did not take care of the ema scope for the moment, and the buffer keys for the scheduling). The weights for the conditioner and the denoiser can then be loaded with +`state_dict['unmatched']` contains a `dict` with the elements that were not matched (should be empty). The weights for the conditioner and the denoiser (including the buffers for the scheduling) can then be loaded with ```python conditioner.load_state_dict(torch.load('conditioner_state_dict.pt')) denoiser.load_state_dict(torch.load('denoiser_state_dict.pt')) ``` -One can check that the buffers have the same values with +The EMA weights and parameters can be loaded with ```python -from src.mlcast.models.ldcast.original_weights import check_saved_buffers -unmatched = check_saved_buffers(state_dict['unmatched'], ldm) +ldm_lighting.ema.load('ema.pt') ``` -Here, `unmatched` contains the element which have not been matched (only the ema weights). # Main LDCast class ```python from src.mlcast.models.ldcast.ldcast import LDCast -ldcast = LDCast(ldm_lightning, autoencoder) +from src.mlcast.models.ldcast.diffusion.plms import PLMSSampler +sampler = PLMSSampler(denoiser) +ldcast = LDCast(ldm_lightning, autoencoder, sampler) +``` +Predictions can be produced with +```python +inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') +ldcast.predict(inputs) ``` To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): ```python @@ -130,17 +141,23 @@ To save in a folder: ```python ldcast.save('/path/to/folder') ``` +The original config for the conditioner and the autoencoder is in `original_config.yaml`, and can be laoded with: +```python +config = 'original_config' +ldcast = LDCast.from_config(config) +``` +Here, `config` can also be a `dict`. # TO DO The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. -I have understood that samplers are only used in inference ! The training (and validation) step is always done by predicting the noise (or a quantity which is related to it by a simple formula). What I called previously the SimpleSampler is actually simply a scheduler (which determines the values of alphas and betas, and add the noise on the latent samples during training) - We might integrate this code within the Hugging Face Diffusers Library. It remains mainly to write code in the main LDCast class (in `ldcast.py`) +It would be nice to rewrite the PLMS sampler, it is a little messy + # Basics on diffusion models See https://huggingface.co/blog/annotated-diffusion for some notations and formulas. diff --git a/original_config.yaml b/original_config.yaml new file mode 100644 index 0000000..049fc3e --- /dev/null +++ b/original_config.yaml @@ -0,0 +1,21 @@ +autoencoder_hidden_width: &autoencoder_hidden_width 32 + +conditioner: + autoencoder_dim: *autoencoder_hidden_width + output_patches: &output_patches 5 + cascade_depth: 3 + embed_dim: 128 + analysis_depth: 4 + +denoiser: + in_channels: *autoencoder_hidden_width + model_channels: 256 + out_channels: *autoencoder_hidden_width + num_res_blocks: 2 + attention_resolutions: [1, 2] + dims: 3 + channel_mult: [1, 2, 4] + num_heads: 8 + num_timesteps: *output_patches + context_ch: [128, 256, 512] # should be equal to conditioner.cascade_dims ? + \ No newline at end of file diff --git a/src/mlcast/models/base.py b/src/mlcast/models/base.py index fe068fb..1c004c3 100644 --- a/src/mlcast/models/base.py +++ b/src/mlcast/models/base.py @@ -136,6 +136,7 @@ def training_logic(self, batch, batch_idx): x, y = batch predictions = self.forward(x, n_timesteps = 4) loss = self.loss(predictions, y) + return loss def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: """Generic model step for training or validation. diff --git a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py deleted file mode 100644 index caedbb6..0000000 --- a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/context-checkpoint.py +++ /dev/null @@ -1,38 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/analysis.py - -import torch -from torch import nn -from torch.nn import functional as F - -from .nowcast import AFNONowcastNetBase -from ..blocks.resnet import ResBlock3D - - -class AFNONowcastNetCascade(AFNONowcastNetBase): - def __init__(self, *args, cascade_depth=4, **kwargs): - super().__init__(*args, **kwargs) - self.cascade_depth = cascade_depth - self.resnet = nn.ModuleList() - ch = self.embed_dim_out - self.cascade_dims = [ch] - for i in range(cascade_depth-1): - ch_out = 2*ch - self.cascade_dims.append(ch_out) - self.resnet.append( - ResBlock3D(ch, ch_out, kernel_size=(1,3,3), norm=None) - ) - ch = ch_out - - def forward(self, x): - # the past timesteps are needed here, but they are always the same... - # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why - past_timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32).unsqueeze(0).expand(1,-1) - x = super().forward(x, past_timesteps) - img_shape = tuple(x.shape[-2:]) - cascade = {img_shape: x} - for i in range(self.cascade_depth-1): - x = F.avg_pool3d(x, (1,2,2)) - x = self.resnet[i](x) - img_shape = tuple(x.shape[-2:]) - cascade[img_shape] = x - return cascade \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py b/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py deleted file mode 100644 index f13b994..0000000 --- a/src/mlcast/models/ldcast/context/.ipynb_checkpoints/nowcast-checkpoint.py +++ /dev/null @@ -1,127 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/nowcast/nowcast.py, but removed the Nowcaster, AFNONowcastNetBasic and AFNONowcastNet classes because they were not used. Reworked also the two remaining classes (FusionBlock3D and AFNONowcastNetBase) to simplify the code by removing the unused parts. - -import collections - -import torch -from torch import nn -from torch.nn import functional as F -import pytorch_lightning as pl - -from ..blocks.afno import AFNOBlock3d -from ..blocks.attention import positional_encoding, TemporalTransformer - -class FusionBlock3d(nn.Module): - def __init__(self, dim, size_ratio, dim_out=None, afno_fusion=False): - super().__init__() - - if dim_out is None: - dim_out = dim - - if size_ratio == 1: - scale = nn.Identity() - else: - scale = [] - while size_ratio > 1: - scale.append(nn.ConvTranspose3d( - dim[i], dim_out if size_ratio==2 else dim[i], - kernel_size=(1,3,3), stride=(1,2,2), - padding=(0,1,1), output_padding=(0,1,1) - )) - size_ratio //= 2 - scale = nn.Sequential(*scale) - self.scale = scale - - self.afno_fusion = afno_fusion - - if self.afno_fusion: - self.fusion = nn.Identity() - - def resize_proj(self, x): - x = x.permute(0,4,1,2,3) - x = self.scale(x) - x = x.permute(0,2,3,4,1) - return x - - def forward(self, x): - x = self.resize_proj(x) - return x - - -class AFNONowcastNetBase(nn.Module): - def __init__( - self, - autoencoder_dim, - embed_dim=128, - embed_dim_out=None, - analysis_depth=4, - forecast_depth=4, - input_patches=1, - input_size_ratios=1, - output_patches=2, - train_autoenc=False, - afno_fusion=False - ): - super().__init__() - - if embed_dim_out is None: - embed_dim_out = embed_dim - self.embed_dim = embed_dim - self.embed_dim_out = embed_dim_out - self.output_patches = output_patches - - self.proj = nn.Conv3d(autoencoder_dim, embed_dim, kernel_size=1) - - self.analysis = nn.Sequential( - *(AFNOBlock3d(embed_dim) for _ in range(analysis_depth)) - ) - - # temporal transformer - self.use_temporal_transformer = input_patches != output_patches - if self.use_temporal_transformer: - self.temporal_transformer = TemporalTransformer(embed_dim) - - # data fusion - self.fusion = FusionBlock3d(embed_dim, input_size_ratios, - afno_fusion=afno_fusion, dim_out=embed_dim_out) - - # forecast - self.forecast = nn.Sequential( - *(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth)) - ) - - def add_pos_enc(self, x, t): - '''not sure the this does what it was supposed to do in the original LDCast code''' - if t.shape[1] != x.shape[1]: - # this can happen if x has been compressed - # by the autoencoder in the time dimension - ds_factor = t.shape[1] // x.shape[1] - t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:,0,:] - - pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2,3)) - return x + pos_enc - - def forward(self, z, timesteps): - '''z is the latent representation of the conditioning and timesteps is contains the timesteps of the input frames (it is [-3, -2, -1, 0])''' - z = self.proj(z) - z = z.permute(0,2,3,4,1) - z = self.analysis(z) - - if self.use_temporal_transformer: - # add positional encoding - z = self.add_pos_enc(z, timesteps) - - # transform to output shape and coordinates - expand_shape = z.shape[:1] + (-1,) + z.shape[2:] - pos_enc_output = positional_encoding( - torch.arange(1,self.output_patches+1, device=z.device), - self.embed_dim, add_dims=(0,2,3) - ) - pe_out = pos_enc_output.expand(*expand_shape) - z = self.temporal_transformer(pe_out, z) - - - # merge inputs - z = self.fusion(z) - # produce prediction - z = self.forecast(z) - return z.permute(0,4,1,2,3) # to channels-first order \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/nowcast.py b/src/mlcast/models/ldcast/context/nowcast.py index f13b994..b066816 100644 --- a/src/mlcast/models/ldcast/context/nowcast.py +++ b/src/mlcast/models/ldcast/context/nowcast.py @@ -58,7 +58,6 @@ def __init__( input_patches=1, input_size_ratios=1, output_patches=2, - train_autoenc=False, afno_fusion=False ): super().__init__() diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py deleted file mode 100644 index 3b733d8..0000000 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/plms-checkpoint.py +++ /dev/null @@ -1,345 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/plms.py, but changed model.apply_model into model.forward - - -""" -From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py -""" - -"""SAMPLING ONLY.""" - -import numpy as np -import torch -from tqdm import tqdm - -from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like - - -class PLMSSampler: - def __init__(self, model, timesteps, schedule="linear", **kwargs): - self.model = model - self.ddpm_num_timesteps = timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - # if type(attr) == torch.Tensor: - # if attr.device != torch.device("cuda"): - # attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule( - self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True - ): - if ddim_eta != 0: - raise ValueError("ddim_eta must be 0 for PLMS") - self.ddim_timesteps = make_ddim_timesteps( - ddim_discr_method=ddim_discretize, - num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps, - verbose=verbose, - ) - alphas_cumprod = self.model.alphas_cumprod - assert ( - alphas_cumprod.shape[0] == self.ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - device = next(self.model.parameters()).device - to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) - - self.register_buffer("betas", to_torch(self.model.betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer( - "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) - ) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer( - "sqrt_alphas_cumprod", to_torch(torch.sqrt(alphas_cumprod)) - ) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", - to_torch(torch.sqrt(1.0 - alphas_cumprod)), - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(torch.log(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(torch.sqrt(1.0 / alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", - to_torch(torch.sqrt(1.0 / alphas_cumprod - 1)), - ) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod, - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, - verbose=verbose, - ) - self.register_buffer("ddim_sigmas", ddim_sigmas) - self.register_buffer("ddim_alphas", ddim_alphas) - self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) - self.register_buffer( - "ddim_sqrt_one_minus_alphas", torch.sqrt(1.0 - ddim_alphas) - ) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) - / (1 - self.alphas_cumprod) - * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) - ) - self.register_buffer( - "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps - ) - - @torch.no_grad() - def sample( - self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, - x0=None, - temperature=1.0, - noise_dropout=0.0, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - progbar=True, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs, - ): - """ - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - """ - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - size = (batch_size,) + shape - print(f"Data shape for PLMS sampling is {size}") - - samples, intermediates = self.plms_sampling( - conditioning, - size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - progbar=progbar, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling( - self, - cond, - shape, - x_T=None, - ddim_use_original_steps=False, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - log_every_t=100, - temperature=1.0, - noise_dropout=0.0, - score_corrector=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - progbar=True, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = ( - self.ddpm_num_timesteps - if ddim_use_original_steps - else self.ddim_timesteps - ) - elif timesteps is not None and not ddim_use_original_steps: - subset_end = ( - int( - min(timesteps / self.ddim_timesteps.shape[0], 1) - * self.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {"x_inter": [img], "pred_x0": [img]} - time_range = ( - list(reversed(range(0, timesteps))) - if ddim_use_original_steps - else np.flip(timesteps) - ) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = time_range - if progbar: - iterator = tqdm(iterator, desc="PLMS Sampler", total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full( - (b,), - time_range[min(i + 1, len(time_range) - 1)], - device=device, - dtype=torch.long, - ) - - outs = self.p_sample_plms( - img, - cond, - ts, - index=index, - use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, - temperature=temperature, - noise_dropout=noise_dropout, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, - t_next=ts_next, - ) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates["x_inter"].append(img) - intermediates["pred_x0"].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms( - self, - x, - condition, - t, - index, - repeat_noise=False, - use_original_steps=False, - quantize_denoised=False, - temperature=1.0, - noise_dropout=0.0, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - old_eps=None, - t_next=None, - ): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 - ): - e_t = self.model(x, t, condition) - else: - pass - '''is never used - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, condition]) - e_t_uncond, e_t = self.model.apply_denoiser(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - ''' - - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = ( - self.model.alphas_cumprod_prev - if use_original_steps - else self.ddim_alphas_prev - ) - sqrt_one_minus_alphas = ( - self.model.sqrt_one_minus_alphas_cumprod - if use_original_steps - else self.ddim_sqrt_one_minus_alphas - ) - sigmas = ( - self.model.ddim_sigmas_for_original_num_steps - if use_original_steps - else self.ddim_sigmas - ) - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - param_shape = (b,) + (1,) * (x.ndim - 1) - a_t = torch.full(param_shape, alphas[index], device=device) - a_prev = torch.full(param_shape, alphas_prev[index], device=device) - sigma_t = torch.full(param_shape, sigmas[index], device=device) - sqrt_one_minus_at = torch.full( - param_shape, sqrt_one_minus_alphas[index], device=device - ) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.0: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = ( - 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] - ) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index e4c8b69..08c429e 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -3,9 +3,9 @@ import pytorch_lightning as L from typing import Any import contextlib -from src.mlcast.models.base import NowcastingLightningModule +from ...base import NowcastingLightningModule import numpy as np -from src.mlcast.models.ldcast.diffusion.utils import extract_into_tensor +from .utils import extract_into_tensor from .ema import EMA class LatentDiffusion(nn.Module): @@ -49,10 +49,10 @@ def register_schedule(self): if (name in saved_buffers.keys() and (schedule[name] != saved_buffers[name]).any()) ] if len(already_saved_and_different) > 0: - raise AttributeError(f'The denoiser has already some different values for {already_saved}') + raise AttributeError(f'The denoiser has already some different values for {already_saved_and_different}') for k in schedule.keys(): - self.net.register_buffer(k, schedule[k]) + self.net.denoiser.register_buffer(k, schedule[k]) def training_logic(self, batch, batch_idx): latent_inputs, latent_true = batch diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py index 4ccec75..a76ec51 100644 --- a/src/mlcast/models/ldcast/diffusion/ema.py +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -6,7 +6,7 @@ import torch from torch import nn -class EMA(nn.Module): +class EMA(): def __init__(self, model, decay=0.9999, use_num_updates=True): super().__init__() if decay < 0.0 or decay > 1.0: @@ -51,4 +51,19 @@ def restore(self): '''restore original model weights from backup''' for name, param in self.model.named_parameters(): if param.requires_grad: - param.data = self.backup[name] \ No newline at end of file + param.data = self.backup[name] + + def load(self, filename): + '''load the ema (shadow) weights parameters''' + self.shadow = torch.load(filename) + self.decay = self.shadow.pop('decay') + self.num_updates = self.shadow.pop('num_updates') + + def save(self, filename): + '''save the ema (shadow) weights parameters''' + self.shadow['decay'] = self.decay + self.shadow['num_updates'] = self.num_updates + torch.save(self.shadow, filename) + + self.shadow.pop('decay') + self.shadow.pop('num_updates') \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/plms.py b/src/mlcast/models/ldcast/diffusion/plms.py index 3b733d8..ce27e0c 100644 --- a/src/mlcast/models/ldcast/diffusion/plms.py +++ b/src/mlcast/models/ldcast/diffusion/plms.py @@ -15,7 +15,7 @@ class PLMSSampler: - def __init__(self, model, timesteps, schedule="linear", **kwargs): + def __init__(self, model, timesteps = 1000, schedule = "linear", **kwargs): self.model = model self.ddpm_num_timesteps = timesteps self.schedule = schedule diff --git a/src/mlcast/models/ldcast/diffusion/scheduler.py b/src/mlcast/models/ldcast/diffusion/scheduler.py index cd0a607..95cb1a3 100644 --- a/src/mlcast/models/ldcast/diffusion/scheduler.py +++ b/src/mlcast/models/ldcast/diffusion/scheduler.py @@ -1,5 +1,5 @@ from functools import partial -from src.mlcast.models.ldcast.diffusion.utils import make_beta_schedule +from .utils import make_beta_schedule import numpy as np import torch diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index 5dd6777..10065d9 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -1,40 +1,71 @@ # new file with respect to original code -from src.mlcast.models.base import NowcastingModelBase +from ..base import NowcastingModelBase import pytorch_lightning as L -from src.mlcast.models.ldcast.data import LatentDataset +from .data import LatentDataset from torch.utils.data import DataLoader import torch +import contextlib +from torch.utils.data import TensorDataset class LDCast(NowcastingModelBase): - def __init__(self, ldm_lightning, autoencoder): + def __init__(self, ldm_lightning, autoencoder, sampler): super().__init__() self.ldm_lightning = ldm_lightning self.autoencoder = autoencoder + self.sampler = sampler - def fit(self, dataset): + def fit(self, inputs, true, batch_size, max_epochs): '''dataset should contains pairs of (inputs, true), with inputs.shape = (batch_size, 1, 4, 256, 256) true.shape = (batch_size, 1, 20, 256, 256) ''' - self.fit_autoencoder(dataset) - self.fit_ldm(dataset) + print('Training autoencoder') + self.fit_autoencoder(inputs, batch_size, max_epochs) - def fit_ldm(self, dataset): + print('Training ldm') + self.fit_ldm(inputs, true, batch_size, max_epochs) + + def fit_ldm(self, inputs, true, batch_size, max_epochs): self.autoencoder.net.eval() + self.ldm_lightning.net.train() + dataset = TensorDataset(inputs, true) latent_dataset = LatentDataset(dataset, self.autoencoder.net) - dataloader = DataLoader(latent_dataset, batch_size=2) - trainer = L.Trainer() + dataloader = DataLoader(latent_dataset, batch_size = batch_size) + trainer = L.Trainer(max_epochs = max_epochs) trainer.fit(self.ldm_lightning, dataloader) - def fit_autoencoder(self, dataset): - pass + def fit_autoencoder(self, inputs, batch_size, max_epochs): + self.autoencoder.net.train() + + dataset = TensorDataset(inputs, inputs) + dataloader = DataLoader(dataset, batch_size = batch_size) + trainer = L.Trainer(max_epochs = max_epochs) + trainer.fit(self.autoencoder, dataloader) - def predict(self, inputs): + def predict(self, inputs, num_diffusion_iters = 50, verbose = True): '''inputs.shape = (batch_size, 1, 4, 256, 256)''' + + assert False, 'prediction should be implemented with a trainer, to take into account the switches of ema weights for example''' + latent_inputs = self.autoencoder.net.encode(inputs) + condition = self.ldm_lightning.net.conditioner(latent_inputs) + + gen_shape = (32, 5, 256//4, 256//4) + batch_size = len(latent_inputs) + + with contextlib.redirect_stdout(None): + (s, intermediates) = self.sampler.sample( + num_diffusion_iters, + batch_size, + gen_shape, + condition, + progbar = verbose) + + return s + latent_pred = self.ldm_lightning(latent_inputs) return self.autoencoder.net.decode(latent_pred) @@ -44,7 +75,7 @@ def save(self, folder): torch.save(self.ldm_lightning.net.denoiser.state_dict(), f'{folder}/denoiser.pt') if hasattr(self.ldm_lightning, 'ema'): - torch.save(self.ldm_lightning.ema.shadow, f'{folder}/ema.pt') + self.ldm_lightning.ema.save(f'{folder}/ema.pt') def load(self, folder): self.autoencoder.net.load_state_dict(torch.load(f'{folder}/autoencoder.pt')) @@ -52,7 +83,36 @@ def load(self, folder): self.ldm_lightning.net.denoiser.load_state_dict(torch.load(f'{folder}/denoiser.pt')) if hasattr(self.ldm_lightning, 'ema'): - self.ldm_lightning.ema.shadow = torch.load(f'{folder}/ema.pt') + self.ldm_lightning.ema.load(f'{folder}/ema.pt') + + @classmethod + def from_config(cls, config): + + if isinstance(config, str): + import yaml + with open(config, 'r') as file: + config = yaml.safe_load(file) + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + from .autoenc.autoenc import AutoencoderKLNet, autoenc_loss + from ..base import NowcastingLightningModule + from .diffusion.unet import UNetModel + from .context.context import AFNONowcastNetCascade + from .diffusion.diffusion import LatentDiffusion, LatentDiffusionLightning + from torch.nn import L1Loss + from .diffusion.scheduler import Scheduler + from .diffusion.plms import PLMSSampler + + autoencoder = NowcastingLightningModule(AutoencoderKLNet(), autoenc_loss()).to(device) + conditioner = AFNONowcastNetCascade(**config['conditioner']).to(device) + denoiser = UNetModel(**config['denoiser']).to(device) + ldm = LatentDiffusion(conditioner, denoiser) + ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler()) + sampler = PLMSSampler(denoiser) + + return cls(ldm_lightning, autoencoder, sampler) + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/original_weights.py b/src/mlcast/models/ldcast/original_weights.py index b54a76c..38ddb7a 100644 --- a/src/mlcast/models/ldcast/original_weights.py +++ b/src/mlcast/models/ldcast/original_weights.py @@ -1,4 +1,5 @@ import torch +import re def convert_original_weights(ldm_weights_fn): ''' @@ -23,6 +24,12 @@ def convert_original_weights(ldm_weights_fn): new_key = k.replace('model.', '') denoiser_state_dict[new_key] = ldm_state_dict[k] unmatched_keys.remove(k) + + denoiser_buffers_keys = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod'] + for k in unmatched_keys.copy(): + if k in denoiser_buffers_keys: + denoiser_state_dict[k] = ldm_state_dict[k] + unmatched_keys.remove(k) # extract the keys of the conditioner (it was called 'context_encoder' in the original code) conditioner_state_dict = {} @@ -32,7 +39,7 @@ def convert_original_weights(ldm_weights_fn): conditioner_state_dict[new_key] = ldm_state_dict[k] unmatched_keys.remove(k) - # proj, temporal_transformer and analysis were lists one only element, I simplified this + # proj, temporal_transformer and analysis were lists with one only element, I simplified this # the keys have to be adapted new_conditioner_state_dict = {} for k, v in conditioner_state_dict.items(): @@ -46,27 +53,37 @@ def convert_original_weights(ldm_weights_fn): new_conditioner_state_dict[new_key] = v conditioner_state_dict = new_conditioner_state_dict + ema = {} + for k in unmatched_keys.copy(): + if k.startswith('model_ema.'): + new_key = restore_name(k.replace('model_ema.', '')) + ema[new_key] = ldm_state_dict[k] + unmatched_keys.remove(k) + # create dict with unmatched keys unmatched = {key: ldm_state_dict[key] for key in unmatched_keys} - return {'denoiser_state_dict': denoiser_state_dict, - 'conditioner_state_dict': conditioner_state_dict, + return {'denoiser': denoiser_state_dict, + 'conditioner': conditioner_state_dict, + 'ema': ema, 'unmatched': unmatched} -def check_saved_buffers(d, ldm): - ''' - checks that the buffers saved in ldm are the same than the ones in d (which is a dict containing these values) - returns the unmatched elements in d - ''' - - unmatched_keys = list(d.keys()) - - for buffer in ldm.named_buffers(): - name, value = buffer - assert (value == d[name].to(value.device)).all() - unmatched_keys.remove(name) - - # create dict with unmatched keys - unmatched = {key: d[key] for key in unmatched_keys} - - return unmatched \ No newline at end of file +def restore_name(s): + '''for the EMA, all the dots were removed from the parameters names in the original code, so they should be added again to match during swapping''' + # add dots before and after every digit + res = re.sub(r'(\d)', r'.\1.', s) + # if the digit was in 'fc1', 'fc2', it should not be preceded by a dot + res = res.replace('fc.1', 'fc1') + res = res.replace('fc.2', 'fc2') + # same, but there should be in addition a dot before w1, w2, b1 and b2 + res = res.replace('w.1.', '.w1') + res = res.replace('w.2.', '.w2') + res = res.replace('b.1.', '.b1') + res = res.replace('b.2.', '.b2') + # add a dot before each 'weights' and each 'bias' + res = res.replace('weight', '.weight').replace('bias', '.bias') + # add dot after mlp + res = res.replace('mlp', 'mlp.') + # if two dots are inserted, replace them by one (happens if two digits follow each other, or if a digit is followed by b or w) + res = res.replace('..', '.') + return res \ No newline at end of file From 63e70f0a563802467d567a8660897fd3b62fd99c Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Mon, 2 Mar 2026 14:45:34 +0100 Subject: [PATCH 08/13] changes with respect to previous commit: small changes. Apart from a few typos generating errors: - diffusion/diffusion.py: I took into account the fact that ema might not be used in lightning hooks (on_train_batch_end, etc.) and updated the way the EMA config is passed to the EMA class - diffusion/ema.py: I made sure that the gradient graphs are not kept when the weights are stored in self.backup and self.shadow, and added the possibility to store on CPU the weights which are not currently on the model through the store_device keyword - data.py: I added the code to construct a dataset to train the autoencoder (AutoencoderDataset), and to construct a dataset to train the ldm (LatentDataset) from a sampled radar dataset. I added a DataModule class - ldcast.py: I mainly wrote the fit_autoencoder and fit_ldm methods taking as argument a sampled radar dataset and using the AutoencoderDataset and LatentDataset classes --- README.md | 22 ++--- original_config.yaml | 5 ++ src/mlcast/models/ldcast/autoenc/autoenc.py | 3 + src/mlcast/models/ldcast/data.py | 88 ++++++++++++++++--- .../models/ldcast/diffusion/diffusion.py | 35 +++++--- src/mlcast/models/ldcast/diffusion/ema.py | 26 +++--- src/mlcast/models/ldcast/ldcast.py | 31 ++++--- 7 files changed, 145 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 83fae2c..a0d9ce4 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,8 @@ Here, 4 consecutive radar images are encoded at once. # Autoencoder ```python -from src.mlcast.models.ldcast.autoenc.autoenc import AutoencoderKLNet, autoenc_loss -from src.mlcast.models.base import NowcastingLightningModule +from mlcast.models.ldcast.autoenc.autoenc import AutoencoderKLNet, autoenc_loss +from mlcast.models.base import NowcastingLightningModule autoencoder = NowcastingLightningModule(AutoencoderKLNet(), autoenc_loss()).to('cuda') ``` The autoencoder is an instance of the NowcastingLightningModule. Training the autoencoder: @@ -41,8 +41,8 @@ autoencoder.net.load_state_dict(torch.load(autoenc_weights_fn)) # Latent diffusion (= conditioner + denoiser) The `LatentDiffusion` class is a `nn.Module` combining the conditioner and the denoiser. ```python -from src.mlcast.models.ldcast.diffusion.unet import UNetModel -from src.mlcast.models.ldcast.context.context import AFNONowcastNetCascade +from mlcast.models.ldcast.diffusion.unet import UNetModel +from mlcast.models.ldcast.context.context import AFNONowcastNetCascade # setup forecaster conditioner = AFNONowcastNetCascade( @@ -55,7 +55,7 @@ conditioner = AFNONowcastNetCascade( ).to('cuda') # setup denoiser -from src.mlcast.models.ldcast.diffusion.unet import UNetModel +from mlcast.models.ldcast.diffusion.unet import UNetModel denoiser = UNetModel(in_channels=autoencoder.net.hidden_width, model_channels=256, out_channels=autoencoder.net.hidden_width, num_res_blocks=2, attention_resolutions=(1,2), @@ -64,7 +64,7 @@ denoiser = UNetModel(in_channels=autoencoder.net.hidden_width, context_ch=[128, 256, 512] # context channels (= analysis_net.cascade_dims) ).to('cuda') -from src.mlcast.models.ldcast.diffusion.diffusion import LatentDiffusion +from mlcast.models.ldcast.diffusion.diffusion import LatentDiffusion ldm = LatentDiffusion(conditioner, denoiser) ``` The `LatentDiffusion` class has a forward pass: it takes the noise, the timesteps of the diffusion and the encoded inputs @@ -93,8 +93,8 @@ dataloader = DataLoader(latent_dataset, batch_size=2) Put `ldm` in a `LightningModule` and train: ```python from torch.nn import L1Loss -from src.mlcast.models.ldcast.diffusion.scheduler import Scheduler -from src.mlcast.models.ldcast.diffusion.diffusion import LatentDiffusionLightning +from mlcast.models.ldcast.diffusion.scheduler import Scheduler +from mlcast.models.ldcast.diffusion.diffusion import LatentDiffusionLightning ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler()) trainer = L.Trainer() @@ -103,7 +103,7 @@ trainer.fit(ldm_lightning, dataloader) ## Loading the original weights The original weights can not be directly loaded because the models are structured a little differently, but the original weights files can be converted with ```python -from src.mlcast.models.ldcast.original_weights import convert_original_weights +from mlcast.models.ldcast.original_weights import convert_original_weights ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' state_dict = convert_original_weights(ldm_weights_fn) torch.save(state_dict['denoiser'], 'denoiser_state_dict.pt') @@ -123,8 +123,8 @@ ldm_lighting.ema.load('ema.pt') # Main LDCast class ```python -from src.mlcast.models.ldcast.ldcast import LDCast -from src.mlcast.models.ldcast.diffusion.plms import PLMSSampler +from mlcast.models.ldcast.ldcast import LDCast +from mlcast.models.ldcast.diffusion.plms import PLMSSampler sampler = PLMSSampler(denoiser) ldcast = LDCast(ldm_lightning, autoencoder, sampler) ``` diff --git a/original_config.yaml b/original_config.yaml index 049fc3e..8807382 100644 --- a/original_config.yaml +++ b/original_config.yaml @@ -18,4 +18,9 @@ denoiser: num_heads: 8 num_timesteps: *output_patches context_ch: [128, 256, 512] # should be equal to conditioner.cascade_dims ? + +ema: + use: True + kwargs: + store_device: 'cpu' \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py index 5d5a798..d673a2a 100644 --- a/src/mlcast/models/ldcast/autoenc/autoenc.py +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -54,6 +54,9 @@ def encode(self, x, return_log_var = False): if return_log_var: return (mean, log_var) else: + # if the first axis has length 1, it is the batch dimension and should be removed + if mean.shape[0] == 1: + mean = mean[0] return mean def decode(self, z): diff --git a/src/mlcast/models/ldcast/data.py b/src/mlcast/models/ldcast/data.py index edc6773..447d853 100644 --- a/src/mlcast/models/ldcast/data.py +++ b/src/mlcast/models/ldcast/data.py @@ -1,12 +1,15 @@ -from torch.utils.data import Dataset +from torch.utils.data import Dataset, random_split, DataLoader import torch +import pytorch_lightning as pl class LatentDataset(Dataset): - def __init__(self, dataset, autoencoder): + def __init__(self, sampled_radar_dataset, autoencoder, autoenc_time_ratio = 4): super().__init__() self.autoencoder = autoencoder - self.dataset = dataset + self.dataset = sampled_radar_dataset + self.autoenc_time_ratio = autoenc_time_ratio + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' def __len__(self): return len(self.dataset) @@ -14,15 +17,72 @@ def __len__(self): def __getitem__(self, idx): with torch.no_grad(): - inputs, true = self.dataset[idx] - latent_inputs = self.autoencoder.encode(inputs) - latent_true = self.autoencoder.encode(true) - - # until here, the first dimension of latent_inputs and latent_true is the 'batch dimension' - # if idx is a list, keep this batch dimension along this list - # if idx is not a list, this batch dimension is 1 and needs to be removed because the dataloader will repeatedly call __getitem__ and add an extra dimension for the batch dimension - if not isinstance(idx, list): - latent_inputs = latent_inputs[0] - latent_true = latent_true[0] + sequence = self.dataset[idx]['data'] + x = sequence[:self.autoenc_time_ratio] + y = sequence[self.autoenc_time_ratio:] + + # for some reason, Gabriele put the time axis before the channel axis, change this + x = x.swapaxes(0, 1).to(self.device) + y = y.swapaxes(0, 1).to(self.device) + + latent_x = self.autoencoder.encode(x) + latent_y = self.autoencoder.encode(y) + + return latent_x, latent_y + +class AutoencoderDataset(Dataset): + ''' + shape of one sample of sampled_radar_dataset = (1, 24, 1,) + spatial_shape + But, for the LDCast autoencoder, we want to have samples of (1, 4, 1,) + spatial_shape + So 1 sample of sampled_radar_dataset is partitioned in 6 samples for the autoencoder + ''' + def __init__(self, sampled_radar_dataset, autoenc_time_ratio = 4): + super().__init__() + self.srd = sampled_radar_dataset + self.autoenc_time_ratio = autoenc_time_ratio + self.samples_ratio = int(self.srd.steps / self.autoenc_time_ratio) # is 6 in the usual case where steps = 24 and autoenc_time_ratio = 4 + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + + def __len__(self): + return self.samples_ratio * len(self.srd) + + def __getitem__(self, idx): + ''' + when given idx between 0 and 6 * len(srd) - 1, one has first to find in which sample of srd we are (the index of this sample is index_srd) + then, within this sample, one has to find in which partition of this sample we are (this is given by index_in_srd_sample) + ''' + index_srd = idx // self.samples_ratio + index_in_srd_sample = idx - index_srd * self.samples_ratio + x = self.srd[index_srd]['data'].reshape(self.samples_ratio, self.autoenc_time_ratio, 1, self.srd.w, self.srd.h)[index_in_srd_sample] + + # for some reason, Gabriele put the time axis before the channel axis, change this + x = x.swapaxes(0, 1).to(self.device) + + # for the autoencoder, y is equal to x + y = x + + return x, y + +class DataModule(pl.LightningDataModule): + def __init__(self, dataset, train_ratio = 0.6, val_ratio = 0.2, **dataloader_kwargs): + super().__init__() + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = 1 - self.train_ratio - self.val_ratio + + train_ds, val_ds, test_ds = random_split(dataset, [self.train_ratio, self.val_ratio, self.test_ratio]) + self.train_dataset = train_ds + self.val_dataset = val_ds + self.test_dataset = test_ds + + self.dataloader_kwargs = dataloader_kwargs + + def train_dataloader(self): + return DataLoader(self.train_dataset, shuffle = True, **self.dataloader_kwargs) + + def val_dataloader(self): + return DataLoader(self.val_dataset, shuffle = False, **self.dataloader_kwargs) - return (latent_inputs, latent_true) \ No newline at end of file + def test_dataloader(self): + return DataLoader(self.test_dataset, shuffle = False, **self.dataloader_kwargs) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index 08c429e..fa18de6 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -29,16 +29,16 @@ def forward(self, x, n_timesteps = 4): class LatentDiffusionLightning(NowcastingLightningModule): - def __init__(self, ldm, loss, scheduler, use_ema = True): + def __init__(self, ldm, loss, scheduler, ema_config = {'use': True}): super().__init__(ldm, loss) self.scheduler = scheduler # register the schedules (i.e. the values of alpha, beta etc). self.register_schedule() - if use_ema: - self.ema = EMA(self.net.denoiser) - + if ema_config['use']: + self.ema = EMA(self.net.denoiser, **ema_config['kwargs']) + def register_schedule(self): schedule = self.scheduler.schedule(torch.float32, next(self.net.parameters()).device) @@ -76,28 +76,35 @@ def q_sample(self, x0, noise = None, t = None): if t is None: t = torch.randint(0, self.scheduler.timesteps, (x0.shape[0],), device=x0.device).long() - x_noisy = extract_into_tensor(self.net.sqrt_alphas_cumprod, t, x0.shape) * x0 + \ - extract_into_tensor(self.net.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + x_noisy = extract_into_tensor(self.net.denoiser.sqrt_alphas_cumprod, t, x0.shape) * x0 + \ + extract_into_tensor(self.net.denoiser.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise return t, noise, x_noisy - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.ema.update() + def on_train_batch_end(self, outputs, batch, batch_idx): + if hasattr(self, 'ema'): + self.ema.update() def on_validation_start(self): - self.ema.apply_shadow() + if hasattr(self, 'ema'): + self.ema.apply_shadow() def on_validation_end(self): - self.ema.restore() + if hasattr(self, 'ema'): + self.ema.restore() def on_test_start(self): - self.ema.apply_shadow() + if hasattr(self, 'ema'): + self.ema.apply_shadow() def on_test_end(self): - self.ema.restore() + if hasattr(self, 'ema'): + self.ema.restore() def on_predict_start(self): - self.ema.apply_shadow() + if hasattr(self, 'ema'): + self.ema.apply_shadow() def on_predict_end(self): - self.ema.restore() + if hasattr(self, 'ema'): + self.ema.restore() diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py index a76ec51..af13022 100644 --- a/src/mlcast/models/ldcast/diffusion/ema.py +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -7,16 +7,18 @@ from torch import nn class EMA(): - def __init__(self, model, decay=0.9999, use_num_updates=True): + def __init__(self, model, decay = 0.9999, use_num_updates = True, store_device = 'cuda'): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.model = model self.decay = decay - self.shadow = {} # to store EMA weights - self.backup = {} # to store the model weights when we replace them by ema weights - self.num_updates = 0 if use_num_updates else -1 # for dynamical decay + self.shadow = {} # to store EMA weights + self.backup = {} # to store the model weights when we replace them by ema weights + self.num_updates = 0 if use_num_updates else -1 # for dynamical decay + self.store_device = store_device # device on which to store the weights + self.model_device = next(model.parameters()).device # device on which the weights are used self.register() @@ -24,7 +26,7 @@ def register(self): '''initialize the ema weights with the model weights''' for name, param in self.model.named_parameters(): if param.requires_grad: - self.shadow[name] = param.data.clone() + self.shadow[name] = param.data.clone().detach().to(self.store_device) def update(self): '''update the shadow parameters''' @@ -37,21 +39,21 @@ def update(self): for name, param in self.model.named_parameters(): if param.requires_grad: - new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] - self.shadow[name] = new_average.clone() + new_average = (1.0 - self.decay) * param.data.detach().to(self.store_device) + self.decay * self.shadow[name] + self.shadow[name] = new_average def apply_shadow(self): '''apply shadow (EMA) weights to the model''' for name, param in self.model.named_parameters(): if param.requires_grad: - self.backup[name] = param.data.clone() - param.data = self.shadow[name] + self.backup[name] = param.data.clone().detach().to(self.store_device) + param.data = self.shadow[name].to(self.model_device) def restore(self): '''restore original model weights from backup''' for name, param in self.model.named_parameters(): if param.requires_grad: - param.data = self.backup[name] + param.data = self.backup[name].to(self.model_device) def load(self, filename): '''load the ema (shadow) weights parameters''' @@ -59,6 +61,10 @@ def load(self, filename): self.decay = self.shadow.pop('decay') self.num_updates = self.shadow.pop('num_updates') + # put the shadow tensors on the correct device + for k in self.shadow.keys(): + self.shadow[k] = self.shadow[k].to(self.store_device) + def save(self, filename): '''save the ema (shadow) weights parameters''' self.shadow['decay'] = self.decay diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index 10065d9..8abccf6 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -2,7 +2,7 @@ from ..base import NowcastingModelBase import pytorch_lightning as L -from .data import LatentDataset +from .data import LatentDataset, AutoencoderDataset, DataModule from torch.utils.data import DataLoader import torch import contextlib @@ -16,34 +16,33 @@ def __init__(self, ldm_lightning, autoencoder, sampler): self.autoencoder = autoencoder self.sampler = sampler - def fit(self, inputs, true, batch_size, max_epochs): + def fit(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): '''dataset should contains pairs of (inputs, true), with inputs.shape = (batch_size, 1, 4, 256, 256) true.shape = (batch_size, 1, 20, 256, 256) ''' print('Training autoencoder') - self.fit_autoencoder(inputs, batch_size, max_epochs) + self.fit_autoencoder(sampled_radar_dataset, dataloader_kwargs = dataloader_kwargs, trainer_kwargs = trainer_kwargs) print('Training ldm') - self.fit_ldm(inputs, true, batch_size, max_epochs) + self.fit_ldm(sampled_radar_dataset, dataloader_kwargs = dataloader_kwargs, trainer_kwargs = trainer_kwargs) - def fit_ldm(self, inputs, true, batch_size, max_epochs): + def fit_ldm(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): self.autoencoder.net.eval() self.ldm_lightning.net.train() - dataset = TensorDataset(inputs, true) - latent_dataset = LatentDataset(dataset, self.autoencoder.net) - dataloader = DataLoader(latent_dataset, batch_size = batch_size) - trainer = L.Trainer(max_epochs = max_epochs) - trainer.fit(self.ldm_lightning, dataloader) + dataset = LatentDataset(sampled_radar_dataset, self.autoencoder.net) + datamodule = DataModule(dataset, **dataloader_kwargs) + trainer = L.Trainer(**trainer_kwargs) + trainer.fit(self.ldm_lightning, datamodule) - def fit_autoencoder(self, inputs, batch_size, max_epochs): + def fit_autoencoder(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): self.autoencoder.net.train() - dataset = TensorDataset(inputs, inputs) - dataloader = DataLoader(dataset, batch_size = batch_size) - trainer = L.Trainer(max_epochs = max_epochs) - trainer.fit(self.autoencoder, dataloader) + dataset = AutoencoderDataset(sampled_radar_dataset) + datamodule = DataModule(dataset, **dataloader_kwargs) + trainer = L.Trainer(**trainer_kwargs) + trainer.fit(self.autoencoder, datamodule) def predict(self, inputs, num_diffusion_iters = 50, verbose = True): '''inputs.shape = (batch_size, 1, 4, 256, 256)''' @@ -108,7 +107,7 @@ def from_config(cls, config): conditioner = AFNONowcastNetCascade(**config['conditioner']).to(device) denoiser = UNetModel(**config['denoiser']).to(device) ldm = LatentDiffusion(conditioner, denoiser) - ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler()) + ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler(), ema_config = config['ema']) sampler = PLMSSampler(denoiser) return cls(ldm_lightning, autoencoder, sampler) From 7c62c2dce27cae0151696f14f5b20bee25384549 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Mon, 2 Mar 2026 15:01:14 +0100 Subject: [PATCH 09/13] Forgot to update the README section on the main LDCast class and the fit_autoencoder and fit_ldm methods --- README.md | 71 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index a0d9ce4..9fc9d51 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,49 @@ autoenc_time_ratio = 4 # number of timesteps encoded in the autoencoder ``` Here, 4 consecutive radar images are encoded at once. +# Main LDCast class + +```python +from mlcast.models.ldcast.ldcast import LDCast +from mlcast.models.ldcast.diffusion.plms import PLMSSampler +sampler = PLMSSampler(denoiser) +ldcast = LDCast(ldm_lightning, autoencoder, sampler) +``` +The original config for the conditioner and the autoencoder is in `original_config.yaml`, and can be laoded with: +```python +config = 'original_config' +ldcast = LDCast.from_config(config) +``` +Here, `config` can also be a `dict`. +## Predictions + +Predictions can be produced with +```python +inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') +ldcast.predict(inputs) +``` + +## Loading/saving weights +To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): +```python +ldcast.load('/path/to/folder') +``` +To save in a folder: +```python +ldcast.save('/path/to/folder') +``` +## Training + +If `sampled_radar_dataset` is a `SampledRadarDataset` built with Gabriele's code (https://github.com/DSIP-FBK/ConvGRU-Ensemble/blob/main/convgru_ensemble/datamodule.py), the autoencoder can be trained with +```python +ldcast.fit_autoencoder(sampled_radar_dataset) +``` +and the ldm can be trained +```python +ldcast.fit_ldm(sampled_radar_dataset) +``` +Keyword arguments can be passed to the trainer and the dataloader through the `trainer_kwargs` and `dataloader_kwargs` keywords. + # Autoencoder ```python @@ -120,34 +163,6 @@ The EMA weights and parameters can be loaded with ldm_lighting.ema.load('ema.pt') ``` -# Main LDCast class - -```python -from mlcast.models.ldcast.ldcast import LDCast -from mlcast.models.ldcast.diffusion.plms import PLMSSampler -sampler = PLMSSampler(denoiser) -ldcast = LDCast(ldm_lightning, autoencoder, sampler) -``` -Predictions can be produced with -```python -inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') -ldcast.predict(inputs) -``` -To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): -```python -ldcast.load('/path/to/folder') -``` -To save in a folder: -```python -ldcast.save('/path/to/folder') -``` -The original config for the conditioner and the autoencoder is in `original_config.yaml`, and can be laoded with: -```python -config = 'original_config' -ldcast = LDCast.from_config(config) -``` -Here, `config` can also be a `dict`. - # TO DO The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. From a1bb23975b88289f02d63f95be64b04b2c731b69 Mon Sep 17 00:00:00 2001 From: vsc47929 Date: Fri, 6 Mar 2026 12:57:05 +0100 Subject: [PATCH 10/13] changes with respect to previous commit: - reorganized the documentation: everything was in the README.md of the file, and I created a docs folder with markdown files to organize the documentation a little better. The README now contains only general informations and things to do - reorganized the config.yaml file (which is not anymore named original_config.yaml but config.yaml) - in the NowcastingLightningModule in base.py, I added the possibility to add a scheduler for the learning rate and removed the n_timesteps argument of the forward method (I think it is not very appropriate to have it there, since not every subclass will need this argument, e.g. the Autoencoder subclass) - in autoencoder/autoencoder.py: the loss is now name AutoencoderLoss; the subclass of NowcastingLightningModule is now named Autoencoder while AutoencoderKLNet is a subclass of torch.nn.Module. I added the possibility to do antialiasing before feeding the autoencoder with samples (done by default by an Antialiaser object (in transforms/antialiasing.py)). I added also the possibility to create an instance of Autoencoder via a config dict, based on the original autoencoder architecture - in LatentDiffusion in diffusion/diffusion.py: I added the possibility to construct an instance from a config dict, based on the architecture of the corresponding part in the original code - in the code in general, ldm was an instance of LatentDiffusionNet and ldm_lightning was an instance of LatentDiffusion, and I changed this: ldm is now an instance of LatentDiffusion and net is the instance of LatentDiffusionNet, to be consistent with the .net attribute of NowcastingLightningModule - in ldcast/ldcast.py: I also added the possibility to build the LDCast class from dict config --- README.md | 220 ++---------------- config.yaml | 92 ++++++++ docs/autoencoder.md | 80 +++++++ docs/ldcast.md | 57 +++++ docs/ldm.md | 106 +++++++++ original_config.yaml | 26 --- src/mlcast/models/base.py | 29 ++- src/mlcast/models/ldcast/autoenc/autoenc.py | 39 +++- src/mlcast/models/ldcast/data.py | 14 +- .../models/ldcast/diffusion/diffusion.py | 36 ++- src/mlcast/models/ldcast/ldcast.py | 60 ++--- .../models/ldcast/transforms/antialiasing.py | 30 +++ 12 files changed, 505 insertions(+), 284 deletions(-) create mode 100644 config.yaml create mode 100644 docs/autoencoder.md create mode 100644 docs/ldcast.md create mode 100644 docs/ldm.md delete mode 100644 original_config.yaml create mode 100644 src/mlcast/models/ldcast/transforms/antialiasing.py diff --git a/README.md b/README.md index 9fc9d51..9b3ef6d 100644 --- a/README.md +++ b/README.md @@ -1,216 +1,42 @@ # MLCast implementation of LDCast -see main branch https://github.com/mlcast-community/mlcast for details. - -```python -future_timesteps = 20 -autoenc_time_ratio = 4 # number of timesteps encoded in the autoencoder -``` -Here, 4 consecutive radar images are encoded at once. - -# Main LDCast class - -```python -from mlcast.models.ldcast.ldcast import LDCast -from mlcast.models.ldcast.diffusion.plms import PLMSSampler -sampler = PLMSSampler(denoiser) -ldcast = LDCast(ldm_lightning, autoencoder, sampler) -``` -The original config for the conditioner and the autoencoder is in `original_config.yaml`, and can be laoded with: -```python -config = 'original_config' -ldcast = LDCast.from_config(config) -``` -Here, `config` can also be a `dict`. -## Predictions - -Predictions can be produced with -```python -inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') -ldcast.predict(inputs) -``` - -## Loading/saving weights -To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): -```python -ldcast.load('/path/to/folder') -``` -To save in a folder: -```python -ldcast.save('/path/to/folder') -``` -## Training - -If `sampled_radar_dataset` is a `SampledRadarDataset` built with Gabriele's code (https://github.com/DSIP-FBK/ConvGRU-Ensemble/blob/main/convgru_ensemble/datamodule.py), the autoencoder can be trained with -```python -ldcast.fit_autoencoder(sampled_radar_dataset) -``` -and the ldm can be trained -```python -ldcast.fit_ldm(sampled_radar_dataset) -``` -Keyword arguments can be passed to the trainer and the dataloader through the `trainer_kwargs` and `dataloader_kwargs` keywords. - -# Autoencoder - -```python -from mlcast.models.ldcast.autoenc.autoenc import AutoencoderKLNet, autoenc_loss -from mlcast.models.base import NowcastingLightningModule -autoencoder = NowcastingLightningModule(AutoencoderKLNet(), autoenc_loss()).to('cuda') -``` -The autoencoder is an instance of the NowcastingLightningModule. Training the autoencoder: -```python -# create fake data -inputs = torch.randn(2, 1, 4, 256, 256, device = 'cuda') - -with torch.no_grad(): - # the forward pass of the autoencoder returns also the encoding - # so [0] is needed to select the decoded part only - y = autoencoder(x, 4)[0] -batch = (x, y) - -import pytorch_lightning as L -trainer = L.Trainer() -trainer.fit(autoencoder, batch) -``` -The inputs tensors have shape `(batch_size, n_channels, number of input radar images,) + spatial shape`. In latent space, the tensors have shape `(batch_size, 32, n, 64, 64)`, where 32 is the `hidden_width` of the `autoencoder` and `n` is the number of consecutive encoded radar images divided by `autoenc_time_ratio` (set to 4). - -The original weights can be loaded directly as -```python -autoenc_weights_fn = '/path/to/original/autoencoder/weights' -autoencoder.net.load_state_dict(torch.load(autoenc_weights_fn)) -``` - -# Latent diffusion (= conditioner + denoiser) -The `LatentDiffusion` class is a `nn.Module` combining the conditioner and the denoiser. -```python -from mlcast.models.ldcast.diffusion.unet import UNetModel -from mlcast.models.ldcast.context.context import AFNONowcastNetCascade - -# setup forecaster -conditioner = AFNONowcastNetCascade( - 32, - train_autoenc=False, - output_patches=future_timesteps//autoenc_time_ratio, - cascade_depth=3, - embed_dim=128, - analysis_depth=4 -).to('cuda') - -# setup denoiser -from mlcast.models.ldcast.diffusion.unet import UNetModel -denoiser = UNetModel(in_channels=autoencoder.net.hidden_width, - model_channels=256, out_channels=autoencoder.net.hidden_width, - num_res_blocks=2, attention_resolutions=(1,2), - dims=3, channel_mult=(1, 2, 4), num_heads=8, - num_timesteps=future_timesteps//autoenc_time_ratio, - context_ch=[128, 256, 512] # context channels (= analysis_net.cascade_dims) - ).to('cuda') - -from mlcast.models.ldcast.diffusion.diffusion import LatentDiffusion -ldm = LatentDiffusion(conditioner, denoiser) -``` -The `LatentDiffusion` class has a forward pass: it takes the noise, the timesteps of the diffusion and the encoded inputs -```python -latent_inputs = autoencoder.net.encode(inputs) -noise = torch.randn(2, 32, 5, 64, 64, device = latent_inputs.device) -t = torch.tensor([2, 3], device = latent_inputs.device) -ldm((t, noise, latent_inputs)) -``` -The noise has to have the shape true radar images encoded in latent space. - -## LatentDiffusionLightning class and training of the ldm - -Create fake data to train the ldm: -```python -from torch.utils.data import TensorDataset -true = torch.randn(2, 1, future_timesteps, 256, 256, device = 'cuda') -dataset = TensorDataset(inputs, true) -``` -Create a ```Dataset``` which convert the samples in latent space with the autoencoder -``` -self.autoencoder.net.eval() -latent_dataset = LatentDataset(dataset, autoencoder.net) -dataloader = DataLoader(latent_dataset, batch_size=2) -``` -Put `ldm` in a `LightningModule` and train: -```python -from torch.nn import L1Loss -from mlcast.models.ldcast.diffusion.scheduler import Scheduler -from mlcast.models.ldcast.diffusion.diffusion import LatentDiffusionLightning - -ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler()) -trainer = L.Trainer() -trainer.fit(ldm_lightning, dataloader) -``` -## Loading the original weights -The original weights can not be directly loaded because the models are structured a little differently, but the original weights files can be converted with -```python -from mlcast.models.ldcast.original_weights import convert_original_weights -ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' -state_dict = convert_original_weights(ldm_weights_fn) -torch.save(state_dict['denoiser'], 'denoiser_state_dict.pt') -torch.save(state_dict['conditioner'], 'conditioner_state_dict.pt') -torch.save(state_dict['ema'], 'ema.pt') -``` -`state_dict['unmatched']` contains a `dict` with the elements that were not matched (should be empty). The weights for the conditioner and the denoiser (including the buffers for the scheduling) can then be loaded with -```python -conditioner.load_state_dict(torch.load('conditioner_state_dict.pt')) -denoiser.load_state_dict(torch.load('denoiser_state_dict.pt')) -``` -The EMA weights and parameters can be loaded with -```python -ldm_lighting.ema.load('ema.pt') -``` - -# TO DO +see main branch https://github.com/mlcast-community/mlcast for context. -The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. - -We might integrate this code within the Hugging Face Diffusers Library. - -It remains mainly to write code in the main LDCast class (in `ldcast.py`) - -It would be nice to rewrite the PLMS sampler, it is a little messy +## Code structure -# Basics on diffusion models +There is one main `LDCast` class, subclassing the `NowcastingModelBase` class. There are three main nets in LDCast: + - the autoencoder + - the conditioner + - the denoiser -See https://huggingface.co/blog/annotated-diffusion for some notations and formulas. +The `NowcastingLightningModule` is subclassed by the smaller composites of nets that should be trained at once. This gives two subclasses in this case: + - the autoencoder (encoder + decoder) has to be trained on its own, so there is one subclass of `NowcastingLightningModule` called `Autoencoder` + - the conditioner and the denoiser have to be trained together, so they are combined into one neural network (the `LatentDiffusionNet` class), whose training is handled by the `LatentDiffusion` subclass of the `NowcastingLightningModule` -During training, we start from a sample $x_0$ and create a series of samples $x_0, x_1, ..., x_T$ according to the formula (forward diffusion) +## Documentation -$$ -x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha_t}}\epsilon_t, \quad t = 1, ..., T -$$ +See `docs` folder for some documenation on the main `LDCast` class, on the autoencoder and on the latent diffusion part. -where $\epsilon_t \sim \mathcal{N}(0, 1)$. The constants $\bar{\alpha}_t$ are chosen, but they need the property that $\bar{\alpha}_t \to 0$ as $t \to T$, so that $x_T \sim \mathcal{N}(0, 1)$. These constants are computed with an algorithm called a scheduler. +## TO DO -From a given $x_t$, the model can either be trained to predict $x_0$, $\epsilon_t$ or the velocity $v_t$. The latter is defined as +reorganize the `LatentDiffusion` class ? for the moment, `LatentDiffusionNet.forward` is never called during inference because the inference process is quite different than in training (see `docs/ldm.md). It might be maybe a bit clearer to reorganize that by implementing explicitly different training and inference step methods in the `LatentDiffusion` class (that being said, `AutoencoderKLNet.forward` is never called either during inference) -$$ -v_t = \sqrt{\bar{\alpha}_t} \epsilon_t - \sqrt{1-\bar{\alpha}_t} x_0, -$$ - -which is equivalent to - -$$ -x_0 = \sqrt{\bar{\alpha}_t} x_t - \sqrt{1-\bar{\alpha}_t} v_t. -$$ +The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. -The model is also given the timestep $t$. The loss is computed by comparing the target quantity ($\epsilon_t$, $x_0$ or $v_t$) with the predicted quantity by the model. Choosing to predict $x_0$, $\epsilon_t$ or $v_t$ is conceptually equivalent, the difference is in the numerical properties of the scheme (like in ODE integration schemes). +We might integrate this code within the Hugging Face Diffusers Library. -The validation and test steps are done in the same way. +It remains mainly to write code in the main LDCast class (in `ldcast.py`) -So the model is trained to predict $x_0$ (or something from which we can compute $x_0$) from $x_T\sim \mathcal{N}(0, 1)$. But for large values of $t$, this prediction is actually quite bad. During actual prediction, the prediction is usually iteratively refined with sampler schemes. The idea is that, from the noise predicted based on $x_T$, the sampler scheme allows to compute $x_{T - \Delta t}$. The model is then used to predict the noise based on this estimation of $x_{T - \Delta t}$, and the sampler scheme allows to deduce $x_{T - 2\Delta t}$, etc. $\Delta t$ is usually taken of the order of 50 (while $T$ is usually 1000). +It would be nice to rewrite the PLMS sampler, it is a little messy -In Hugging Face Diffusers library, the scheduler and the sampler parts are often combined in one object called a scheduler, but the sampler part is only used during inference. +implement different parametrization than 'eps' -The original code was using antialiasing before feeding the samples to the model (at least during inference), I should add this +use ZarrDataModule and ZarrDataset ! -# The variational autoencoder +add the computation of the EMA loss during the ldm training, change the LDCast.predict method so that EMA weights are automatically used during inference -Source https://medium.com/@jpark7/finally-a-clear-derivation-of-the-vae-kl-loss-4cb38d2e47b3. +add in the code (and in the doc) the input and output shapes of the nets -Variational autoencoders encode the data through a normal distribution in latent space: each sample is represented by the mean and the standard deviation of the normal distribution. When decoding the sample, a new sample is created resembling the original sample, but is not quite the same. The degree to which we force the decoded samples to resemble the original ones is tuned by the `kl_weight` parameter of the KL loss function. +understand which parameters can be changed, which have to be adapted when others change -When using the encoded sample (for example to produce a condition with the conditioner), only the mean is used. In the original code, `autoencoder.decode` was returning a tuple `(mean, log_var)`, so that one had to select the mean with `autoencoder.decode(x)[0]`, which is not very clear. I replaced this by adding a keyword `return_log_var` in `autoencoder.decode`. +make the implementation of the `AutoencoderDataset` more efficient ? (see docs/autoencoder) \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..17ee1c9 --- /dev/null +++ b/config.yaml @@ -0,0 +1,92 @@ +# everything is from the original config, except for the batch size + +device: &device 'cuda' + +model: + autoencoder: + optimizer_class: "${as_class: 'torch.optim.AdamW'}" + optimizer_kwargs: + lr: 0.001 + betas: [0.5, 0.9] + weight_decay: 0.001 + lr_scheduler: + class: "${as_class: 'torch.optim.lr_scheduler.ReduceLROnPlateau'}" + kwargs: + patience: 3 + factor: 0.25 + verbose: True + extra: + monitor: 'val/rec_loss' + frequency: 1 + interval: 'epoch' + antialiaser: + use: True + kwargs: + device: *device + encoder: {} + decoder: {} + net_kwargs: + hidden_width: &autoencoder_hidden_width 32 + loss: + kl_weight: 0.01 + device: *device + + ldm: + device: *device + conditioner: + autoencoder_dim: *autoencoder_hidden_width + output_patches: &output_patches 5 + cascade_depth: 3 + embed_dim: 128 + analysis_depth: 4 + denoiser: + in_channels: *autoencoder_hidden_width + model_channels: 256 + out_channels: *autoencoder_hidden_width + num_res_blocks: 2 + attention_resolutions: [1, 2] + dims: 3 + channel_mult: [1, 2, 4] + num_heads: 8 + num_timesteps: *output_patches + context_ch: [128, 256, 512] # should be equal to conditioner.cascade_dims ? + ema: + use: True + kwargs: + store_device: 'cpu' + optimizer_class: "${as_class: 'torch.optim.AdamW'}" + optimizer_kwargs: + lr: 0.0001 + betas: [0.5, 0.9] + weight_decay: 0.001 + lr_scheduler: + class: "${as_class: 'torch.optim.lr_scheduler.ReduceLROnPlateau'}" + kwargs: + patience: 3 + factor: 0.25 + verbose: True + extra: + monitor: 'val/ema_loss' + frequency: 1 + interval: 'epoch' + scheduler: {} # diffusion scheduler + +dataloader: + batch_size: 32 + num_workers: 4 + persistent_workers: True + +trainer: + max_epochs: 200 + accelerator: 'gpu' + log_every_n_steps: 5 + callbacks: "${as_class: '[pl.callbacks.EarlyStopping(\"val/rec_loss\", patience=6, verbose=True, check_finite=False)]'}" + +sampled_radar_dataset: + zarr_path: 'test.zarr' + csv_path: 'mlcast-dataset-sampler/sampled_datacubes_2017-01-01-2017-02-01_24x256x256_3x16x16_10000.csv' + steps: 24 + augment: False + data_var: 'precip_intensity_EDK' + + diff --git a/docs/autoencoder.md b/docs/autoencoder.md new file mode 100644 index 0000000..3f8da7d --- /dev/null +++ b/docs/autoencoder.md @@ -0,0 +1,80 @@ +# Autoencoder documentation + +1. [Autoencoder class](#autoencoder-class) +2. [Tensor shapes](#tensor-shapes) +3. [Encoding and decoding](#encoding-and-decoding) +4. [Loading original weights](#loading-original-weights) +5. [Antialiasing](#antialiasing) +6. [Autoencoder training dataset](#autoencoder-training-dataset) +7. [Background on variational autoencoders](#background-on-variational-autoencoders) + +## Autoencoder class + +The `Autoencoder` class is a subclass of `NowcastingLightningModule`, and takes three arguments: + - the `net` (an instance of `AutoencoderKLNet` for LDCast), which is the neural network of the autoencoder, containing the decoder and the autoencoder + - the `loss` (an instance of `AutoencoderLoss` for LDCast) +Options for the optimizer and the learning rate scheduler can be passed as well. + +An instance can be created from a `dict` containing the configuration, based on the architecture of LDCast's autoencoder: +```python +from mlcast.models.ldcast.autoencoder.autoencoder import Autoencoder +autoencoder = Autoencoder.from_config(config) +``` + +## Tensor shapes + +The autoencoder encodes sequences of radar images (not image by image). The number of radar images encoded at once is given by `autoenc_time_ratio` and was set to 4 in the original code (and kept here). `Conv3d` layers are used for the encoding, so input tensors have shape +``` +(batch_size, n_channels, autoenc_time_ratio,) + spatial shape +``` +`n_channels` is always 1 for radar images. + +In latent space, the tensors have shape `(batch_size, 32, n, 64, 64)`, where 32 is the `hidden_width` of the `autoencoder` and `n` is the number of consecutive encoded radar images divided by `autoenc_time_ratio`. **I should still clarify which of these parameters can be changed freely, and how it affects other shapes. Can `autoencoder.net` encode a e.g. 8 images at once (in which case `n` is 2) ?** + + +## Encoding and decoding + +Doing the following +```python +import torch +inputs = torch.randn(1, 1, 4, 256, 256, device = 'cuda') # fake sample +autoencoder(inputs) +``` +is equivalent to `autoencoder.net(inputs)` and computes the whole forward pass through the `net` (encoding + decoding). To encode only, one needs to do +```python +autoencoder.net.encode(inputs). +``` +If `encoded` is an encoded sample, it can be decoded as +```python +autoencoder.net.decode(encoded) +``` + +## Laoding original weights + +The original weights can be loaded directly as +```python +autoenc_weights_fn = '/path/to/original/autoencoder/weights' +autoencoder.net.load_state_dict(torch.load(autoenc_weights_fn)) +``` + +## Antialiasing + +As in the original code, antialiasing is applied by default (by an Antialiaser object) to the inputs before being fed to the `net`. + +## Autoencoder training dataset + +Gabriele's code produces a dataset whose samples are sequences of `steps` images (`steps` is usually set to 24, to have 4 input images and 20 ground truth images). + +But the autoencoder needs samples which are sequences of only 4 images, so each sample in `SampledRadarDataset` needs to be divided in 6 samples. This is done by the `AutoencoderDataset`. Its samples are tuple `(x, y)` where `y = x` since we want the autoencoder to reconstruct the sequences. + +**The current implementation of this class is not the most efficient since, when going through the `AutoencoderDataset`, each sample of the `SampledRadarDataset` is loaded 6 times.** + +## Background on variational autoencoders + +The autoencoder used in LDCast is a variational autoencoder. Here is some background on that kind of autoencoder. + +Source https://medium.com/@jpark7/finally-a-clear-derivation-of-the-vae-kl-loss-4cb38d2e47b3. + +Variational autoencoders encode the data through a normal distribution in latent space: each sample is represented by the mean and the standard deviation of the normal distribution. When decoding the sample, a new sample is created resembling the original sample, but is not quite the same. The degree to which we force the decoded samples to resemble the original ones is tuned by the `kl_weight` parameter of the KL loss function. + +When using the encoded sample (for example to produce a condition with the conditioner), only the mean is used. In the original code, `autoencoder.net.decode` was returning a tuple `(mean, log_var)`, so that one had to select the mean with `autoencoder.net.decode(x)[0]`, which is not very clear. I replaced this by adding a keyword `return_log_var` in `autoencoder.net.decode`. \ No newline at end of file diff --git a/docs/ldcast.md b/docs/ldcast.md new file mode 100644 index 0000000..a7554db --- /dev/null +++ b/docs/ldcast.md @@ -0,0 +1,57 @@ +# Main LDCast class documentation + +1. [LDCast class](#ldcast-class) +2. [Inference](#inference) +3. [Loading/saving weights](#loading/saving-weights) +4. [Training](#training) + +## LDCast class + +The `LDCast` class is a subclass of `NowcastingModelBase` and takes three arguments + - the `ldm` (typically, an instance of `LatentDiffusion`) + - the `autoencoder` (typically, an instance of `Autoencoder`) + - the `sampler` + +An instance can be created from a `dict` containing the configuration, based on the architecture of LDCast: +```python +from mlcast.models.ldcast.ldcast import LDCast +ldcast = LDCast.from_config(config) +``` +A config very close to what was used in the original code is in 'config.yaml'. It should be loaded as +```python +from omegaconf import OmageConf +OmegaConf.register_new_resolver("as_class", lambda class_name: eval(class_name)) +config = OmegaConf.load('config.yaml') +``` + +## Inference + +Predictions can be produced with +```python +import torch +inputs = torch.randn(1, 1, 4, 256, 256, device = 'cuda') # fake data +ldcast.predict(inputs) +``` +**Do not use for the moment, since the EMA weights (if used) are not automatically used for inference** + +## Loading/saving weights +To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): +```python +ldcast.load('/path/to/folder') +``` +To save in a folder: +```python +ldcast.save('/path/to/folder') +``` + +## Training + +If `sampled_radar_dataset` is a `SampledRadarDataset` built with Gabriele's code (https://github.com/DSIP-FBK/ConvGRU-Ensemble/blob/main/convgru_ensemble/datamodule.py), the autoencoder can be trained with +```python +ldcast.fit_autoencoder(sampled_radar_dataset) +``` +and the ldm can be trained +```python +ldcast.fit_ldm(sampled_radar_dataset) +``` +Keyword arguments can be passed to the trainer and the dataloader through the `trainer_kwargs` and `dataloader_kwargs` keywords. \ No newline at end of file diff --git a/docs/ldm.md b/docs/ldm.md new file mode 100644 index 0000000..18fab0f --- /dev/null +++ b/docs/ldm.md @@ -0,0 +1,106 @@ +# Latent diffusion documentation + +1. [LatentDiffusion class](#latentdiffusion-class) +2. [LatentDiffusionNet](#latentdiffusionnet) +3. [Training vs inference modes](#training-vs-inference-modes) +4. [Loading original weights](#loading-original-weights) +5. [Exponential Moving Average](#exponential-moving-average) +6. [LatentDataset](#latentdataset) +7. [Background on diffusion models](#background-on-diffusion-models) + +## LatentDiffusion class + +The `LatentDiffusion` class is a subclass of `NowcastingLightningModule` and takes three arguments + - the `net` (typically, an instance of `LatentDiffusionNet`) + - the `loss` (a `torch.nn.MSELoss` for LDCast) + - the `scheduler`, scheduling the diffusion process +Options for the optimizer and the learning rate scheduler can be passed as well. + +An instance can be created from a `dict` containing the configuration, based on the corresponding part of the architecture of LDCast: +```python +from mlcast.models.ldcast.diffusion.diffusion import LatentDiffusion +ldm = LatentDiffusion.from_config(config) +``` + +## LatentDiffusionNet + +The `LatentDiffusionNet` class combines two elements: the `conditioner` and the `denoiser`. + +The `denoiser` takes some noise and performs the backward diffusion process to produce samples (in latent space). Since we want a nowcast based on input images, the denoiser needs with some condition based on the input images. + +The role of the `conditioner` is to provide this condition to the denoiser. It takes input images (encoded in latent space) and returns a condition (also called context ?) to help the denoiser to produce relevant predictions. The `conditioner` could also be called a forecaster. + +In the original LDCast code, the `conditioner` was called `analysis_net` and the `denoiser` was called `denoiser` or `model`. + +As in the original code, the `conditioner` is an instance of `AFNONowcastNet` and the `denoiser` is an instance of `UNetModel`. + +**check the output shape of the conditioner and the input shape of the denoiser** + +## Training vs inference modes + +The `net` combines the `conditioner` and the `denoiser` in the way they should be trained (i.e. the `denoiser` is always called after the `conditioner`). The inference process is however different: once the input has been converted in latent space, it is passed to the conditioner to produce a condition; then, the denoiser is repeatedly called to iteratively denoise a completely noisy image according to a scheme defined by a sampler (see [Background on diffusion models](background-on-diffusion-models)). During the inference, `net.forward` is thus never called. + +## Loading original weights + +The structure of this part of LDCast has changed a little with respect to the original code, so the weights need to be reorganized before being loaded. The `convert_original_weights` function does this: +```python +from mlcast.models.ldcast.original_weights import convert_original_weights +ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' +state_dict = convert_original_weights(ldm_weights_fn) +``` +`state_dict`is a `dict` whose keys are `conditioner`, `denoiser`, `ema` and `unmatched`. `state_dict['unmatched']` contains the elements which were not matched in `convert_original_weights` (should be empty). The weights of the conditioner, of the denoiser and the EMA can then be loaded as +``` +ldm.net.conditioner.load_state_dict(state_dict['conditioner']) +ldm.net.denoiser.load_state_dict(state_dict['denoiser']) +ldm.ema.load(state_dict['ema']) +``` + +## Exponential Moving Average + +The original code included an Exponential Moving Average (EMA) of the weights of the denoiser. This seems quite common for diffusion models. + +The idea is two versions of the weights of the models: + 1. the usual weights, which are updated through the optimization of the training loss + 2. the EMA weights, which are computed as an average of the last values of the usual weights + +The average is exponentially weighted, so that the latest weights are more taken into account. + +This is useful because the usual weights are quite unstable, while the EMA weights are more stable because of the average. + +The EMA is switched on by default, and can be switched off when creating the `LatentDiffusion` class (setting the keyword `ema_config` to `{'use': False}`). When switched on, the original weights have to be loaded into the model for the computation of the training loss, but the loss with the EMA weights should be computed during validation (and also after each training step ?). The EMA weights should also be used during inference. Everything is handled automatically if a `pl.Trainer` is used (through lightning hooks on the `LatentDiffusion` class). This means that one also needs to use a `pl.Trainer` at inference. + +## LatentDataset + +Gabriele's `SampledRadarDataset` returns sequence of `steps` radar images (with `steps` usually set to 24). However, the training of the `ldm` requires samples in latent space. This is handled by the `LatentDataset` class: it returns samples `(x, y)` with `x` being the latent encoding of the input radar images, and `y` being the latent encoding of the radar images to predict. The `LatentDataset` thus needs the trained `autoencoder.net` ! + +## Background on diffusion models + +See https://huggingface.co/blog/annotated-diffusion for some notations and formulas. + +During training, we start from a sample $x_0$ and create a series of samples $x_0, x_1, ..., x_T$ according to the formula (forward diffusion) + +$$ +x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha_t}}\epsilon_t, \quad t = 1, ..., T +$$ + +where $\epsilon_t \sim \mathcal{N}(0, 1)$. The constants $\bar{\alpha}_t$ are chosen, but they need the property that $\bar{\alpha}_t \to 0$ as $t \to T$, so that $x_T \sim \mathcal{N}(0, 1)$. These constants are computed with an algorithm called a scheduler. + +From a given $x_t$, the model can either be trained to predict $x_0$, $\epsilon_t$ or the velocity $v_t$. The latter is defined as + +$$ +v_t = \sqrt{\bar{\alpha}_t} \epsilon_t - \sqrt{1-\bar{\alpha}_t} x_0, +$$ + +which is equivalent to + +$$ +x_0 = \sqrt{\bar{\alpha}_t} x_t - \sqrt{1-\bar{\alpha}_t} v_t. +$$ + +The model is also given the timestep $t$. The loss is computed by comparing the target quantity ($\epsilon_t$, $x_0$ or $v_t$) with the predicted quantity by the model. Choosing to predict $x_0$, $\epsilon_t$ or $v_t$ is conceptually equivalent, the difference is in the numerical properties of the scheme (like in ODE integration schemes). + +The validation and test steps are done in the same way. + +So the model is trained to predict $x_0$ (or something from which we can compute $x_0$) from $x_T\sim \mathcal{N}(0, 1)$. But for large values of $t$, this prediction is actually quite bad. During actual prediction, the prediction is usually iteratively refined with sampler schemes. The idea is that, from the noise predicted based on $x_T$, the sampler scheme allows to compute $x_{T - \Delta t}$. The model is then used to predict the noise based on this estimation of $x_{T - \Delta t}$, and the sampler scheme allows to deduce $x_{T - 2\Delta t}$, etc. $\Delta t$ is usually taken of the order of 50 (while $T$ is usually 1000). + +In Hugging Face Diffusers library, the scheduler and the sampler parts are often combined in one object called a scheduler, but the sampler part is only used during inference. \ No newline at end of file diff --git a/original_config.yaml b/original_config.yaml deleted file mode 100644 index 8807382..0000000 --- a/original_config.yaml +++ /dev/null @@ -1,26 +0,0 @@ -autoencoder_hidden_width: &autoencoder_hidden_width 32 - -conditioner: - autoencoder_dim: *autoencoder_hidden_width - output_patches: &output_patches 5 - cascade_depth: 3 - embed_dim: 128 - analysis_depth: 4 - -denoiser: - in_channels: *autoencoder_hidden_width - model_channels: 256 - out_channels: *autoencoder_hidden_width - num_res_blocks: 2 - attention_resolutions: [1, 2] - dims: 3 - channel_mult: [1, 2, 4] - num_heads: 8 - num_timesteps: *output_patches - context_ch: [128, 256, 512] # should be equal to conditioner.cascade_dims ? - -ema: - use: True - kwargs: - store_device: 'cpu' - \ No newline at end of file diff --git a/src/mlcast/models/base.py b/src/mlcast/models/base.py index 1c004c3..a4fe87b 100644 --- a/src/mlcast/models/base.py +++ b/src/mlcast/models/base.py @@ -113,15 +113,16 @@ def __init__( loss: nn.Module, optimizer_class: Any | None = None, optimizer_kwargs: dict | None = None, - **kwargs: Any, + lr_scheduler_config: dict | None = None, ): super().__init__() self.save_hyperparameters(ignore=["net", "loss"]) self.net = net self.loss = loss self.optimizer_class = torch.optim.Adam if optimizer_class is None else optimizer_class + self.lr_scheduler_config = lr_scheduler_config - def forward(self, x: torch.Tensor, n_timesteps: int) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the model. Args: x: Input tensor with shape (batch, seq_len, channels, height, width) @@ -129,13 +130,14 @@ def forward(self, x: torch.Tensor, n_timesteps: int) -> torch.Tensor: Returns: Output tensor with shape (batch, n_timesteps, channels, height, width) """ - return self.net(x, n_timesteps) # Assuming net is a callable model + return self.net(x) # Assuming net is a callable model def training_logic(self, batch, batch_idx): """Can be overwritten if needed""" x, y = batch - predictions = self.forward(x, n_timesteps = 4) + predictions = self.forward(x) loss = self.loss(predictions, y) + return loss def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: @@ -195,5 +197,22 @@ def configure_optimizers(self) -> torch.optim.Optimizer: Returns: Optimizer instance to use for training + + following https://lightning.ai/docs/pytorch/stable/common/optimization.html for the scheduler part """ - return self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) + optimizer = self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) + + if self.lr_scheduler_config is None: + return optimizer + + else: + cls = self.lr_scheduler_config['class'] + kwargs = self.lr_scheduler_config['kwargs'] + extra = self.lr_scheduler_config['extra'] + + # up to here, extra might be a omegaconf dictconfig object, but we need add the scheduler to it + # so convert it to a dict + extra = {k: v for k, v in extra.items()} + extra['scheduler'] = cls(optimizer, **kwargs) + + return {'optimizer': optimizer, 'lr_scheduler': extra} diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py index d673a2a..d14a244 100644 --- a/src/mlcast/models/ldcast/autoenc/autoenc.py +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -1,9 +1,10 @@ # from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py -import pytorch_lightning as pl import torch from torch import nn from .encoder import SimpleConvEncoder, SimpleConvDecoder +from ...base import NowcastingLightningModule +from ..transforms.antialiasing import Antialiaser from ..distributions import ( ensemble_nll_normal, @@ -11,7 +12,7 @@ sample_from_standard_normal, ) -class autoenc_loss(nn.Module): +class AutoencoderLoss(nn.Module): def __init__(self, kl_weight = 0.01): super().__init__() self.kl_weight = kl_weight @@ -26,13 +27,37 @@ def forward(self, predictions, y): return {'total_loss': total_loss, 'rec_loss': rec_loss, 'kl_loss': kl_loss} +class Autoencoder(NowcastingLightningModule): + def __init__(self, net, loss, antialiaser = Antialiaser(), **kwargs): + super().__init__(net, loss, **kwargs) + self.antialiaser = antialiaser -class AutoencoderKLNet(pl.LightningModule): + def forward(self, x): + if self.antialiaser is not None: + x = self.antialiaser(x) + return self.net(x) + + @classmethod + def from_config(cls, config): + + antialiaser = Antialiaser(**config['antialiaser']['kwargs']) if config['antialiaser']['use'] else None + encoder = SimpleConvEncoder(**config['encoder']) + decoder = SimpleConvDecoder(**config['decoder']) + net = AutoencoderKLNet(encoder = encoder, decoder = decoder, **config['net_kwargs']) + loss = AutoencoderLoss(**config['loss']) + + return cls(net, loss, + antialiaser = antialiaser, + optimizer_class = config['optimizer_class'], + optimizer_kwargs = config['optimizer_kwargs'], + lr_scheduler_config = config['lr_scheduler'] + ).to(config['device']) + +class AutoencoderKLNet(nn.Module): def __init__( self, encoder = SimpleConvEncoder(), decoder = SimpleConvDecoder(), - kl_weight=0.01, encoded_channels=64, hidden_width=32, **kwargs, @@ -44,11 +69,11 @@ def __init__( self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, kernel_size=1) self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, kernel_size=1) self.log_var = nn.Parameter(torch.zeros(size=())) - self.kl_weight = kl_weight - def encode(self, x, return_log_var = False): + def encode(self, x, return_log_var = False): if len(x.shape) < 5: x = x[None] + h = self.encoder(x) (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) if return_log_var: @@ -64,7 +89,7 @@ def decode(self, z): dec = self.decoder(z) return dec - def forward(self, x, n_timesteps, sample_posterior=True): + def forward(self, x, sample_posterior=True): (mean, log_var) = self.encode(x, return_log_var = True) if sample_posterior: z = sample_from_standard_normal(mean, log_var) diff --git a/src/mlcast/models/ldcast/data.py b/src/mlcast/models/ldcast/data.py index 447d853..a2a08b8 100644 --- a/src/mlcast/models/ldcast/data.py +++ b/src/mlcast/models/ldcast/data.py @@ -1,6 +1,7 @@ from torch.utils.data import Dataset, random_split, DataLoader import torch import pytorch_lightning as pl +from tqdm import tqdm class LatentDataset(Dataset): def __init__(self, sampled_radar_dataset, autoencoder, autoenc_time_ratio = 4): @@ -17,7 +18,7 @@ def __len__(self): def __getitem__(self, idx): with torch.no_grad(): - sequence = self.dataset[idx]['data'] + sequence = self.dataset[idx] x = sequence[:self.autoenc_time_ratio] y = sequence[self.autoenc_time_ratio:] @@ -54,7 +55,7 @@ def __getitem__(self, idx): ''' index_srd = idx // self.samples_ratio index_in_srd_sample = idx - index_srd * self.samples_ratio - x = self.srd[index_srd]['data'].reshape(self.samples_ratio, self.autoenc_time_ratio, 1, self.srd.w, self.srd.h)[index_in_srd_sample] + x = self.srd[index_srd].reshape(self.samples_ratio, self.autoenc_time_ratio, 1, self.srd.w, self.srd.h)[index_in_srd_sample] # for some reason, Gabriele put the time axis before the channel axis, change this x = x.swapaxes(0, 1).to(self.device) @@ -85,4 +86,11 @@ def val_dataloader(self): return DataLoader(self.val_dataset, shuffle = False, **self.dataloader_kwargs) def test_dataloader(self): - return DataLoader(self.test_dataset, shuffle = False, **self.dataloader_kwargs) \ No newline at end of file + return DataLoader(self.test_dataset, shuffle = False, **self.dataloader_kwargs) + +def load_in_memory(dataset): + ds = [] + for i in tqdm(range(len(dataset)), desc = 'Loading data in memory'): + # append the sample and create an extra dimension (to be the batch dimension) + ds.append(dataset[i][None].to('cpu')) + return torch.cat(ds, axis = 0) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index fa18de6..328480f 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -1,21 +1,18 @@ import torch import torch.nn as nn -import pytorch_lightning as L -from typing import Any -import contextlib from ...base import NowcastingLightningModule import numpy as np from .utils import extract_into_tensor -from .ema import EMA +from .ema import EMA -class LatentDiffusion(nn.Module): +class LatentDiffusionNet(nn.Module): def __init__(self, conditioner, denoiser, parametrization = "eps"): super().__init__() self.conditioner = conditioner self.denoiser = denoiser self.parametrization = parametrization - def forward(self, x, n_timesteps = 4): + def forward(self, x): # during training, noisy should be x_t # during inference, noisy should be noise t, noisy, latent_inputs = x @@ -28,9 +25,9 @@ def forward(self, x, n_timesteps = 4): return out -class LatentDiffusionLightning(NowcastingLightningModule): - def __init__(self, ldm, loss, scheduler, ema_config = {'use': True}): - super().__init__(ldm, loss) +class LatentDiffusion(NowcastingLightningModule): + def __init__(self, net, loss, scheduler, ema_config = {'use': True}, **kwargs): + super().__init__(net, loss, **kwargs) self.scheduler = scheduler # register the schedules (i.e. the values of alpha, beta etc). @@ -108,3 +105,24 @@ def on_predict_start(self): def on_predict_end(self): if hasattr(self, 'ema'): self.ema.restore() + + @classmethod + def from_config(cls, config): + + from .scheduler import Scheduler + from .unet import UNetModel + from ..context.context import AFNONowcastNetCascade + + conditioner = AFNONowcastNetCascade(**config['conditioner']) + denoiser = UNetModel(**config['denoiser']) + net = LatentDiffusionNet(conditioner, denoiser) + loss = nn.MSELoss() + scheduler = Scheduler(**config['scheduler']) + + return cls(net, loss, scheduler, + ema_config = config['ema'], + optimizer_class = config['optimizer_class'], + optimizer_kwargs = config['optimizer_kwargs'], + lr_scheduler_config = config['lr_scheduler'] + ).to(config['device']) + diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index 8abccf6..bfb7b9e 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -2,17 +2,16 @@ from ..base import NowcastingModelBase import pytorch_lightning as L -from .data import LatentDataset, AutoencoderDataset, DataModule -from torch.utils.data import DataLoader +from .data import LatentDataset, AutoencoderDataset, DataModule, load_in_memory import torch import contextlib -from torch.utils.data import TensorDataset +torch.multiprocessing.set_start_method('spawn') class LDCast(NowcastingModelBase): - def __init__(self, ldm_lightning, autoencoder, sampler): + def __init__(self, ldm, autoencoder, sampler): super().__init__() - self.ldm_lightning = ldm_lightning + self.ldm = ldm self.autoencoder = autoencoder self.sampler = sampler @@ -29,16 +28,16 @@ def fit(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {} def fit_ldm(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): self.autoencoder.net.eval() - self.ldm_lightning.net.train() + self.ldm.net.train() dataset = LatentDataset(sampled_radar_dataset, self.autoencoder.net) datamodule = DataModule(dataset, **dataloader_kwargs) trainer = L.Trainer(**trainer_kwargs) - trainer.fit(self.ldm_lightning, datamodule) + trainer.fit(self.ldm, datamodule) def fit_autoencoder(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): self.autoencoder.net.train() - + dataset = AutoencoderDataset(sampled_radar_dataset) datamodule = DataModule(dataset, **dataloader_kwargs) trainer = L.Trainer(**trainer_kwargs) @@ -50,7 +49,7 @@ def predict(self, inputs, num_diffusion_iters = 50, verbose = True): assert False, 'prediction should be implemented with a trainer, to take into account the switches of ema weights for example''' latent_inputs = self.autoencoder.net.encode(inputs) - condition = self.ldm_lightning.net.conditioner(latent_inputs) + condition = self.ldm.net.conditioner(latent_inputs) gen_shape = (32, 5, 256//4, 256//4) batch_size = len(latent_inputs) @@ -65,52 +64,39 @@ def predict(self, inputs, num_diffusion_iters = 50, verbose = True): return s - latent_pred = self.ldm_lightning(latent_inputs) + latent_pred = self.ldm(latent_inputs) return self.autoencoder.net.decode(latent_pred) def save(self, folder): torch.save(self.autoencoder.net.state_dict(), f'{folder}/autoencoder.pt') - torch.save(self.ldm_lightning.net.conditioner.state_dict(), f'{folder}/conditioner.pt') - torch.save(self.ldm_lightning.net.denoiser.state_dict(), f'{folder}/denoiser.pt') + torch.save(self.ldm.net.conditioner.state_dict(), f'{folder}/conditioner.pt') + torch.save(self.ldm.net.denoiser.state_dict(), f'{folder}/denoiser.pt') - if hasattr(self.ldm_lightning, 'ema'): - self.ldm_lightning.ema.save(f'{folder}/ema.pt') + if hasattr(self.ldm, 'ema'): + self.ldm.ema.save(f'{folder}/ema.pt') def load(self, folder): self.autoencoder.net.load_state_dict(torch.load(f'{folder}/autoencoder.pt')) - self.ldm_lightning.net.conditioner.load_state_dict(torch.load(f'{folder}/conditioner.pt')) - self.ldm_lightning.net.denoiser.load_state_dict(torch.load(f'{folder}/denoiser.pt')) + self.ldm.net.conditioner.load_state_dict(torch.load(f'{folder}/conditioner.pt')) + self.ldm.net.denoiser.load_state_dict(torch.load(f'{folder}/denoiser.pt')) - if hasattr(self.ldm_lightning, 'ema'): - self.ldm_lightning.ema.load(f'{folder}/ema.pt') + if hasattr(self.ldm, 'ema'): + self.ldm.ema.load(f'{folder}/ema.pt') @classmethod def from_config(cls, config): - if isinstance(config, str): - import yaml - with open(config, 'r') as file: - config = yaml.safe_load(file) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - from .autoenc.autoenc import AutoencoderKLNet, autoenc_loss - from ..base import NowcastingLightningModule - from .diffusion.unet import UNetModel - from .context.context import AFNONowcastNetCascade - from .diffusion.diffusion import LatentDiffusion, LatentDiffusionLightning - from torch.nn import L1Loss - from .diffusion.scheduler import Scheduler + from .autoenc.autoenc import Autoencoder + from .diffusion.diffusion import LatentDiffusion from .diffusion.plms import PLMSSampler - autoencoder = NowcastingLightningModule(AutoencoderKLNet(), autoenc_loss()).to(device) - conditioner = AFNONowcastNetCascade(**config['conditioner']).to(device) - denoiser = UNetModel(**config['denoiser']).to(device) - ldm = LatentDiffusion(conditioner, denoiser) - ldm_lightning = LatentDiffusionLightning(ldm, L1Loss(), Scheduler(), ema_config = config['ema']) - sampler = PLMSSampler(denoiser) + autoencoder = Autoencoder.from_config(config['autoencoder']) + ldm = LatentDiffusion.from_config(config['ldm']) + sampler = PLMSSampler(ldm.net.denoiser) - return cls(ldm_lightning, autoencoder, sampler) + return cls(ldm, autoencoder, sampler) diff --git a/src/mlcast/models/ldcast/transforms/antialiasing.py b/src/mlcast/models/ldcast/transforms/antialiasing.py new file mode 100644 index 0000000..bf0a70c --- /dev/null +++ b/src/mlcast/models/ldcast/transforms/antialiasing.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +class Antialiaser: + def __init__(self, device = 'cpu'): + + self.device = device + + # construct the kernel (symmetric in both directions), shape = (5, 5) + (x, y) = torch.meshgrid(torch.arange(-2, 3), torch.arange(-2, 3), indexing = 'ij') + kernel = torch.exp(-0.5*(x**2+y**2)/(0.5**2)) + kernel /= kernel.sum() + kernel = kernel.to(self.device) + + # the convolution will be done on x (shape = (1, autoenc_time_ratio) + spatial_shape) + # so treat the autoenc_time_ratio as one axis of the convolution -> Conv3d + # but we do not want to convolve on this axis, so use kernl_size = 1 and padding = 0 on this axis + self.conv = nn.Conv3d(1, 1, bias = False, kernel_size = (1, 5, 5), padding = (0, 2, 2)) + + # set the weights to be those of the kernel + self.conv.weight = nn.Parameter(kernel[None, None, None], requires_grad = False) + + def __call__(self, x): + + # factor is 1 in the bulk of x, but is greater than 1 near the border to accounr that less values were used in the convolution + # recomputed each time because the image shape could change + factor = self.conv(torch.ones(x.shape, device = self.device)) + factor = 1./factor + + return self.conv(x) * factor \ No newline at end of file From ae4ec11b0eddf441157bc925ba123de6a11887bf Mon Sep 17 00:00:00 2001 From: vsc47929 Date: Fri, 6 Mar 2026 13:44:39 +0100 Subject: [PATCH 11/13] changed .gitignore to remove .ipynb_checkpoints files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index a663386..97c46a7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ dist/ src/mlcast/modules/.ipynb_checkpoints/ src/mlcast/models/ldcast/context/.ipynb_checkpoints/ src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ +src/mlcast/models/ldcast/.ipynb_checkpoints/ +src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/ +src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/ From 1129729f33aa3f6aaff95ee51785c2f77bdc0d61 Mon Sep 17 00:00:00 2001 From: vsc47929 Date: Fri, 6 Mar 2026 13:48:56 +0100 Subject: [PATCH 12/13] removed .ipynb_checkpoints files --- .../distributions-checkpoint.py | 31 -- .../.ipynb_checkpoints/utils-checkpoint.py | 30 -- .../.ipynb_checkpoints/encoder-checkpoint.py | 59 --- .../.ipynb_checkpoints/afno-checkpoint.py | 350 ------------- .../attention-checkpoint.py | 106 ---- .../.ipynb_checkpoints/resnet-checkpoint.py | 91 ---- .../.ipynb_checkpoints/unet-checkpoint.py | 492 ------------------ .../.ipynb_checkpoints/utils-checkpoint.py | 249 --------- 8 files changed, 1408 deletions(-) delete mode 100644 src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py delete mode 100644 src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py deleted file mode 100644 index 3dcb183..0000000 --- a/src/mlcast/models/ldcast/.ipynb_checkpoints/distributions-checkpoint.py +++ /dev/null @@ -1,31 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/distributions.py - -import numpy as np -import torch - - -def kl_from_standard_normal(mean, log_var): - kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) - return kl.mean() - - -def sample_from_standard_normal(mean, log_var, num=None): - std = (0.5 * log_var).exp() - shape = mean.shape - if num is not None: - # expand channel 1 to create several samples - shape = shape[:1] + (num,) + shape[1:] - mean = mean[:, None, ...] - std = std[:, None, ...] - return mean + std * torch.randn(shape, device=mean.device) - - -def ensemble_nll_normal(ensemble, sample, epsilon=1e-5): - mean = ensemble.mean(dim=1) - var = ensemble.var(dim=1, unbiased=True) + epsilon - logvar = var.log() - - diff = sample[:, None, ...] - mean - logtwopi = np.log(2 * np.pi) - nll = (logtwopi + logvar + diff.square() / var).mean() - return nll \ No newline at end of file diff --git a/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py b/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py deleted file mode 100644 index f38bd29..0000000 --- a/src/mlcast/models/ldcast/.ipynb_checkpoints/utils-checkpoint.py +++ /dev/null @@ -1,30 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/utils.py - -import torch -from torch import nn - - -def normalization(channels, norm_type="group", num_groups=32): - if norm_type == "batch": - return nn.BatchNorm3d(channels) - elif norm_type == "group": - return nn.GroupNorm(num_groups=num_groups, num_channels=channels) - elif (not norm_type) or (norm_type.tolower() == "none"): - return nn.Identity() - else: - raise NotImplementedError(norm) - - -def activation(act_type="swish"): - if act_type == "swish": - return nn.SiLU() - elif act_type == "gelu": - return nn.GELU() - elif act_type == "relu": - return nn.ReLU() - elif act_type == "tanh": - return nn.Tanh() - elif not act_type: - return nn.Identity() - else: - raise NotImplementedError(act_type) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py b/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py deleted file mode 100644 index 157af11..0000000 --- a/src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/encoder-checkpoint.py +++ /dev/null @@ -1,59 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py - -import numpy as np -import torch.nn as nn - -from ..blocks.resnet import ResBlock3D -from ..utils import activation, normalization - - -class SimpleConvEncoder(nn.Sequential): - def __init__(self, in_dim=1, levels=2, min_ch=64): - sequence = [] - channels = np.hstack([ - in_dim, - (8**np.arange(1,levels+1)).clip(min=min_ch) - ]) - - for i in range(levels): - in_channels = int(channels[i]) - out_channels = int(channels[i+1]) - res_kernel_size = (3,3,3) if i == 0 else (1,3,3) - res_block = ResBlock3D( - in_channels, out_channels, - kernel_size=res_kernel_size, - norm_kwargs={"num_groups": 1} - ) - sequence.append(res_block) - downsample = nn.Conv3d(out_channels, out_channels, - kernel_size=(2,2,2), stride=(2,2,2)) - sequence.append(downsample) - in_channels = out_channels - - super().__init__(*sequence) - - -class SimpleConvDecoder(nn.Sequential): - def __init__(self, in_dim=1, levels=2, min_ch=64): - sequence = [] - channels = np.hstack([ - in_dim, - (8**np.arange(1,levels+1)).clip(min=min_ch) - ]) - - for i in reversed(list(range(levels))): - in_channels = int(channels[i+1]) - out_channels = int(channels[i]) - upsample = nn.ConvTranspose3d(in_channels, in_channels, - kernel_size=(2,2,2), stride=(2,2,2)) - sequence.append(upsample) - res_kernel_size = (3,3,3) if (i == 0) else (1,3,3) - res_block = ResBlock3D( - in_channels, out_channels, - kernel_size=res_kernel_size, - norm_kwargs={"num_groups": 1} - ) - sequence.append(res_block) - in_channels = out_channels - - super().__init__(*sequence) diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py deleted file mode 100644 index 84c73d0..0000000 --- a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/afno-checkpoint.py +++ /dev/null @@ -1,350 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/afno.py - -#reference: https://github.com/NVlabs/AFNO-transformer -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - -from .attention import TemporalAttention - -class Mlp(nn.Module): - def __init__( - self, - in_features, hidden_features=None, out_features=None, - act_layer=nn.GELU, drop=0.0 - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class AFNO2D(nn.Module): - def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): - super().__init__() - assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" - - self.hidden_size = hidden_size - self.sparsity_threshold = sparsity_threshold - self.num_blocks = num_blocks - self.block_size = self.hidden_size // self.num_blocks - self.hard_thresholding_fraction = hard_thresholding_fraction - self.hidden_size_factor = hidden_size_factor - self.scale = 0.02 - - self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) - self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) - self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) - self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) - - def forward(self, x): - bias = x - - dtype = x.dtype - x = x.float() - B, H, W, C = x.shape - - x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") - x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) - - o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) - o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) - o2_real = torch.zeros(x.shape, device=x.device) - o2_imag = torch.zeros(x.shape, device=x.device) - - total_modes = H // 2 + 1 - kept_modes = int(total_modes * self.hard_thresholding_fraction) - - o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( - torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ - torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ - self.b1[0] - ) - - o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( - torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ - torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ - self.b1[1] - ) - - o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( - torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ - torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ - self.b2[0] - ) - - o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( - torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ - torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ - self.b2[1] - ) - - x = torch.stack([o2_real, o2_imag], dim=-1) - x = F.softshrink(x, lambd=self.sparsity_threshold) - x = torch.view_as_complex(x) - x = x.reshape(B, H, W // 2 + 1, C) - x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho") - x = x.type(dtype) - - return x + bias - - -class Block(nn.Module): - def __init__( - self, - dim, - mlp_ratio=4., - drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - double_skip=True, - num_blocks=8, - sparsity_threshold=0.01, - hard_thresholding_fraction=1.0 - ): - super().__init__() - self.norm1 = norm_layer(dim) - self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - self.double_skip = double_skip - - def forward(self, x): - residual = x - x = self.norm1(x) - x = self.filter(x) - - if self.double_skip: - x = x + residual - residual = x - - x = self.norm2(x) - x = self.mlp(x) - x = x + residual - return x - - - -class AFNO3D(nn.Module): - def __init__( - self, hidden_size, num_blocks=8, sparsity_threshold=0.01, - hard_thresholding_fraction=1, hidden_size_factor=1 - ): - super().__init__() - assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" - - self.hidden_size = hidden_size - self.sparsity_threshold = sparsity_threshold - self.num_blocks = num_blocks - self.block_size = self.hidden_size // self.num_blocks - self.hard_thresholding_fraction = hard_thresholding_fraction - self.hidden_size_factor = hidden_size_factor - self.scale = 0.02 - - self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) - self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) - self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) - self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) - - def forward(self, x): - bias = x - - dtype = x.dtype - x = x.float() - B, D, H, W, C = x.shape - - x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho") - x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size) - - o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) - o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) - o2_real = torch.zeros(x.shape, device=x.device) - o2_imag = torch.zeros(x.shape, device=x.device) - - total_modes = H // 2 + 1 - kept_modes = int(total_modes * self.hard_thresholding_fraction) - - o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( - torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ - torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ - self.b1[0] - ) - - o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( - torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ - torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ - self.b1[1] - ) - - o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( - torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ - torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ - self.b2[0] - ) - - o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( - torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ - torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ - self.b2[1] - ) - - x = torch.stack([o2_real, o2_imag], dim=-1) - x = F.softshrink(x, lambd=self.sparsity_threshold) - x = torch.view_as_complex(x) - x = x.reshape(B, D, H, W // 2 + 1, C) - x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho") - x = x.type(dtype) - - return x + bias - - -class AFNOBlock3d(nn.Module): - def __init__( - self, - dim, - mlp_ratio=4., - drop=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - double_skip=True, - num_blocks=8, - sparsity_threshold=0.01, - hard_thresholding_fraction=1.0, - data_format="channels_last", - mlp_out_features=None, - ): - super().__init__() - self.norm_layer = norm_layer - self.norm1 = norm_layer(dim) - self.filter = AFNO3D(dim, num_blocks, sparsity_threshold, - hard_thresholding_fraction) - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, out_features=mlp_out_features, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop - ) - self.double_skip = double_skip - self.channels_first = (data_format == "channels_first") - - def forward(self, x): - if self.channels_first: - # AFNO natively uses a channels-last data format - x = x.permute(0,2,3,4,1) - - residual = x - x = self.norm1(x) - x = self.filter(x) - - if self.double_skip: - x = x + residual - residual = x - - x = self.norm2(x) - x = self.mlp(x) - x = x + residual - - if self.channels_first: - x = x.permute(0,4,1,2,3) - - return x - - -class PatchEmbed3d(nn.Module): - def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256): - super().__init__() - self.patch_size = patch_size - self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x): - x = self.proj(x) - x = x.permute(0,2,3,4,1) # convert to BHWC - return x - - -class PatchExpand3d(nn.Module): - def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256): - super().__init__() - self.patch_size = patch_size - self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size)) - - def forward(self, x): - x = self.proj(x) - x = rearrange( - x, - "b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)", - p0=self.patch_size[0], - p1=self.patch_size[1], - p2=self.patch_size[2], - d=x.shape[1], - h=x.shape[2], - w=x.shape[3], - ) - return x - - -class AFNOCrossAttentionBlock3d(nn.Module): - """ AFNO 3D Block with channel mixing from two sources, used (only?) in the Unet denoiser - """ - def __init__( - self, - dim, - context_dim, - mlp_ratio=2., - drop=0., - act_layer=nn.GELU, - norm_layer=nn.Identity, - double_skip=True, - num_blocks=8, - sparsity_threshold=0.01, - hard_thresholding_fraction=1.0, - data_format="channels_last", - timesteps=None - ): - super().__init__() - - self.norm1 = norm_layer(dim) - self.norm2 = norm_layer(dim+context_dim) - mlp_hidden_dim = int((dim+context_dim) * mlp_ratio) - self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim) - self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold, - hard_thresholding_fraction) - self.mlp = Mlp( - in_features=dim+context_dim, - out_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, drop=drop - ) - self.channels_first = (data_format == "channels_first") - - def forward(self, x, y): - if self.channels_first: - # AFNO natively uses a channels-last order - x = x.permute(0,2,3,4,1) - y = y.permute(0,2,3,4,1) - - xy = torch.concat((self.norm1(x),y), axis=-1) - xy = self.pre_proj(xy) + xy - xy = self.filter(self.norm2(xy)) + xy # AFNO filter - x = self.mlp(xy) + x # feed-forward - - if self.channels_first: - x = x.permute(0,4,1,2,3) - - return x \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py deleted file mode 100644 index b8b3149..0000000 --- a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/attention-checkpoint.py +++ /dev/null @@ -1,106 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/attention.py - -import math - -import torch -from torch import nn -import torch.nn.functional as F - - -class TemporalAttention(nn.Module): - def __init__( - self, channels, context_channels=None, - head_dim=32, num_heads=8 - ): - super().__init__() - self.channels = channels - if context_channels is None: - context_channels = channels - self.context_channels = context_channels - self.head_dim = head_dim - self.num_heads = num_heads - self.inner_dim = head_dim * num_heads - self.attn_scale = self.head_dim ** -0.5 - if channels % num_heads: - raise ValueError("channels must be divisible by num_heads") - self.KV = nn.Linear(context_channels, self.inner_dim*2) - self.Q = nn.Linear(channels, self.inner_dim) - self.proj = nn.Linear(self.inner_dim, channels) - - def forward(self, x, y=None): - if y is None: - y = x - - (K,V) = self.KV(y).chunk(2, dim=-1) - (B, Dk, H, W, C) = K.shape - shape = (B, Dk, H, W, self.num_heads, self.head_dim) - K = K.reshape(shape) - V = V.reshape(shape) - - Q = self.Q(x) - (B, Dq, H, W, C) = Q.shape - shape = (B, Dq, H, W, self.num_heads, self.head_dim) - Q = Q.reshape(shape) - - K = K.permute((0,2,3,4,5,1)) # K^T - V = V.permute((0,2,3,4,1,5)) - Q = Q.permute((0,2,3,4,1,5)) - - attn = torch.matmul(Q, K) * self.attn_scale - attn = F.softmax(attn, dim=-1) - y = torch.matmul(attn, V) - y = y.permute((0,4,1,2,3,5)) - y = y.reshape((B,Dq,H,W,C)) - y = self.proj(y) - return y - - -class TemporalTransformer(nn.Module): - def __init__(self, - channels, - mlp_dim_mul=1, - **kwargs - ): - super().__init__() - self.attn1 = TemporalAttention(channels, **kwargs) - self.attn2 = TemporalAttention(channels, **kwargs) - self.norm1 = nn.LayerNorm(channels) - self.norm2 = nn.LayerNorm(channels) - self.norm3 = nn.LayerNorm(channels) - self.mlp = MLP(channels, dim_mul=mlp_dim_mul) - - def forward(self, x, y): - x = self.attn1(self.norm1(x)) + x # self attention - x = self.attn2(self.norm2(x), y) + x # cross attention - return self.mlp(self.norm3(x)) + x # feed-forward - - -class MLP(nn.Sequential): - def __init__(self, dim, dim_mul=4): - inner_dim = dim * dim_mul - sequence = [ - nn.Linear(dim, inner_dim), - nn.SiLU(), - nn.Linear(inner_dim, dim) - ] - super().__init__(*sequence) - - -def positional_encoding(position, dims, add_dims=()): - div_term = torch.exp( - torch.arange(0, dims, 2, device=position.device) * - (-math.log(10000.0) / dims) - ) - if position.ndim == 1: - arg = position[:,None] * div_term[None,:] - else: - arg = position[:,:,None] * div_term[None,None,:] - - pos_enc = torch.concat( - [torch.sin(arg), torch.cos(arg)], - dim=-1 - ) - if add_dims: - for dim in add_dims: - pos_enc = pos_enc.unsqueeze(dim) - return pos_enc \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py b/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py deleted file mode 100644 index 983092d..0000000 --- a/src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/resnet-checkpoint.py +++ /dev/null @@ -1,91 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/resnet.py - -from torch import nn -from torch.nn.utils.parametrizations import spectral_norm as sn - -from ..utils import activation, normalization - - -class ResBlock3D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - resample=None, - resample_factor=(1, 1, 1), - kernel_size=(3, 3, 3), - act="swish", - norm="group", - norm_kwargs=None, - spectral_norm=False, - **kwargs, - ): - super().__init__(**kwargs) - if in_channels != out_channels: - self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1) - else: - self.proj = nn.Identity() - - padding = tuple(k // 2 for k in kernel_size) - if resample == "down": - self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True) - self.conv1 = nn.Conv3d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=resample_factor, - padding=padding, - ) - self.conv2 = nn.Conv3d( - out_channels, out_channels, kernel_size=kernel_size, padding=padding - ) - elif resample == "up": - self.resample = nn.Upsample(scale_factor=resample_factor, mode="trilinear") - self.conv1 = nn.ConvTranspose3d( - in_channels, out_channels, kernel_size=kernel_size, padding=padding - ) - output_padding = tuple( - 2 * p + s - k - for (p, s, k) in zip(padding, resample_factor, kernel_size) - ) - self.conv2 = nn.ConvTranspose3d( - out_channels, - out_channels, - kernel_size=kernel_size, - stride=resample_factor, - padding=padding, - output_padding=output_padding, - ) - else: - self.resample = nn.Identity() - self.conv1 = nn.Conv3d( - in_channels, out_channels, kernel_size=kernel_size, padding=padding - ) - self.conv2 = nn.Conv3d( - out_channels, out_channels, kernel_size=kernel_size, padding=padding - ) - - if isinstance(act, str): - act = (act, act) - self.act1 = activation(act_type=act[0]) - self.act2 = activation(act_type=act[1]) - - if norm_kwargs is None: - norm_kwargs = {} - self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) - self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) - if spectral_norm: - self.conv1 = sn(self.conv1) - self.conv2 = sn(self.conv2) - if not isinstance(self.proj, nn.Identity): - self.proj = sn(self.proj) - - def forward(self, x): - x_in = self.resample(self.proj(x)) - x = self.norm1(x) - x = self.act1(x) - x = self.conv1(x) - x = self.norm2(x) - x = self.act2(x) - x = self.conv2(x) - return x + x_in \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py deleted file mode 100644 index f63d2d4..0000000 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/unet-checkpoint.py +++ /dev/null @@ -1,492 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/unet.py: weirdly, this part (which is the denoiser) was in not in the diffusion folder in the original code - -from abc import abstractmethod -from functools import partial -import math -from typing import Iterable - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F -import pytorch_lightning as pl - -from .utils import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from ..blocks.afno import AFNOCrossAttentionBlock3d -SpatialTransformer = type(None) -#from ldm.modules.attention import SpatialTransformer - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, AFNOCrossAttentionBlock3d): - img_shape = tuple(x.shape[-2:]) - x = layer(x, context[img_shape]) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - - """ - - def __init__( - self, - model_channels, - in_channels=1, - out_channels=1, - num_res_blocks=2, - attention_resolutions=(1,2,4), - context_ch=128, - dropout=0, - channel_mult=(1, 2, 4, 4), - conv_resample=True, - dims=3, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - legacy=True, - num_timesteps=1 - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' - - if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - timesteps = th.arange(1, num_timesteps+1) - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - dim_head = num_head_channels - layers.append( - AFNOCrossAttentionBlock3d( - ch, context_dim=context_ch[level], num_blocks=num_heads, - data_format="channels_first", timesteps=timesteps - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - dim_head = num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AFNOCrossAttentionBlock3d( - ch, context_dim=context_ch[-1], num_blocks=num_heads, - data_format="channels_first", timesteps=timesteps - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = num_head_channels - layers.append( - AFNOCrossAttentionBlock3d( - ch, context_dim=context_ch[level], num_blocks=num_heads, - data_format="channels_first", timesteps=timesteps - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - - def forward(self, x, timesteps=None, context=None): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :return: an [N x C x ...] Tensor of outputs. - """ - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - return self.out(h) - diff --git a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py b/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py deleted file mode 100644 index e908cd1..0000000 --- a/src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/utils-checkpoint.py +++ /dev/null @@ -1,249 +0,0 @@ -# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/utils.py - -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - -import os -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat - - -def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if schedule == "linear": - betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - - -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) - else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps].cpu() - numpied = [alphacums[0].cpu()] + alphacums[ddim_timesteps[:-1]].cpu().tolist() - alphas_prev = np.asarray(numpied) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) - if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return nn.Identity() #GroupNorm32(32, channels) - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") \ No newline at end of file From db6a2a9466bb54f1b0dab5dd7bccdce94a4df038 Mon Sep 17 00:00:00 2001 From: Martin Bonte Date: Thu, 19 Mar 2026 14:34:13 +0100 Subject: [PATCH 13/13] I have essentially reorganized the code so that the nets can be trained in parallel on multiple GPUs. The main thing for that is that the LatentDiffusion class has to have the autoencoder as an attribute so that Lightning creates one instance of the autoencoder and one instance of the ldm on each GPU (in DDP strategy). --- config.yaml | 58 ++++++++++--------- src/mlcast/models/base.py | 2 +- src/mlcast/models/ldcast/autoenc/autoenc.py | 21 +++++-- src/mlcast/models/ldcast/data.py | 26 ++++----- .../models/ldcast/diffusion/diffusion.py | 26 +++++++-- src/mlcast/models/ldcast/diffusion/ema.py | 20 +++---- src/mlcast/models/ldcast/ldcast.py | 15 +++-- .../models/ldcast/transforms/antialiasing.py | 16 +++-- 8 files changed, 103 insertions(+), 81 deletions(-) diff --git a/config.yaml b/config.yaml index 17ee1c9..9e6eba5 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,3 @@ -# everything is from the original config, except for the batch size - -device: &device 'cuda' - model: autoencoder: optimizer_class: "${as_class: 'torch.optim.AdamW'}" @@ -14,25 +10,33 @@ model: kwargs: patience: 3 factor: 0.25 - verbose: True extra: monitor: 'val/rec_loss' frequency: 1 interval: 'epoch' antialiaser: use: True - kwargs: - device: *device + kwargs: {} encoder: {} decoder: {} net_kwargs: hidden_width: &autoencoder_hidden_width 32 loss: kl_weight: 0.01 - device: *device + trainer: + max_epochs: 200 + accelerator: 'gpu' + log_every_n_steps: 5 + callbacks: "${as_class: '[pl.callbacks.EarlyStopping(\"val/loss_epoch\", patience=6, verbose=True, check_finite=False)]'}" + strategy: 'ddp' + num_nodes: 1 + sync_batchnorm: True + dataloader: + batch_size: 1 + num_workers: 0 + persistent_workers: False ldm: - device: *device conditioner: autoencoder_dim: *autoencoder_hidden_width output_patches: &output_patches 5 @@ -53,7 +57,7 @@ model: ema: use: True kwargs: - store_device: 'cpu' + store_device: 'cuda' optimizer_class: "${as_class: 'torch.optim.AdamW'}" optimizer_kwargs: lr: 0.0001 @@ -64,29 +68,27 @@ model: kwargs: patience: 3 factor: 0.25 - verbose: True extra: - monitor: 'val/ema_loss' + monitor: 'val/loss' # is actually the ema loss, since the ema weights are used for validation frequency: 1 interval: 'epoch' scheduler: {} # diffusion scheduler - -dataloader: - batch_size: 32 - num_workers: 4 - persistent_workers: True - -trainer: - max_epochs: 200 - accelerator: 'gpu' - log_every_n_steps: 5 - callbacks: "${as_class: '[pl.callbacks.EarlyStopping(\"val/rec_loss\", patience=6, verbose=True, check_finite=False)]'}" + trainer: + max_epochs: 200 + accelerator: 'gpu' + log_every_n_steps: 5 + callbacks: "${as_class: '[pl.callbacks.EarlyStopping(\"val/loss_epoch\", patience=6, verbose=True, check_finite=False)]'}" + strategy: 'ddp' + num_nodes: 1 + sync_batchnorm: True + dataloader: + batch_size: 1 + num_workers: 0 + persistent_workers: False sampled_radar_dataset: - zarr_path: 'test.zarr' - csv_path: 'mlcast-dataset-sampler/sampled_datacubes_2017-01-01-2017-02-01_24x256x256_3x16x16_10000.csv' + zarr_path: '/scratch/martinbo/MLCast/radklim.zarr' + csv_path: '/scratch/martinbo/MLCast/LDCastTraining/indexes_radklim/sampled_datacubes_2001-01-01-2001-01-01_24x256x256_3x16x16_1500000.csv' steps: 24 augment: False - data_var: 'precip_intensity_EDK' - - + data_var: 'RR' \ No newline at end of file diff --git a/src/mlcast/models/base.py b/src/mlcast/models/base.py index a4fe87b..5d2a832 100644 --- a/src/mlcast/models/base.py +++ b/src/mlcast/models/base.py @@ -160,7 +160,7 @@ def print_log_loss(self, loss, step_name): if isinstance(loss, dict): # append step name to loss keys for logging loss = {f"{step_name}/{k}": v for k, v in loss.items()} - self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) loss = loss.get(f"{step_name}/total_loss", None) if loss is None: raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py index d14a244..d42df86 100644 --- a/src/mlcast/models/ldcast/autoenc/autoenc.py +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -30,12 +30,27 @@ def forward(self, predictions, y): class Autoencoder(NowcastingLightningModule): def __init__(self, net, loss, antialiaser = Antialiaser(), **kwargs): super().__init__(net, loss, **kwargs) + self.save_hyperparameters(ignore=['net', 'loss', 'antialiaser']) self.antialiaser = antialiaser def forward(self, x): if self.antialiaser is not None: x = self.antialiaser(x) return self.net(x) + + def encode(self, x): + if self.antialiaser is not None: + x = self.antialiaser(x) + return self.net.encode(x) + + def training_logic(self, batch, batch_idx): + x, y = batch + predictions = self.forward(x) + if self.antialiaser is not None: + y = self.antialiser(y) + loss = self.loss(predictions, y) + + return loss @classmethod def from_config(cls, config): @@ -51,7 +66,7 @@ def from_config(cls, config): optimizer_class = config['optimizer_class'], optimizer_kwargs = config['optimizer_kwargs'], lr_scheduler_config = config['lr_scheduler'] - ).to(config['device']) + ) class AutoencoderKLNet(nn.Module): def __init__( @@ -68,7 +83,6 @@ def __init__( self.hidden_width = hidden_width self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, kernel_size=1) self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, kernel_size=1) - self.log_var = nn.Parameter(torch.zeros(size=())) def encode(self, x, return_log_var = False): if len(x.shape) < 5: @@ -79,9 +93,6 @@ def encode(self, x, return_log_var = False): if return_log_var: return (mean, log_var) else: - # if the first axis has length 1, it is the batch dimension and should be removed - if mean.shape[0] == 1: - mean = mean[0] return mean def decode(self, z): diff --git a/src/mlcast/models/ldcast/data.py b/src/mlcast/models/ldcast/data.py index a2a08b8..9ea69a3 100644 --- a/src/mlcast/models/ldcast/data.py +++ b/src/mlcast/models/ldcast/data.py @@ -7,10 +7,10 @@ class LatentDataset(Dataset): def __init__(self, sampled_radar_dataset, autoencoder, autoenc_time_ratio = 4): super().__init__() - self.autoencoder = autoencoder + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.autoencoder = autoencoder.to(self.device) self.dataset = sampled_radar_dataset self.autoenc_time_ratio = autoenc_time_ratio - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' def __len__(self): return len(self.dataset) @@ -19,17 +19,17 @@ def __getitem__(self, idx): with torch.no_grad(): sequence = self.dataset[idx] - x = sequence[:self.autoenc_time_ratio] - y = sequence[self.autoenc_time_ratio:] + x = sequence[:, :self.autoenc_time_ratio] + y = sequence[:, self.autoenc_time_ratio:] # for some reason, Gabriele put the time axis before the channel axis, change this - x = x.swapaxes(0, 1).to(self.device) - y = y.swapaxes(0, 1).to(self.device) - - latent_x = self.autoencoder.encode(x) - latent_y = self.autoencoder.encode(y) + #x = x.swapaxes(0, 1).to(self.device) + #y = y.swapaxes(0, 1).to(self.device) - return latent_x, latent_y + #latent_x = self.autoencoder.encode(x) + #latent_y = self.autoencoder.encode(y) + + return x, y #latent_x, latent_y class AutoencoderDataset(Dataset): ''' @@ -42,8 +42,6 @@ def __init__(self, sampled_radar_dataset, autoenc_time_ratio = 4): self.srd = sampled_radar_dataset self.autoenc_time_ratio = autoenc_time_ratio self.samples_ratio = int(self.srd.steps / self.autoenc_time_ratio) # is 6 in the usual case where steps = 24 and autoenc_time_ratio = 4 - - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' def __len__(self): return self.samples_ratio * len(self.srd) @@ -55,10 +53,10 @@ def __getitem__(self, idx): ''' index_srd = idx // self.samples_ratio index_in_srd_sample = idx - index_srd * self.samples_ratio - x = self.srd[index_srd].reshape(self.samples_ratio, self.autoenc_time_ratio, 1, self.srd.w, self.srd.h)[index_in_srd_sample] + x = self.srd[index_srd].reshape(self.samples_ratio, 1, self.autoenc_time_ratio, self.srd.w, self.srd.h)[index_in_srd_sample] # for some reason, Gabriele put the time axis before the channel axis, change this - x = x.swapaxes(0, 1).to(self.device) + # x = x.swapaxes(0, 1) # for the autoencoder, y is equal to x y = x diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py index 328480f..2e33b47 100644 --- a/src/mlcast/models/ldcast/diffusion/diffusion.py +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -26,11 +26,17 @@ def forward(self, x): class LatentDiffusion(NowcastingLightningModule): - def __init__(self, net, loss, scheduler, ema_config = {'use': True}, **kwargs): + def __init__(self, net, loss, scheduler, autoencoder, ema_config = {'use': True}, **kwargs): super().__init__(net, loss, **kwargs) + self.save_hyperparameters(ignore=['net', 'loss', 'autoencoder']) self.scheduler = scheduler + self.autoencoder = autoencoder # the LatentDiffusion class needs the autoencoder so that, in a parallel training setup, Lightning creates one instance of the autoencoder on each GPU - # register the schedules (i.e. the values of alpha, beta etc). + # Freeze autoencoder + for param in self.autoencoder.parameters(): + param.requires_grad = False + + # register the schedules (i.e. the values of alpha, beta etc) self.register_schedule() if ema_config['use']: @@ -52,7 +58,11 @@ def register_schedule(self): self.net.denoiser.register_buffer(k, schedule[k]) def training_logic(self, batch, batch_idx): - latent_inputs, latent_true = batch + + inputs, true = batch + with torch.no_grad(): + latent_inputs = self.autoencoder.encode(inputs) + latent_true = self.autoencoder.encode(true) x0 = latent_true t, noise, x_t = self.q_sample(x0) @@ -78,6 +88,10 @@ def q_sample(self, x0, noise = None, t = None): return t, noise, x_noisy + def on_fit_start(self): + if hasattr(self, 'ema'): + self.ema.register() + def on_train_batch_end(self, outputs, batch, batch_idx): if hasattr(self, 'ema'): self.ema.update() @@ -107,7 +121,7 @@ def on_predict_end(self): self.ema.restore() @classmethod - def from_config(cls, config): + def from_config(cls, config, autoencoder): from .scheduler import Scheduler from .unet import UNetModel @@ -119,10 +133,10 @@ def from_config(cls, config): loss = nn.MSELoss() scheduler = Scheduler(**config['scheduler']) - return cls(net, loss, scheduler, + return cls(net, loss, scheduler, autoencoder, ema_config = config['ema'], optimizer_class = config['optimizer_class'], optimizer_kwargs = config['optimizer_kwargs'], lr_scheduler_config = config['lr_scheduler'] - ).to(config['device']) + ) diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py index af13022..fd7e3d4 100644 --- a/src/mlcast/models/ldcast/diffusion/ema.py +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -17,17 +17,15 @@ def __init__(self, model, decay = 0.9999, use_num_updates = True, store_device = self.shadow = {} # to store EMA weights self.backup = {} # to store the model weights when we replace them by ema weights self.num_updates = 0 if use_num_updates else -1 # for dynamical decay - self.store_device = store_device # device on which to store the weights - self.model_device = next(model.parameters()).device # device on which the weights are used - - self.register() - + self.store_device = store_device # device on which to store the weights + def register(self): '''initialize the ema weights with the model weights''' + print(next(self.model.parameters()).device) for name, param in self.model.named_parameters(): if param.requires_grad: - self.shadow[name] = param.data.clone().detach().to(self.store_device) - + self.shadow[name] = param.data.detach().to(self.store_device) + def update(self): '''update the shadow parameters''' @@ -44,16 +42,18 @@ def update(self): def apply_shadow(self): '''apply shadow (EMA) weights to the model''' + model_device = next(self.model.parameters()).device for name, param in self.model.named_parameters(): if param.requires_grad: - self.backup[name] = param.data.clone().detach().to(self.store_device) - param.data = self.shadow[name].to(self.model_device) + self.backup[name] = param.data.detach().to(self.store_device) + param.data = self.shadow[name].to(model_device) def restore(self): '''restore original model weights from backup''' + model_device = next(self.model.parameters()).device for name, param in self.model.named_parameters(): if param.requires_grad: - param.data = self.backup[name].to(self.model_device) + param.data = self.backup[name].to(model_device) def load(self, filename): '''load the ema (shadow) weights parameters''' diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py index bfb7b9e..ed94d42 100644 --- a/src/mlcast/models/ldcast/ldcast.py +++ b/src/mlcast/models/ldcast/ldcast.py @@ -6,7 +6,7 @@ import torch import contextlib -torch.multiprocessing.set_start_method('spawn') +#torch.multiprocessing.set_start_method('spawn') class LDCast(NowcastingModelBase): def __init__(self, ldm, autoencoder, sampler): @@ -27,16 +27,16 @@ def fit(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {} self.fit_ldm(sampled_radar_dataset, dataloader_kwargs = dataloader_kwargs, trainer_kwargs = trainer_kwargs) def fit_ldm(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): + + #assert False, 'need to add a trainer instance in the LatentDataset class to automatically move the autoencoder to cuda etc.' self.autoencoder.net.eval() - self.ldm.net.train() - dataset = LatentDataset(sampled_radar_dataset, self.autoencoder.net) + dataset = LatentDataset(sampled_radar_dataset, self.autoencoder) datamodule = DataModule(dataset, **dataloader_kwargs) trainer = L.Trainer(**trainer_kwargs) trainer.fit(self.ldm, datamodule) def fit_autoencoder(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): - self.autoencoder.net.train() dataset = AutoencoderDataset(sampled_radar_dataset) datamodule = DataModule(dataset, **dataloader_kwargs) @@ -48,12 +48,13 @@ def predict(self, inputs, num_diffusion_iters = 50, verbose = True): assert False, 'prediction should be implemented with a trainer, to take into account the switches of ema weights for example''' - latent_inputs = self.autoencoder.net.encode(inputs) + latent_inputs = self.autoencoder.encode(inputs) condition = self.ldm.net.conditioner(latent_inputs) gen_shape = (32, 5, 256//4, 256//4) batch_size = len(latent_inputs) + # this could also be put in the LatentDiffusion class, by overriding the predict_step method (https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#inference) with contextlib.redirect_stdout(None): (s, intermediates) = self.sampler.sample( num_diffusion_iters, @@ -86,14 +87,12 @@ def load(self, folder): @classmethod def from_config(cls, config): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - from .autoenc.autoenc import Autoencoder from .diffusion.diffusion import LatentDiffusion from .diffusion.plms import PLMSSampler autoencoder = Autoencoder.from_config(config['autoencoder']) - ldm = LatentDiffusion.from_config(config['ldm']) + ldm = LatentDiffusion.from_config(config['ldm'], autoencoder) sampler = PLMSSampler(ldm.net.denoiser) return cls(ldm, autoencoder, sampler) diff --git a/src/mlcast/models/ldcast/transforms/antialiasing.py b/src/mlcast/models/ldcast/transforms/antialiasing.py index bf0a70c..c6ab4c9 100644 --- a/src/mlcast/models/ldcast/transforms/antialiasing.py +++ b/src/mlcast/models/ldcast/transforms/antialiasing.py @@ -1,16 +1,15 @@ import torch import torch.nn as nn -class Antialiaser: - def __init__(self, device = 'cpu'): +class Antialiaser(nn.ModuleDict): + def __init__(self): - self.device = device + super().__init__() # construct the kernel (symmetric in both directions), shape = (5, 5) (x, y) = torch.meshgrid(torch.arange(-2, 3), torch.arange(-2, 3), indexing = 'ij') kernel = torch.exp(-0.5*(x**2+y**2)/(0.5**2)) kernel /= kernel.sum() - kernel = kernel.to(self.device) # the convolution will be done on x (shape = (1, autoenc_time_ratio) + spatial_shape) # so treat the autoenc_time_ratio as one axis of the convolution -> Conv3d @@ -19,12 +18,11 @@ def __init__(self, device = 'cpu'): # set the weights to be those of the kernel self.conv.weight = nn.Parameter(kernel[None, None, None], requires_grad = False) - - def __call__(self, x): + + def forward(self, x): # factor is 1 in the bulk of x, but is greater than 1 near the border to accounr that less values were used in the convolution # recomputed each time because the image shape could change - factor = self.conv(torch.ones(x.shape, device = self.device)) - factor = 1./factor + factor = self.conv(torch.ones(x.shape, device = x.device)) - return self.conv(x) * factor \ No newline at end of file + return self.conv(x) / factor \ No newline at end of file