diff --git a/.gitignore b/.gitignore index 9853d60..97c46a7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,10 @@ __pycache__/ build/ dist/ .venv/ +.ipynb_checkpoints/ +src/mlcast/modules/.ipynb_checkpoints/ +src/mlcast/models/ldcast/context/.ipynb_checkpoints/ +src/mlcast/models/ldcast/diffusion/.ipynb_checkpoints/ +src/mlcast/models/ldcast/.ipynb_checkpoints/ +src/mlcast/models/ldcast/autoenc/.ipynb_checkpoints/ +src/mlcast/models/ldcast/blocks/.ipynb_checkpoints/ diff --git a/README.md b/README.md index d6e9b47..9b3ef6d 100644 --- a/README.md +++ b/README.md @@ -1,83 +1,42 @@ -# mlcast +# MLCast implementation of LDCast - +see main branch https://github.com/mlcast-community/mlcast for context. -The MLCast Community is a collaborative effort bringing together meteorological services, research institutions, and academia across Europe to develop a unified Python package for AI-based nowcasting. This is an initiative of the E-AI WG6 (Nowcasting) of EUMETNET. +## Code structure -This repo contains the `mlcast` package for machine learning-based weather nowcasting. +There is one main `LDCast` class, subclassing the `NowcastingModelBase` class. There are three main nets in LDCast: + - the autoencoder + - the conditioner + - the denoiser -## Project Status +The `NowcastingLightningModule` is subclassed by the smaller composites of nets that should be trained at once. This gives two subclasses in this case: + - the autoencoder (encoder + decoder) has to be trained on its own, so there is one subclass of `NowcastingLightningModule` called `Autoencoder` + - the conditioner and the denoiser have to be trained together, so they are combined into one neural network (the `LatentDiffusionNet` class), whose training is handled by the `LatentDiffusion` subclass of the `NowcastingLightningModule` -⚠️ **Under Development** - This package is currently in early development stages and not usable by end users. The API and functionality are subject to change. +## Documentation -## Installation -```bash -# Install from pypi -pip install mlcast -``` +See `docs` folder for some documenation on the main `LDCast` class, on the autoencoder and on the latent diffusion part. -or -```bash -# Install from source -git clone https://github.com/mlcast-community/mlcast -cd mlcast -uv pip install -e . +## TO DO -# For development -uv pip install -e ".[dev]" -``` +reorganize the `LatentDiffusion` class ? for the moment, `LatentDiffusionNet.forward` is never called during inference because the inference process is quite different than in training (see `docs/ldm.md). It might be maybe a bit clearer to reorganize that by implementing explicitly different training and inference step methods in the `LatentDiffusion` class (that being said, `AutoencoderKLNet.forward` is never called either during inference) -## Project Structure +The 'timesteps' variable sometimes refers to the timesteps of the diffusion process (= 1000 during training) and sometimes refers to the nowcasting timesteps (where each time step = 5 minutes). Better to have different names. -``` -mlcast/ -├── src/mlcast/ # Main package source code -│ ├── __init__.py # Package initialization and version -│ ├── data/ # Data loading and preprocessing -│ │ ├── zarr_datamodule.py # PyTorch Lightning data module for Zarr -│ │ └── zarr_dataset.py # PyTorch dataset for Zarr arrays -│ ├── models/ # Lightning model implementations -│ │ └── base.py # Abstract base classes for nowcasting models -│ └── modules/ # Pure PyTorch neural network modules -│ └── convgru_modules.py # ConvGRU encoder-decoder modules -├── examples/ # Example scripts and notebooks -│ └── scripts/ -│ └── simple_train.py # Basic training example -├── pyproject.toml # Project metadata and dependencies -├── LICENSE # Apache 2.0 license -└── README.md # This file -``` +We might integrate this code within the Hugging Face Diffusers Library. -## Development +It remains mainly to write code in the main LDCast class (in `ldcast.py`) -This project uses `uv` for dependency management. To set up the development environment: +It would be nice to rewrite the PLMS sampler, it is a little messy -```bash -# Install uv if not already installed -curl -LsSf https://astral.sh/uv/install.sh | sh +implement different parametrization than 'eps' -# Install dependencies -uv sync +use ZarrDataModule and ZarrDataset ! -# Run pre-commit hooks -uv run pre-commit install -``` +add the computation of the EMA loss during the ldm training, change the LDCast.predict method so that EMA weights are automatically used during inference -## Contributing +add in the code (and in the doc) the input and output shapes of the nets -Please feel free to raise issues or PRs if you have any suggestions or questions. +understand which parameters can be changed, which have to be adapted when others change -## Links to presentations for discussion about the API - -- [2025/02/04 first design discussions](https://docs.google.com/presentation/d/1oWmnyxOfUMWgeQi0XyX4fX9YDMX1vl6h/edit?usp=drive_link&rtpof=true&sd=true) - -## License - -This project is dual-licensed under either: - -* Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) -* BSD 3-Clause License ([LICENSE-BSD](LICENSE-BSD) or https://opensource.org/licenses/BSD-3-Clause) - -at your option. - -See [LICENSE](LICENSE) for more details. +make the implementation of the `AutoencoderDataset` more efficient ? (see docs/autoencoder) \ No newline at end of file diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..9e6eba5 --- /dev/null +++ b/config.yaml @@ -0,0 +1,94 @@ +model: + autoencoder: + optimizer_class: "${as_class: 'torch.optim.AdamW'}" + optimizer_kwargs: + lr: 0.001 + betas: [0.5, 0.9] + weight_decay: 0.001 + lr_scheduler: + class: "${as_class: 'torch.optim.lr_scheduler.ReduceLROnPlateau'}" + kwargs: + patience: 3 + factor: 0.25 + extra: + monitor: 'val/rec_loss' + frequency: 1 + interval: 'epoch' + antialiaser: + use: True + kwargs: {} + encoder: {} + decoder: {} + net_kwargs: + hidden_width: &autoencoder_hidden_width 32 + loss: + kl_weight: 0.01 + trainer: + max_epochs: 200 + accelerator: 'gpu' + log_every_n_steps: 5 + callbacks: "${as_class: '[pl.callbacks.EarlyStopping(\"val/loss_epoch\", patience=6, verbose=True, check_finite=False)]'}" + strategy: 'ddp' + num_nodes: 1 + sync_batchnorm: True + dataloader: + batch_size: 1 + num_workers: 0 + persistent_workers: False + + ldm: + conditioner: + autoencoder_dim: *autoencoder_hidden_width + output_patches: &output_patches 5 + cascade_depth: 3 + embed_dim: 128 + analysis_depth: 4 + denoiser: + in_channels: *autoencoder_hidden_width + model_channels: 256 + out_channels: *autoencoder_hidden_width + num_res_blocks: 2 + attention_resolutions: [1, 2] + dims: 3 + channel_mult: [1, 2, 4] + num_heads: 8 + num_timesteps: *output_patches + context_ch: [128, 256, 512] # should be equal to conditioner.cascade_dims ? + ema: + use: True + kwargs: + store_device: 'cuda' + optimizer_class: "${as_class: 'torch.optim.AdamW'}" + optimizer_kwargs: + lr: 0.0001 + betas: [0.5, 0.9] + weight_decay: 0.001 + lr_scheduler: + class: "${as_class: 'torch.optim.lr_scheduler.ReduceLROnPlateau'}" + kwargs: + patience: 3 + factor: 0.25 + extra: + monitor: 'val/loss' # is actually the ema loss, since the ema weights are used for validation + frequency: 1 + interval: 'epoch' + scheduler: {} # diffusion scheduler + trainer: + max_epochs: 200 + accelerator: 'gpu' + log_every_n_steps: 5 + callbacks: "${as_class: '[pl.callbacks.EarlyStopping(\"val/loss_epoch\", patience=6, verbose=True, check_finite=False)]'}" + strategy: 'ddp' + num_nodes: 1 + sync_batchnorm: True + dataloader: + batch_size: 1 + num_workers: 0 + persistent_workers: False + +sampled_radar_dataset: + zarr_path: '/scratch/martinbo/MLCast/radklim.zarr' + csv_path: '/scratch/martinbo/MLCast/LDCastTraining/indexes_radklim/sampled_datacubes_2001-01-01-2001-01-01_24x256x256_3x16x16_1500000.csv' + steps: 24 + augment: False + data_var: 'RR' \ No newline at end of file diff --git a/docs/autoencoder.md b/docs/autoencoder.md new file mode 100644 index 0000000..3f8da7d --- /dev/null +++ b/docs/autoencoder.md @@ -0,0 +1,80 @@ +# Autoencoder documentation + +1. [Autoencoder class](#autoencoder-class) +2. [Tensor shapes](#tensor-shapes) +3. [Encoding and decoding](#encoding-and-decoding) +4. [Loading original weights](#loading-original-weights) +5. [Antialiasing](#antialiasing) +6. [Autoencoder training dataset](#autoencoder-training-dataset) +7. [Background on variational autoencoders](#background-on-variational-autoencoders) + +## Autoencoder class + +The `Autoencoder` class is a subclass of `NowcastingLightningModule`, and takes three arguments: + - the `net` (an instance of `AutoencoderKLNet` for LDCast), which is the neural network of the autoencoder, containing the decoder and the autoencoder + - the `loss` (an instance of `AutoencoderLoss` for LDCast) +Options for the optimizer and the learning rate scheduler can be passed as well. + +An instance can be created from a `dict` containing the configuration, based on the architecture of LDCast's autoencoder: +```python +from mlcast.models.ldcast.autoencoder.autoencoder import Autoencoder +autoencoder = Autoencoder.from_config(config) +``` + +## Tensor shapes + +The autoencoder encodes sequences of radar images (not image by image). The number of radar images encoded at once is given by `autoenc_time_ratio` and was set to 4 in the original code (and kept here). `Conv3d` layers are used for the encoding, so input tensors have shape +``` +(batch_size, n_channels, autoenc_time_ratio,) + spatial shape +``` +`n_channels` is always 1 for radar images. + +In latent space, the tensors have shape `(batch_size, 32, n, 64, 64)`, where 32 is the `hidden_width` of the `autoencoder` and `n` is the number of consecutive encoded radar images divided by `autoenc_time_ratio`. **I should still clarify which of these parameters can be changed freely, and how it affects other shapes. Can `autoencoder.net` encode a e.g. 8 images at once (in which case `n` is 2) ?** + + +## Encoding and decoding + +Doing the following +```python +import torch +inputs = torch.randn(1, 1, 4, 256, 256, device = 'cuda') # fake sample +autoencoder(inputs) +``` +is equivalent to `autoencoder.net(inputs)` and computes the whole forward pass through the `net` (encoding + decoding). To encode only, one needs to do +```python +autoencoder.net.encode(inputs). +``` +If `encoded` is an encoded sample, it can be decoded as +```python +autoencoder.net.decode(encoded) +``` + +## Laoding original weights + +The original weights can be loaded directly as +```python +autoenc_weights_fn = '/path/to/original/autoencoder/weights' +autoencoder.net.load_state_dict(torch.load(autoenc_weights_fn)) +``` + +## Antialiasing + +As in the original code, antialiasing is applied by default (by an Antialiaser object) to the inputs before being fed to the `net`. + +## Autoencoder training dataset + +Gabriele's code produces a dataset whose samples are sequences of `steps` images (`steps` is usually set to 24, to have 4 input images and 20 ground truth images). + +But the autoencoder needs samples which are sequences of only 4 images, so each sample in `SampledRadarDataset` needs to be divided in 6 samples. This is done by the `AutoencoderDataset`. Its samples are tuple `(x, y)` where `y = x` since we want the autoencoder to reconstruct the sequences. + +**The current implementation of this class is not the most efficient since, when going through the `AutoencoderDataset`, each sample of the `SampledRadarDataset` is loaded 6 times.** + +## Background on variational autoencoders + +The autoencoder used in LDCast is a variational autoencoder. Here is some background on that kind of autoencoder. + +Source https://medium.com/@jpark7/finally-a-clear-derivation-of-the-vae-kl-loss-4cb38d2e47b3. + +Variational autoencoders encode the data through a normal distribution in latent space: each sample is represented by the mean and the standard deviation of the normal distribution. When decoding the sample, a new sample is created resembling the original sample, but is not quite the same. The degree to which we force the decoded samples to resemble the original ones is tuned by the `kl_weight` parameter of the KL loss function. + +When using the encoded sample (for example to produce a condition with the conditioner), only the mean is used. In the original code, `autoencoder.net.decode` was returning a tuple `(mean, log_var)`, so that one had to select the mean with `autoencoder.net.decode(x)[0]`, which is not very clear. I replaced this by adding a keyword `return_log_var` in `autoencoder.net.decode`. \ No newline at end of file diff --git a/docs/ldcast.md b/docs/ldcast.md new file mode 100644 index 0000000..a7554db --- /dev/null +++ b/docs/ldcast.md @@ -0,0 +1,57 @@ +# Main LDCast class documentation + +1. [LDCast class](#ldcast-class) +2. [Inference](#inference) +3. [Loading/saving weights](#loading/saving-weights) +4. [Training](#training) + +## LDCast class + +The `LDCast` class is a subclass of `NowcastingModelBase` and takes three arguments + - the `ldm` (typically, an instance of `LatentDiffusion`) + - the `autoencoder` (typically, an instance of `Autoencoder`) + - the `sampler` + +An instance can be created from a `dict` containing the configuration, based on the architecture of LDCast: +```python +from mlcast.models.ldcast.ldcast import LDCast +ldcast = LDCast.from_config(config) +``` +A config very close to what was used in the original code is in 'config.yaml'. It should be loaded as +```python +from omegaconf import OmageConf +OmegaConf.register_new_resolver("as_class", lambda class_name: eval(class_name)) +config = OmegaConf.load('config.yaml') +``` + +## Inference + +Predictions can be produced with +```python +import torch +inputs = torch.randn(1, 1, 4, 256, 256, device = 'cuda') # fake data +ldcast.predict(inputs) +``` +**Do not use for the moment, since the EMA weights (if used) are not automatically used for inference** + +## Loading/saving weights +To load from a folder containing in different files the weights of the autoencoder, of the denoiser and of the conditioner (and possibly ema weights): +```python +ldcast.load('/path/to/folder') +``` +To save in a folder: +```python +ldcast.save('/path/to/folder') +``` + +## Training + +If `sampled_radar_dataset` is a `SampledRadarDataset` built with Gabriele's code (https://github.com/DSIP-FBK/ConvGRU-Ensemble/blob/main/convgru_ensemble/datamodule.py), the autoencoder can be trained with +```python +ldcast.fit_autoencoder(sampled_radar_dataset) +``` +and the ldm can be trained +```python +ldcast.fit_ldm(sampled_radar_dataset) +``` +Keyword arguments can be passed to the trainer and the dataloader through the `trainer_kwargs` and `dataloader_kwargs` keywords. \ No newline at end of file diff --git a/docs/ldm.md b/docs/ldm.md new file mode 100644 index 0000000..18fab0f --- /dev/null +++ b/docs/ldm.md @@ -0,0 +1,106 @@ +# Latent diffusion documentation + +1. [LatentDiffusion class](#latentdiffusion-class) +2. [LatentDiffusionNet](#latentdiffusionnet) +3. [Training vs inference modes](#training-vs-inference-modes) +4. [Loading original weights](#loading-original-weights) +5. [Exponential Moving Average](#exponential-moving-average) +6. [LatentDataset](#latentdataset) +7. [Background on diffusion models](#background-on-diffusion-models) + +## LatentDiffusion class + +The `LatentDiffusion` class is a subclass of `NowcastingLightningModule` and takes three arguments + - the `net` (typically, an instance of `LatentDiffusionNet`) + - the `loss` (a `torch.nn.MSELoss` for LDCast) + - the `scheduler`, scheduling the diffusion process +Options for the optimizer and the learning rate scheduler can be passed as well. + +An instance can be created from a `dict` containing the configuration, based on the corresponding part of the architecture of LDCast: +```python +from mlcast.models.ldcast.diffusion.diffusion import LatentDiffusion +ldm = LatentDiffusion.from_config(config) +``` + +## LatentDiffusionNet + +The `LatentDiffusionNet` class combines two elements: the `conditioner` and the `denoiser`. + +The `denoiser` takes some noise and performs the backward diffusion process to produce samples (in latent space). Since we want a nowcast based on input images, the denoiser needs with some condition based on the input images. + +The role of the `conditioner` is to provide this condition to the denoiser. It takes input images (encoded in latent space) and returns a condition (also called context ?) to help the denoiser to produce relevant predictions. The `conditioner` could also be called a forecaster. + +In the original LDCast code, the `conditioner` was called `analysis_net` and the `denoiser` was called `denoiser` or `model`. + +As in the original code, the `conditioner` is an instance of `AFNONowcastNet` and the `denoiser` is an instance of `UNetModel`. + +**check the output shape of the conditioner and the input shape of the denoiser** + +## Training vs inference modes + +The `net` combines the `conditioner` and the `denoiser` in the way they should be trained (i.e. the `denoiser` is always called after the `conditioner`). The inference process is however different: once the input has been converted in latent space, it is passed to the conditioner to produce a condition; then, the denoiser is repeatedly called to iteratively denoise a completely noisy image according to a scheme defined by a sampler (see [Background on diffusion models](background-on-diffusion-models)). During the inference, `net.forward` is thus never called. + +## Loading original weights + +The structure of this part of LDCast has changed a little with respect to the original code, so the weights need to be reorganized before being loaded. The `convert_original_weights` function does this: +```python +from mlcast.models.ldcast.original_weights import convert_original_weights +ldm_weights_fn = '/path/to/original/ldm/genforecast/weights' +state_dict = convert_original_weights(ldm_weights_fn) +``` +`state_dict`is a `dict` whose keys are `conditioner`, `denoiser`, `ema` and `unmatched`. `state_dict['unmatched']` contains the elements which were not matched in `convert_original_weights` (should be empty). The weights of the conditioner, of the denoiser and the EMA can then be loaded as +``` +ldm.net.conditioner.load_state_dict(state_dict['conditioner']) +ldm.net.denoiser.load_state_dict(state_dict['denoiser']) +ldm.ema.load(state_dict['ema']) +``` + +## Exponential Moving Average + +The original code included an Exponential Moving Average (EMA) of the weights of the denoiser. This seems quite common for diffusion models. + +The idea is two versions of the weights of the models: + 1. the usual weights, which are updated through the optimization of the training loss + 2. the EMA weights, which are computed as an average of the last values of the usual weights + +The average is exponentially weighted, so that the latest weights are more taken into account. + +This is useful because the usual weights are quite unstable, while the EMA weights are more stable because of the average. + +The EMA is switched on by default, and can be switched off when creating the `LatentDiffusion` class (setting the keyword `ema_config` to `{'use': False}`). When switched on, the original weights have to be loaded into the model for the computation of the training loss, but the loss with the EMA weights should be computed during validation (and also after each training step ?). The EMA weights should also be used during inference. Everything is handled automatically if a `pl.Trainer` is used (through lightning hooks on the `LatentDiffusion` class). This means that one also needs to use a `pl.Trainer` at inference. + +## LatentDataset + +Gabriele's `SampledRadarDataset` returns sequence of `steps` radar images (with `steps` usually set to 24). However, the training of the `ldm` requires samples in latent space. This is handled by the `LatentDataset` class: it returns samples `(x, y)` with `x` being the latent encoding of the input radar images, and `y` being the latent encoding of the radar images to predict. The `LatentDataset` thus needs the trained `autoencoder.net` ! + +## Background on diffusion models + +See https://huggingface.co/blog/annotated-diffusion for some notations and formulas. + +During training, we start from a sample $x_0$ and create a series of samples $x_0, x_1, ..., x_T$ according to the formula (forward diffusion) + +$$ +x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha_t}}\epsilon_t, \quad t = 1, ..., T +$$ + +where $\epsilon_t \sim \mathcal{N}(0, 1)$. The constants $\bar{\alpha}_t$ are chosen, but they need the property that $\bar{\alpha}_t \to 0$ as $t \to T$, so that $x_T \sim \mathcal{N}(0, 1)$. These constants are computed with an algorithm called a scheduler. + +From a given $x_t$, the model can either be trained to predict $x_0$, $\epsilon_t$ or the velocity $v_t$. The latter is defined as + +$$ +v_t = \sqrt{\bar{\alpha}_t} \epsilon_t - \sqrt{1-\bar{\alpha}_t} x_0, +$$ + +which is equivalent to + +$$ +x_0 = \sqrt{\bar{\alpha}_t} x_t - \sqrt{1-\bar{\alpha}_t} v_t. +$$ + +The model is also given the timestep $t$. The loss is computed by comparing the target quantity ($\epsilon_t$, $x_0$ or $v_t$) with the predicted quantity by the model. Choosing to predict $x_0$, $\epsilon_t$ or $v_t$ is conceptually equivalent, the difference is in the numerical properties of the scheme (like in ODE integration schemes). + +The validation and test steps are done in the same way. + +So the model is trained to predict $x_0$ (or something from which we can compute $x_0$) from $x_T\sim \mathcal{N}(0, 1)$. But for large values of $t$, this prediction is actually quite bad. During actual prediction, the prediction is usually iteratively refined with sampler schemes. The idea is that, from the noise predicted based on $x_T$, the sampler scheme allows to compute $x_{T - \Delta t}$. The model is then used to predict the noise based on this estimation of $x_{T - \Delta t}$, and the sampler scheme allows to deduce $x_{T - 2\Delta t}$, etc. $\Delta t$ is usually taken of the order of 50 (while $T$ is usually 1000). + +In Hugging Face Diffusers library, the scheduler and the sampler parts are often combined in one object called a scheduler, but the sampler part is only used during inference. \ No newline at end of file diff --git a/src/mlcast/models/base.py b/src/mlcast/models/base.py index d526d63..5d2a832 100644 --- a/src/mlcast/models/base.py +++ b/src/mlcast/models/base.py @@ -113,15 +113,16 @@ def __init__( loss: nn.Module, optimizer_class: Any | None = None, optimizer_kwargs: dict | None = None, - **kwargs: Any, + lr_scheduler_config: dict | None = None, ): super().__init__() self.save_hyperparameters(ignore=["net", "loss"]) self.net = net self.loss = loss self.optimizer_class = torch.optim.Adam if optimizer_class is None else optimizer_class + self.lr_scheduler_config = lr_scheduler_config - def forward(self, x: torch.Tensor, n_timesteps: int) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the model. Args: x: Input tensor with shape (batch, seq_len, channels, height, width) @@ -129,8 +130,16 @@ def forward(self, x: torch.Tensor, n_timesteps: int) -> torch.Tensor: Returns: Output tensor with shape (batch, n_timesteps, channels, height, width) """ - return self.net(x, n_timesteps) # Assuming net is a callable model + return self.net(x) # Assuming net is a callable model + def training_logic(self, batch, batch_idx): + """Can be overwritten if needed""" + x, y = batch + predictions = self.forward(x) + loss = self.loss(predictions, y) + + return loss + def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> torch.Tensor: """Generic model step for training or validation. @@ -141,20 +150,24 @@ def model_step(self, batch: Any, batch_idx: int, step_name: str = "train") -> to Returns: Loss value for the current batch """ - x, y = batch - predictions = self.forward(x, n_timesteps=y.shape[1]) - loss = self.loss(predictions, y) + + loss = self.training_logic(batch, batch_idx) + loss = self.print_log_loss(loss, step_name) + + return loss + + def print_log_loss(self, loss, step_name): if isinstance(loss, dict): # append step name to loss keys for logging - loss = {f"{step_name}/{k}": v.item() for k, v in loss} - self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) - loss = loss.get("loss", loss.get("total_loss", None)) + loss = {f"{step_name}/{k}": v for k, v in loss.items()} + self.log_dict(loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + loss = loss.get(f"{step_name}/total_loss", None) if loss is None: raise ValueError(f"Loss is None for step {step_name}. Ensure loss function returns a valid tensor.") else: self.log(f"{step_name}/loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss - + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """Training step for a single batch. @@ -184,5 +197,22 @@ def configure_optimizers(self) -> torch.optim.Optimizer: Returns: Optimizer instance to use for training + + following https://lightning.ai/docs/pytorch/stable/common/optimization.html for the scheduler part """ - return self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) + optimizer = self.optimizer_class(self.parameters(), **(self.hparams.optimizer_kwargs or {})) + + if self.lr_scheduler_config is None: + return optimizer + + else: + cls = self.lr_scheduler_config['class'] + kwargs = self.lr_scheduler_config['kwargs'] + extra = self.lr_scheduler_config['extra'] + + # up to here, extra might be a omegaconf dictconfig object, but we need add the scheduler to it + # so convert it to a dict + extra = {k: v for k, v in extra.items()} + extra['scheduler'] = cls(optimizer, **kwargs) + + return {'optimizer': optimizer, 'lr_scheduler': extra} diff --git a/src/mlcast/models/ldcast/autoenc/autoenc.py b/src/mlcast/models/ldcast/autoenc/autoenc.py new file mode 100644 index 0000000..d42df86 --- /dev/null +++ b/src/mlcast/models/ldcast/autoenc/autoenc.py @@ -0,0 +1,110 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py + +import torch +from torch import nn +from .encoder import SimpleConvEncoder, SimpleConvDecoder +from ...base import NowcastingLightningModule +from ..transforms.antialiasing import Antialiaser + +from ..distributions import ( + ensemble_nll_normal, + kl_from_standard_normal, + sample_from_standard_normal, +) + +class AutoencoderLoss(nn.Module): + def __init__(self, kl_weight = 0.01): + super().__init__() + self.kl_weight = kl_weight + + def forward(self, predictions, y): + (y_pred, mean, log_var) = predictions + + rec_loss = (y - y_pred).abs().mean() + kl_loss = kl_from_standard_normal(mean, log_var) + + total_loss = rec_loss + self.kl_weight * kl_loss + + return {'total_loss': total_loss, 'rec_loss': rec_loss, 'kl_loss': kl_loss} + +class Autoencoder(NowcastingLightningModule): + def __init__(self, net, loss, antialiaser = Antialiaser(), **kwargs): + super().__init__(net, loss, **kwargs) + self.save_hyperparameters(ignore=['net', 'loss', 'antialiaser']) + self.antialiaser = antialiaser + + def forward(self, x): + if self.antialiaser is not None: + x = self.antialiaser(x) + return self.net(x) + + def encode(self, x): + if self.antialiaser is not None: + x = self.antialiaser(x) + return self.net.encode(x) + + def training_logic(self, batch, batch_idx): + x, y = batch + predictions = self.forward(x) + if self.antialiaser is not None: + y = self.antialiser(y) + loss = self.loss(predictions, y) + + return loss + + @classmethod + def from_config(cls, config): + + antialiaser = Antialiaser(**config['antialiaser']['kwargs']) if config['antialiaser']['use'] else None + encoder = SimpleConvEncoder(**config['encoder']) + decoder = SimpleConvDecoder(**config['decoder']) + net = AutoencoderKLNet(encoder = encoder, decoder = decoder, **config['net_kwargs']) + loss = AutoencoderLoss(**config['loss']) + + return cls(net, loss, + antialiaser = antialiaser, + optimizer_class = config['optimizer_class'], + optimizer_kwargs = config['optimizer_kwargs'], + lr_scheduler_config = config['lr_scheduler'] + ) + +class AutoencoderKLNet(nn.Module): + def __init__( + self, + encoder = SimpleConvEncoder(), + decoder = SimpleConvDecoder(), + encoded_channels=64, + hidden_width=32, + **kwargs, + ): + super().__init__(**kwargs) + self.encoder = encoder + self.decoder = decoder + self.hidden_width = hidden_width + self.to_moments = nn.Conv3d(encoded_channels, 2 * hidden_width, kernel_size=1) + self.to_decoder = nn.Conv3d(hidden_width, encoded_channels, kernel_size=1) + + def encode(self, x, return_log_var = False): + if len(x.shape) < 5: + x = x[None] + + h = self.encoder(x) + (mean, log_var) = torch.chunk(self.to_moments(h), 2, dim=1) + if return_log_var: + return (mean, log_var) + else: + return mean + + def decode(self, z): + z = self.to_decoder(z) + dec = self.decoder(z) + return dec + + def forward(self, x, sample_posterior=True): + (mean, log_var) = self.encode(x, return_log_var = True) + if sample_posterior: + z = sample_from_standard_normal(mean, log_var) + else: + z = mean + dec = self.decode(z) + return (dec, mean, log_var) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/autoenc/encoder.py b/src/mlcast/models/ldcast/autoenc/encoder.py new file mode 100644 index 0000000..157af11 --- /dev/null +++ b/src/mlcast/models/ldcast/autoenc/encoder.py @@ -0,0 +1,59 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/autoenc/autoenc.py + +import numpy as np +import torch.nn as nn + +from ..blocks.resnet import ResBlock3D +from ..utils import activation, normalization + + +class SimpleConvEncoder(nn.Sequential): + def __init__(self, in_dim=1, levels=2, min_ch=64): + sequence = [] + channels = np.hstack([ + in_dim, + (8**np.arange(1,levels+1)).clip(min=min_ch) + ]) + + for i in range(levels): + in_channels = int(channels[i]) + out_channels = int(channels[i+1]) + res_kernel_size = (3,3,3) if i == 0 else (1,3,3) + res_block = ResBlock3D( + in_channels, out_channels, + kernel_size=res_kernel_size, + norm_kwargs={"num_groups": 1} + ) + sequence.append(res_block) + downsample = nn.Conv3d(out_channels, out_channels, + kernel_size=(2,2,2), stride=(2,2,2)) + sequence.append(downsample) + in_channels = out_channels + + super().__init__(*sequence) + + +class SimpleConvDecoder(nn.Sequential): + def __init__(self, in_dim=1, levels=2, min_ch=64): + sequence = [] + channels = np.hstack([ + in_dim, + (8**np.arange(1,levels+1)).clip(min=min_ch) + ]) + + for i in reversed(list(range(levels))): + in_channels = int(channels[i+1]) + out_channels = int(channels[i]) + upsample = nn.ConvTranspose3d(in_channels, in_channels, + kernel_size=(2,2,2), stride=(2,2,2)) + sequence.append(upsample) + res_kernel_size = (3,3,3) if (i == 0) else (1,3,3) + res_block = ResBlock3D( + in_channels, out_channels, + kernel_size=res_kernel_size, + norm_kwargs={"num_groups": 1} + ) + sequence.append(res_block) + in_channels = out_channels + + super().__init__(*sequence) diff --git a/src/mlcast/models/ldcast/blocks/afno.py b/src/mlcast/models/ldcast/blocks/afno.py new file mode 100644 index 0000000..84c73d0 --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/afno.py @@ -0,0 +1,350 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/afno.py + +#reference: https://github.com/NVlabs/AFNO-transformer +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .attention import TemporalAttention + +class Mlp(nn.Module): + def __init__( + self, + in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class AFNO2D(nn.Module): + def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, H, W, C = x.shape + + x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho") + x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, H, W // 2 + 1, C) + x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0 + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.double_skip = double_skip + + def forward(self, x): + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + return x + + + +class AFNO3D(nn.Module): + def __init__( + self, hidden_size, num_blocks=8, sparsity_threshold=0.01, + hard_thresholding_fraction=1, hidden_size_factor=1 + ): + super().__init__() + assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + + self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor)) + self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor)) + self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size)) + self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) + + def forward(self, x): + bias = x + + dtype = x.dtype + x = x.float() + B, D, H, W, C = x.shape + + x = torch.fft.rfftn(x, dim=(1, 2, 3), norm="ortho") + x = x.reshape(B, D, H, W // 2 + 1, self.num_blocks, self.block_size) + + o1_real = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o1_imag = torch.zeros([B, D, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device) + o2_real = torch.zeros(x.shape, device=x.device) + o2_imag = torch.zeros(x.shape, device=x.device) + + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \ + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \ + self.b1[0] + ) + + o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu( + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \ + torch.einsum('...bi,bio->...bo', x[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \ + self.b1[1] + ) + + o2_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \ + torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[0] + ) + + o2_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = ( + torch.einsum('...bi,bio->...bo', o1_imag[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \ + torch.einsum('...bi,bio->...bo', o1_real[:, :, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \ + self.b2[1] + ) + + x = torch.stack([o2_real, o2_imag], dim=-1) + x = F.softshrink(x, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, D, H, W // 2 + 1, C) + x = torch.fft.irfftn(x, s=(D, H, W), dim=(1,2,3), norm="ortho") + x = x.type(dtype) + + return x + bias + + +class AFNOBlock3d(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4., + drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + data_format="channels_last", + mlp_out_features=None, + ): + super().__init__() + self.norm_layer = norm_layer + self.norm1 = norm_layer(dim) + self.filter = AFNO3D(dim, num_blocks, sparsity_threshold, + hard_thresholding_fraction) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, out_features=mlp_out_features, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop + ) + self.double_skip = double_skip + self.channels_first = (data_format == "channels_first") + + def forward(self, x): + if self.channels_first: + # AFNO natively uses a channels-last data format + x = x.permute(0,2,3,4,1) + + residual = x + x = self.norm1(x) + x = self.filter(x) + + if self.double_skip: + x = x + residual + residual = x + + x = self.norm2(x) + x = self.mlp(x) + x = x + residual + + if self.channels_first: + x = x.permute(0,4,1,2,3) + + return x + + +class PatchEmbed3d(nn.Module): + def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=256): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.proj(x) + x = x.permute(0,2,3,4,1) # convert to BHWC + return x + + +class PatchExpand3d(nn.Module): + def __init__(self, patch_size=(4,4,4), out_chans=1, embed_dim=256): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear(embed_dim, out_chans*np.prod(patch_size)) + + def forward(self, x): + x = self.proj(x) + x = rearrange( + x, + "b d h w (p0 p1 p2 c_out) -> b c_out (d p0) (h p1) (w p2)", + p0=self.patch_size[0], + p1=self.patch_size[1], + p2=self.patch_size[2], + d=x.shape[1], + h=x.shape[2], + w=x.shape[3], + ) + return x + + +class AFNOCrossAttentionBlock3d(nn.Module): + """ AFNO 3D Block with channel mixing from two sources, used (only?) in the Unet denoiser + """ + def __init__( + self, + dim, + context_dim, + mlp_ratio=2., + drop=0., + act_layer=nn.GELU, + norm_layer=nn.Identity, + double_skip=True, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1.0, + data_format="channels_last", + timesteps=None + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim+context_dim) + mlp_hidden_dim = int((dim+context_dim) * mlp_ratio) + self.pre_proj = nn.Linear(dim+context_dim, dim+context_dim) + self.filter = AFNO3D(dim+context_dim, num_blocks, sparsity_threshold, + hard_thresholding_fraction) + self.mlp = Mlp( + in_features=dim+context_dim, + out_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop + ) + self.channels_first = (data_format == "channels_first") + + def forward(self, x, y): + if self.channels_first: + # AFNO natively uses a channels-last order + x = x.permute(0,2,3,4,1) + y = y.permute(0,2,3,4,1) + + xy = torch.concat((self.norm1(x),y), axis=-1) + xy = self.pre_proj(xy) + xy + xy = self.filter(self.norm2(xy)) + xy # AFNO filter + x = self.mlp(xy) + x # feed-forward + + if self.channels_first: + x = x.permute(0,4,1,2,3) + + return x \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/attention.py b/src/mlcast/models/ldcast/blocks/attention.py new file mode 100644 index 0000000..b8b3149 --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/attention.py @@ -0,0 +1,106 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/attention.py + +import math + +import torch +from torch import nn +import torch.nn.functional as F + + +class TemporalAttention(nn.Module): + def __init__( + self, channels, context_channels=None, + head_dim=32, num_heads=8 + ): + super().__init__() + self.channels = channels + if context_channels is None: + context_channels = channels + self.context_channels = context_channels + self.head_dim = head_dim + self.num_heads = num_heads + self.inner_dim = head_dim * num_heads + self.attn_scale = self.head_dim ** -0.5 + if channels % num_heads: + raise ValueError("channels must be divisible by num_heads") + self.KV = nn.Linear(context_channels, self.inner_dim*2) + self.Q = nn.Linear(channels, self.inner_dim) + self.proj = nn.Linear(self.inner_dim, channels) + + def forward(self, x, y=None): + if y is None: + y = x + + (K,V) = self.KV(y).chunk(2, dim=-1) + (B, Dk, H, W, C) = K.shape + shape = (B, Dk, H, W, self.num_heads, self.head_dim) + K = K.reshape(shape) + V = V.reshape(shape) + + Q = self.Q(x) + (B, Dq, H, W, C) = Q.shape + shape = (B, Dq, H, W, self.num_heads, self.head_dim) + Q = Q.reshape(shape) + + K = K.permute((0,2,3,4,5,1)) # K^T + V = V.permute((0,2,3,4,1,5)) + Q = Q.permute((0,2,3,4,1,5)) + + attn = torch.matmul(Q, K) * self.attn_scale + attn = F.softmax(attn, dim=-1) + y = torch.matmul(attn, V) + y = y.permute((0,4,1,2,3,5)) + y = y.reshape((B,Dq,H,W,C)) + y = self.proj(y) + return y + + +class TemporalTransformer(nn.Module): + def __init__(self, + channels, + mlp_dim_mul=1, + **kwargs + ): + super().__init__() + self.attn1 = TemporalAttention(channels, **kwargs) + self.attn2 = TemporalAttention(channels, **kwargs) + self.norm1 = nn.LayerNorm(channels) + self.norm2 = nn.LayerNorm(channels) + self.norm3 = nn.LayerNorm(channels) + self.mlp = MLP(channels, dim_mul=mlp_dim_mul) + + def forward(self, x, y): + x = self.attn1(self.norm1(x)) + x # self attention + x = self.attn2(self.norm2(x), y) + x # cross attention + return self.mlp(self.norm3(x)) + x # feed-forward + + +class MLP(nn.Sequential): + def __init__(self, dim, dim_mul=4): + inner_dim = dim * dim_mul + sequence = [ + nn.Linear(dim, inner_dim), + nn.SiLU(), + nn.Linear(inner_dim, dim) + ] + super().__init__(*sequence) + + +def positional_encoding(position, dims, add_dims=()): + div_term = torch.exp( + torch.arange(0, dims, 2, device=position.device) * + (-math.log(10000.0) / dims) + ) + if position.ndim == 1: + arg = position[:,None] * div_term[None,:] + else: + arg = position[:,:,None] * div_term[None,None,:] + + pos_enc = torch.concat( + [torch.sin(arg), torch.cos(arg)], + dim=-1 + ) + if add_dims: + for dim in add_dims: + pos_enc = pos_enc.unsqueeze(dim) + return pos_enc \ No newline at end of file diff --git a/src/mlcast/models/ldcast/blocks/resnet.py b/src/mlcast/models/ldcast/blocks/resnet.py new file mode 100644 index 0000000..983092d --- /dev/null +++ b/src/mlcast/models/ldcast/blocks/resnet.py @@ -0,0 +1,91 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/blocks/resnet.py + +from torch import nn +from torch.nn.utils.parametrizations import spectral_norm as sn + +from ..utils import activation, normalization + + +class ResBlock3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + resample=None, + resample_factor=(1, 1, 1), + kernel_size=(3, 3, 3), + act="swish", + norm="group", + norm_kwargs=None, + spectral_norm=False, + **kwargs, + ): + super().__init__(**kwargs) + if in_channels != out_channels: + self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=1) + else: + self.proj = nn.Identity() + + padding = tuple(k // 2 for k in kernel_size) + if resample == "down": + self.resample = nn.AvgPool3d(resample_factor, ceil_mode=True) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=resample_factor, + padding=padding, + ) + self.conv2 = nn.Conv3d( + out_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + elif resample == "up": + self.resample = nn.Upsample(scale_factor=resample_factor, mode="trilinear") + self.conv1 = nn.ConvTranspose3d( + in_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + output_padding = tuple( + 2 * p + s - k + for (p, s, k) in zip(padding, resample_factor, kernel_size) + ) + self.conv2 = nn.ConvTranspose3d( + out_channels, + out_channels, + kernel_size=kernel_size, + stride=resample_factor, + padding=padding, + output_padding=output_padding, + ) + else: + self.resample = nn.Identity() + self.conv1 = nn.Conv3d( + in_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + self.conv2 = nn.Conv3d( + out_channels, out_channels, kernel_size=kernel_size, padding=padding + ) + + if isinstance(act, str): + act = (act, act) + self.act1 = activation(act_type=act[0]) + self.act2 = activation(act_type=act[1]) + + if norm_kwargs is None: + norm_kwargs = {} + self.norm1 = normalization(in_channels, norm_type=norm, **norm_kwargs) + self.norm2 = normalization(out_channels, norm_type=norm, **norm_kwargs) + if spectral_norm: + self.conv1 = sn(self.conv1) + self.conv2 = sn(self.conv2) + if not isinstance(self.proj, nn.Identity): + self.proj = sn(self.proj) + + def forward(self, x): + x_in = self.resample(self.proj(x)) + x = self.norm1(x) + x = self.act1(x) + x = self.conv1(x) + x = self.norm2(x) + x = self.act2(x) + x = self.conv2(x) + return x + x_in \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/context.py b/src/mlcast/models/ldcast/context/context.py new file mode 100644 index 0000000..caedbb6 --- /dev/null +++ b/src/mlcast/models/ldcast/context/context.py @@ -0,0 +1,38 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/analysis.py + +import torch +from torch import nn +from torch.nn import functional as F + +from .nowcast import AFNONowcastNetBase +from ..blocks.resnet import ResBlock3D + + +class AFNONowcastNetCascade(AFNONowcastNetBase): + def __init__(self, *args, cascade_depth=4, **kwargs): + super().__init__(*args, **kwargs) + self.cascade_depth = cascade_depth + self.resnet = nn.ModuleList() + ch = self.embed_dim_out + self.cascade_dims = [ch] + for i in range(cascade_depth-1): + ch_out = 2*ch + self.cascade_dims.append(ch_out) + self.resnet.append( + ResBlock3D(ch, ch_out, kernel_size=(1,3,3), norm=None) + ) + ch = ch_out + + def forward(self, x): + # the past timesteps are needed here, but they are always the same... + # need to expand timesteps because of the AFNONowcastNetBase.add_pos_enc method, not sure why + past_timesteps = torch.tensor([-3, -2, -1, 0], device = 'cuda', dtype = torch.float32).unsqueeze(0).expand(1,-1) + x = super().forward(x, past_timesteps) + img_shape = tuple(x.shape[-2:]) + cascade = {img_shape: x} + for i in range(self.cascade_depth-1): + x = F.avg_pool3d(x, (1,2,2)) + x = self.resnet[i](x) + img_shape = tuple(x.shape[-2:]) + cascade[img_shape] = x + return cascade \ No newline at end of file diff --git a/src/mlcast/models/ldcast/context/nowcast.py b/src/mlcast/models/ldcast/context/nowcast.py new file mode 100644 index 0000000..b066816 --- /dev/null +++ b/src/mlcast/models/ldcast/context/nowcast.py @@ -0,0 +1,126 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/nowcast/nowcast.py, but removed the Nowcaster, AFNONowcastNetBasic and AFNONowcastNet classes because they were not used. Reworked also the two remaining classes (FusionBlock3D and AFNONowcastNetBase) to simplify the code by removing the unused parts. + +import collections + +import torch +from torch import nn +from torch.nn import functional as F +import pytorch_lightning as pl + +from ..blocks.afno import AFNOBlock3d +from ..blocks.attention import positional_encoding, TemporalTransformer + +class FusionBlock3d(nn.Module): + def __init__(self, dim, size_ratio, dim_out=None, afno_fusion=False): + super().__init__() + + if dim_out is None: + dim_out = dim + + if size_ratio == 1: + scale = nn.Identity() + else: + scale = [] + while size_ratio > 1: + scale.append(nn.ConvTranspose3d( + dim[i], dim_out if size_ratio==2 else dim[i], + kernel_size=(1,3,3), stride=(1,2,2), + padding=(0,1,1), output_padding=(0,1,1) + )) + size_ratio //= 2 + scale = nn.Sequential(*scale) + self.scale = scale + + self.afno_fusion = afno_fusion + + if self.afno_fusion: + self.fusion = nn.Identity() + + def resize_proj(self, x): + x = x.permute(0,4,1,2,3) + x = self.scale(x) + x = x.permute(0,2,3,4,1) + return x + + def forward(self, x): + x = self.resize_proj(x) + return x + + +class AFNONowcastNetBase(nn.Module): + def __init__( + self, + autoencoder_dim, + embed_dim=128, + embed_dim_out=None, + analysis_depth=4, + forecast_depth=4, + input_patches=1, + input_size_ratios=1, + output_patches=2, + afno_fusion=False + ): + super().__init__() + + if embed_dim_out is None: + embed_dim_out = embed_dim + self.embed_dim = embed_dim + self.embed_dim_out = embed_dim_out + self.output_patches = output_patches + + self.proj = nn.Conv3d(autoencoder_dim, embed_dim, kernel_size=1) + + self.analysis = nn.Sequential( + *(AFNOBlock3d(embed_dim) for _ in range(analysis_depth)) + ) + + # temporal transformer + self.use_temporal_transformer = input_patches != output_patches + if self.use_temporal_transformer: + self.temporal_transformer = TemporalTransformer(embed_dim) + + # data fusion + self.fusion = FusionBlock3d(embed_dim, input_size_ratios, + afno_fusion=afno_fusion, dim_out=embed_dim_out) + + # forecast + self.forecast = nn.Sequential( + *(AFNOBlock3d(embed_dim_out) for _ in range(forecast_depth)) + ) + + def add_pos_enc(self, x, t): + '''not sure the this does what it was supposed to do in the original LDCast code''' + if t.shape[1] != x.shape[1]: + # this can happen if x has been compressed + # by the autoencoder in the time dimension + ds_factor = t.shape[1] // x.shape[1] + t = F.avg_pool1d(t.unsqueeze(1), ds_factor)[:,0,:] + + pos_enc = positional_encoding(t, x.shape[-1], add_dims=(2,3)) + return x + pos_enc + + def forward(self, z, timesteps): + '''z is the latent representation of the conditioning and timesteps is contains the timesteps of the input frames (it is [-3, -2, -1, 0])''' + z = self.proj(z) + z = z.permute(0,2,3,4,1) + z = self.analysis(z) + + if self.use_temporal_transformer: + # add positional encoding + z = self.add_pos_enc(z, timesteps) + + # transform to output shape and coordinates + expand_shape = z.shape[:1] + (-1,) + z.shape[2:] + pos_enc_output = positional_encoding( + torch.arange(1,self.output_patches+1, device=z.device), + self.embed_dim, add_dims=(0,2,3) + ) + pe_out = pos_enc_output.expand(*expand_shape) + z = self.temporal_transformer(pe_out, z) + + + # merge inputs + z = self.fusion(z) + # produce prediction + z = self.forecast(z) + return z.permute(0,4,1,2,3) # to channels-first order \ No newline at end of file diff --git a/src/mlcast/models/ldcast/data.py b/src/mlcast/models/ldcast/data.py new file mode 100644 index 0000000..9ea69a3 --- /dev/null +++ b/src/mlcast/models/ldcast/data.py @@ -0,0 +1,94 @@ +from torch.utils.data import Dataset, random_split, DataLoader +import torch +import pytorch_lightning as pl +from tqdm import tqdm + +class LatentDataset(Dataset): + def __init__(self, sampled_radar_dataset, autoencoder, autoenc_time_ratio = 4): + super().__init__() + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.autoencoder = autoencoder.to(self.device) + self.dataset = sampled_radar_dataset + self.autoenc_time_ratio = autoenc_time_ratio + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + + with torch.no_grad(): + sequence = self.dataset[idx] + x = sequence[:, :self.autoenc_time_ratio] + y = sequence[:, self.autoenc_time_ratio:] + + # for some reason, Gabriele put the time axis before the channel axis, change this + #x = x.swapaxes(0, 1).to(self.device) + #y = y.swapaxes(0, 1).to(self.device) + + #latent_x = self.autoencoder.encode(x) + #latent_y = self.autoencoder.encode(y) + + return x, y #latent_x, latent_y + +class AutoencoderDataset(Dataset): + ''' + shape of one sample of sampled_radar_dataset = (1, 24, 1,) + spatial_shape + But, for the LDCast autoencoder, we want to have samples of (1, 4, 1,) + spatial_shape + So 1 sample of sampled_radar_dataset is partitioned in 6 samples for the autoencoder + ''' + def __init__(self, sampled_radar_dataset, autoenc_time_ratio = 4): + super().__init__() + self.srd = sampled_radar_dataset + self.autoenc_time_ratio = autoenc_time_ratio + self.samples_ratio = int(self.srd.steps / self.autoenc_time_ratio) # is 6 in the usual case where steps = 24 and autoenc_time_ratio = 4 + + def __len__(self): + return self.samples_ratio * len(self.srd) + + def __getitem__(self, idx): + ''' + when given idx between 0 and 6 * len(srd) - 1, one has first to find in which sample of srd we are (the index of this sample is index_srd) + then, within this sample, one has to find in which partition of this sample we are (this is given by index_in_srd_sample) + ''' + index_srd = idx // self.samples_ratio + index_in_srd_sample = idx - index_srd * self.samples_ratio + x = self.srd[index_srd].reshape(self.samples_ratio, 1, self.autoenc_time_ratio, self.srd.w, self.srd.h)[index_in_srd_sample] + + # for some reason, Gabriele put the time axis before the channel axis, change this + # x = x.swapaxes(0, 1) + + # for the autoencoder, y is equal to x + y = x + + return x, y + +class DataModule(pl.LightningDataModule): + def __init__(self, dataset, train_ratio = 0.6, val_ratio = 0.2, **dataloader_kwargs): + super().__init__() + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = 1 - self.train_ratio - self.val_ratio + + train_ds, val_ds, test_ds = random_split(dataset, [self.train_ratio, self.val_ratio, self.test_ratio]) + self.train_dataset = train_ds + self.val_dataset = val_ds + self.test_dataset = test_ds + + self.dataloader_kwargs = dataloader_kwargs + + def train_dataloader(self): + return DataLoader(self.train_dataset, shuffle = True, **self.dataloader_kwargs) + + def val_dataloader(self): + return DataLoader(self.val_dataset, shuffle = False, **self.dataloader_kwargs) + + def test_dataloader(self): + return DataLoader(self.test_dataset, shuffle = False, **self.dataloader_kwargs) + +def load_in_memory(dataset): + ds = [] + for i in tqdm(range(len(dataset)), desc = 'Loading data in memory'): + # append the sample and create an extra dimension (to be the batch dimension) + ds.append(dataset[i][None].to('cpu')) + return torch.cat(ds, axis = 0) \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/diffusion.py b/src/mlcast/models/ldcast/diffusion/diffusion.py new file mode 100644 index 0000000..2e33b47 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/diffusion.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +from ...base import NowcastingLightningModule +import numpy as np +from .utils import extract_into_tensor +from .ema import EMA + +class LatentDiffusionNet(nn.Module): + def __init__(self, conditioner, denoiser, parametrization = "eps"): + super().__init__() + self.conditioner = conditioner + self.denoiser = denoiser + self.parametrization = parametrization + + def forward(self, x): + # during training, noisy should be x_t + # during inference, noisy should be noise + t, noisy, latent_inputs = x + condition = self.conditioner(latent_inputs) + + # if parametrization is eps, out is the predicted noise + # if parametrization is x0, out is the guessed x0 + # if parametrization is v, out is the guessed v + out = self.denoiser(noisy, t, context = condition) + return out + + +class LatentDiffusion(NowcastingLightningModule): + def __init__(self, net, loss, scheduler, autoencoder, ema_config = {'use': True}, **kwargs): + super().__init__(net, loss, **kwargs) + self.save_hyperparameters(ignore=['net', 'loss', 'autoencoder']) + self.scheduler = scheduler + self.autoencoder = autoencoder # the LatentDiffusion class needs the autoencoder so that, in a parallel training setup, Lightning creates one instance of the autoencoder on each GPU + + # Freeze autoencoder + for param in self.autoencoder.parameters(): + param.requires_grad = False + + # register the schedules (i.e. the values of alpha, beta etc) + self.register_schedule() + + if ema_config['use']: + self.ema = EMA(self.net.denoiser, **ema_config['kwargs']) + + def register_schedule(self): + + schedule = self.scheduler.schedule(torch.float32, next(self.net.parameters()).device) + + # check if the ldm has already some saved buffers + saved_buffers = dict(self.net.named_buffers()) + already_saved_and_different = [name for name in schedule.keys() + if (name in saved_buffers.keys() and (schedule[name] != saved_buffers[name]).any()) + ] + if len(already_saved_and_different) > 0: + raise AttributeError(f'The denoiser has already some different values for {already_saved_and_different}') + + for k in schedule.keys(): + self.net.denoiser.register_buffer(k, schedule[k]) + + def training_logic(self, batch, batch_idx): + + inputs, true = batch + with torch.no_grad(): + latent_inputs = self.autoencoder.encode(inputs) + latent_true = self.autoencoder.encode(true) + x0 = latent_true + + t, noise, x_t = self.q_sample(x0) + + if self.net.parametrization == 'eps': + target = noise + if self.net.parametrization == 'x0': + target = x0 + + model_output = self.net((t, x_t, latent_inputs)) + + return self.loss(model_output, target) + + def q_sample(self, x0, noise = None, t = None): + '''generate noise target for training''' + if noise is None: + noise = torch.randn_like(x0) + if t is None: + t = torch.randint(0, self.scheduler.timesteps, (x0.shape[0],), device=x0.device).long() + + x_noisy = extract_into_tensor(self.net.denoiser.sqrt_alphas_cumprod, t, x0.shape) * x0 + \ + extract_into_tensor(self.net.denoiser.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + + return t, noise, x_noisy + + def on_fit_start(self): + if hasattr(self, 'ema'): + self.ema.register() + + def on_train_batch_end(self, outputs, batch, batch_idx): + if hasattr(self, 'ema'): + self.ema.update() + + def on_validation_start(self): + if hasattr(self, 'ema'): + self.ema.apply_shadow() + + def on_validation_end(self): + if hasattr(self, 'ema'): + self.ema.restore() + + def on_test_start(self): + if hasattr(self, 'ema'): + self.ema.apply_shadow() + + def on_test_end(self): + if hasattr(self, 'ema'): + self.ema.restore() + + def on_predict_start(self): + if hasattr(self, 'ema'): + self.ema.apply_shadow() + + def on_predict_end(self): + if hasattr(self, 'ema'): + self.ema.restore() + + @classmethod + def from_config(cls, config, autoencoder): + + from .scheduler import Scheduler + from .unet import UNetModel + from ..context.context import AFNONowcastNetCascade + + conditioner = AFNONowcastNetCascade(**config['conditioner']) + denoiser = UNetModel(**config['denoiser']) + net = LatentDiffusionNet(conditioner, denoiser) + loss = nn.MSELoss() + scheduler = Scheduler(**config['scheduler']) + + return cls(net, loss, scheduler, autoencoder, + ema_config = config['ema'], + optimizer_class = config['optimizer_class'], + optimizer_kwargs = config['optimizer_kwargs'], + lr_scheduler_config = config['lr_scheduler'] + ) + diff --git a/src/mlcast/models/ldcast/diffusion/ema.py b/src/mlcast/models/ldcast/diffusion/ema.py new file mode 100644 index 0000000..fd7e3d4 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/ema.py @@ -0,0 +1,75 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/ema.py +''' +modifications following https://medium.com/@heyamit10/exponential-moving-average-ema-in-pytorch-eb8b6f1718eb +In the original code, EMA was a subclass of nn.Module, in order to register the parameters as buffers and (I guess) to have them saved automatically when saving the model. This made things a bit tricky and messy, because '.'-characters naturally appear in the names of model parameters, while they cannot appear in buffers names... Instead, the EMA weights will have to be saved with torch.save(ema.shadow) and loaded with ema.shadow = torch.load('ema_weights')''' + +import torch +from torch import nn + +class EMA(): + def __init__(self, model, decay = 0.9999, use_num_updates = True, store_device = 'cuda'): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.model = model + self.decay = decay + self.shadow = {} # to store EMA weights + self.backup = {} # to store the model weights when we replace them by ema weights + self.num_updates = 0 if use_num_updates else -1 # for dynamical decay + self.store_device = store_device # device on which to store the weights + + def register(self): + '''initialize the ema weights with the model weights''' + print(next(self.model.parameters()).device) + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.detach().to(self.store_device) + + def update(self): + '''update the shadow parameters''' + + # use dynamical decay if use_num_updates was true in __init__ + decay = self.decay + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + for name, param in self.model.named_parameters(): + if param.requires_grad: + new_average = (1.0 - self.decay) * param.data.detach().to(self.store_device) + self.decay * self.shadow[name] + self.shadow[name] = new_average + + def apply_shadow(self): + '''apply shadow (EMA) weights to the model''' + model_device = next(self.model.parameters()).device + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.backup[name] = param.data.detach().to(self.store_device) + param.data = self.shadow[name].to(model_device) + + def restore(self): + '''restore original model weights from backup''' + model_device = next(self.model.parameters()).device + for name, param in self.model.named_parameters(): + if param.requires_grad: + param.data = self.backup[name].to(model_device) + + def load(self, filename): + '''load the ema (shadow) weights parameters''' + self.shadow = torch.load(filename) + self.decay = self.shadow.pop('decay') + self.num_updates = self.shadow.pop('num_updates') + + # put the shadow tensors on the correct device + for k in self.shadow.keys(): + self.shadow[k] = self.shadow[k].to(self.store_device) + + def save(self, filename): + '''save the ema (shadow) weights parameters''' + self.shadow['decay'] = self.decay + self.shadow['num_updates'] = self.num_updates + torch.save(self.shadow, filename) + + self.shadow.pop('decay') + self.shadow.pop('num_updates') \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/plms.py b/src/mlcast/models/ldcast/diffusion/plms.py new file mode 100644 index 0000000..ce27e0c --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/plms.py @@ -0,0 +1,345 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/plms.py, but changed model.apply_model into model.forward + + +""" +From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py +""" + +"""SAMPLING ONLY.""" + +import numpy as np +import torch +from tqdm import tqdm + +from .utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler: + def __init__(self, model, timesteps = 1000, schedule = "linear", **kwargs): + self.model = model + self.ddpm_num_timesteps = timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + # if type(attr) == torch.Tensor: + # if attr.device != torch.device("cuda"): + # attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + if ddim_eta != 0: + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + device = next(self.model.parameters()).device + to_torch = lambda x: x.clone().detach().to(torch.float32).to(device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(torch.sqrt(alphas_cumprod)) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(torch.sqrt(1.0 - alphas_cumprod)), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(torch.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(torch.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(torch.sqrt(1.0 / alphas_cumprod - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod, + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer( + "ddim_sqrt_one_minus_alphas", torch.sqrt(1.0 - ddim_alphas) + ) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + x0=None, + temperature=1.0, + noise_dropout=0.0, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + progbar=True, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + """ + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + """ + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + size = (batch_size,) + shape + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + progbar=progbar, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + progbar=True, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + list(reversed(range(0, timesteps))) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = time_range + if progbar: + iterator = tqdm(iterator, desc="PLMS Sampler", total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full( + (b,), + time_range[min(i + 1, len(time_range) - 1)], + device=device, + dtype=torch.long, + ) + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms( + self, + x, + condition, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if ( + unconditional_conditioning is None + or unconditional_guidance_scale == 1.0 + ): + e_t = self.model(x, t, condition) + else: + pass + '''is never used + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, condition]) + e_t_uncond, e_t = self.model.apply_denoiser(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + ''' + + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + param_shape = (b,) + (1,) * (x.ndim - 1) + a_t = torch.full(param_shape, alphas[index], device=device) + a_prev = torch.full(param_shape, alphas_prev[index], device=device) + sigma_t = torch.full(param_shape, sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + param_shape, sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = ( + 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] + ) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/scheduler.py b/src/mlcast/models/ldcast/diffusion/scheduler.py new file mode 100644 index 0000000..95cb1a3 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/scheduler.py @@ -0,0 +1,39 @@ +from functools import partial +from .utils import make_beta_schedule +import numpy as np +import torch + +class Scheduler(): + def __init__(self, + timesteps = 1000, + beta_schedule = "linear", + linear_start = 1e-4, + linear_end = 2e-2, + cosine_s = 8e-3, + ): + self.timesteps = timesteps + self.beta_schedule = beta_schedule + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + def schedule(self, dtype, device): + + betas = make_beta_schedule( + self.beta_schedule, self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, + cosine_s=self.cosine_s + ) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + assert alphas_cumprod.shape[0] == self.timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype = dtype, device = device) + + return {'betas': to_torch(betas), + 'alphas_cumprod': to_torch(alphas_cumprod), + 'alphas_cumprod_prev': to_torch(alphas_cumprod_prev), + 'sqrt_alphas_cumprod': to_torch(np.sqrt(alphas_cumprod)), + 'sqrt_one_minus_alphas_cumprod': to_torch(np.sqrt(1. - alphas_cumprod))} \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/simple_sampler.py b/src/mlcast/models/ldcast/diffusion/simple_sampler.py new file mode 100644 index 0000000..ad7dd45 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/simple_sampler.py @@ -0,0 +1,84 @@ +# from https://github.com/mfroelund/ldcast-dmi-public/blob/master/ldcast/models/diffusion/diffusion.py, but reworked + +""" +From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py +Pared down to simplify code. + +The original file acknowledges: +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +""" + +import torch +import torch.nn as nn +import pytorch_lightning as L +from functools import partial +import numpy as np + +from .utils import make_beta_schedule, extract_into_tensor + + +class SimpleSampler(L.LightningModule): + '''Sampler used for training (the PLMSSampler is used for inference). The sample method is not consistent with the sample method of PLMSSampler''' + def __init__(self, + timesteps=1000, + beta_schedule="linear", + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + parameterization="eps", # all assuming fixed variance schedules + ): + super().__init__() + + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + self.beta_schedule = beta_schedule + self.timesteps = timesteps + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + + def register_schedule(self, denoiser): + + # check if the denoiser has already some saved buffers + buffer_names = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod'] + already_saved = [n for n in buffer_names if n in dict(denoiser.named_buffers()).keys()] + if len(already_saved) > 0: + raise AttributeError(f'The denoiser has already some saved values for {already_saved}') + + betas = make_beta_schedule( + self.beta_schedule, self.timesteps, + linear_start=self.linear_start, linear_end=self.linear_end, + cosine_s=self.cosine_s + ) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + assert alphas_cumprod.shape[0] == self.timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32, device = next(denoiser.parameters()).device) + + denoiser.register_buffer('betas', to_torch(betas)) + denoiser.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + denoiser.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + denoiser.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + denoiser.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + + + def q_sample(self, denoiser, x_start, noise=None): + '''generate noise target for training''' + if noise is None: + noise = torch.randn_like(x_start) + t = torch.randint(0, self.timesteps, (x_start.shape[0],), device=x_start.device).long() + x_noisy = extract_into_tensor(denoiser.sqrt_alphas_cumprod, t, x_start.shape) * x_start + \ + extract_into_tensor(denoiser.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + return t, noise, x_noisy + + def sample(self, denoiser, conditioning, num_diffusion_iters = 50): + '''sampling for inference, should maybe be implemented to be consistent with the PLMSSampler class''' + pass + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/diffusion/unet.py b/src/mlcast/models/ldcast/diffusion/unet.py new file mode 100644 index 0000000..f63d2d4 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/unet.py @@ -0,0 +1,492 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/genforecast/unet.py: weirdly, this part (which is the denoiser) was in not in the diffusion folder in the original code + +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl + +from .utils import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ..blocks.afno import AFNOCrossAttentionBlock3d +SpatialTransformer = type(None) +#from ldm.modules.attention import SpatialTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AFNOCrossAttentionBlock3d): + img_shape = tuple(x.shape[-2:]) + x = layer(x, context[img_shape]) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + + """ + + def __init__( + self, + model_channels, + in_channels=1, + out_channels=1, + num_res_blocks=2, + attention_resolutions=(1,2,4), + context_ch=128, + dropout=0, + channel_mult=(1, 2, 4, 4), + conv_resample=True, + dims=3, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + legacy=True, + num_timesteps=1 + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + timesteps = th.arange(1, num_timesteps+1) + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = num_head_channels + layers.append( + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[level], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[-1], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = num_head_channels + layers.append( + AFNOCrossAttentionBlock3d( + ch, context_dim=context_ch[level], num_blocks=num_heads, + data_format="channels_first", timesteps=timesteps + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps=None, context=None): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :return: an [N x C x ...] Tensor of outputs. + """ + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + return self.out(h) + diff --git a/src/mlcast/models/ldcast/diffusion/utils.py b/src/mlcast/models/ldcast/diffusion/utils.py new file mode 100644 index 0000000..e908cd1 --- /dev/null +++ b/src/mlcast/models/ldcast/diffusion/utils.py @@ -0,0 +1,249 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/diffusion/utils.py + +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps].cpu() + numpied = [alphacums[0].cpu()] + alphacums[ddim_timesteps[:-1]].cpu().tolist() + alphas_prev = np.asarray(numpied) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return nn.Identity() #GroupNorm32(32, channels) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") \ No newline at end of file diff --git a/src/mlcast/models/ldcast/distributions.py b/src/mlcast/models/ldcast/distributions.py new file mode 100644 index 0000000..3dcb183 --- /dev/null +++ b/src/mlcast/models/ldcast/distributions.py @@ -0,0 +1,31 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/distributions.py + +import numpy as np +import torch + + +def kl_from_standard_normal(mean, log_var): + kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var) + return kl.mean() + + +def sample_from_standard_normal(mean, log_var, num=None): + std = (0.5 * log_var).exp() + shape = mean.shape + if num is not None: + # expand channel 1 to create several samples + shape = shape[:1] + (num,) + shape[1:] + mean = mean[:, None, ...] + std = std[:, None, ...] + return mean + std * torch.randn(shape, device=mean.device) + + +def ensemble_nll_normal(ensemble, sample, epsilon=1e-5): + mean = ensemble.mean(dim=1) + var = ensemble.var(dim=1, unbiased=True) + epsilon + logvar = var.log() + + diff = sample[:, None, ...] - mean + logtwopi = np.log(2 * np.pi) + nll = (logtwopi + logvar + diff.square() / var).mean() + return nll \ No newline at end of file diff --git a/src/mlcast/models/ldcast/ldcast.py b/src/mlcast/models/ldcast/ldcast.py new file mode 100644 index 0000000..ed94d42 --- /dev/null +++ b/src/mlcast/models/ldcast/ldcast.py @@ -0,0 +1,102 @@ +# new file with respect to original code + +from ..base import NowcastingModelBase +import pytorch_lightning as L +from .data import LatentDataset, AutoencoderDataset, DataModule, load_in_memory +import torch +import contextlib + +#torch.multiprocessing.set_start_method('spawn') + +class LDCast(NowcastingModelBase): + def __init__(self, ldm, autoencoder, sampler): + super().__init__() + self.ldm = ldm + self.autoencoder = autoencoder + self.sampler = sampler + + def fit(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): + '''dataset should contains pairs of (inputs, true), with + inputs.shape = (batch_size, 1, 4, 256, 256) + true.shape = (batch_size, 1, 20, 256, 256) + ''' + print('Training autoencoder') + self.fit_autoencoder(sampled_radar_dataset, dataloader_kwargs = dataloader_kwargs, trainer_kwargs = trainer_kwargs) + + print('Training ldm') + self.fit_ldm(sampled_radar_dataset, dataloader_kwargs = dataloader_kwargs, trainer_kwargs = trainer_kwargs) + + def fit_ldm(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): + + #assert False, 'need to add a trainer instance in the LatentDataset class to automatically move the autoencoder to cuda etc.' + self.autoencoder.net.eval() + + dataset = LatentDataset(sampled_radar_dataset, self.autoencoder) + datamodule = DataModule(dataset, **dataloader_kwargs) + trainer = L.Trainer(**trainer_kwargs) + trainer.fit(self.ldm, datamodule) + + def fit_autoencoder(self, sampled_radar_dataset, dataloader_kwargs = {}, trainer_kwargs = {}): + + dataset = AutoencoderDataset(sampled_radar_dataset) + datamodule = DataModule(dataset, **dataloader_kwargs) + trainer = L.Trainer(**trainer_kwargs) + trainer.fit(self.autoencoder, datamodule) + + def predict(self, inputs, num_diffusion_iters = 50, verbose = True): + '''inputs.shape = (batch_size, 1, 4, 256, 256)''' + + assert False, 'prediction should be implemented with a trainer, to take into account the switches of ema weights for example''' + + latent_inputs = self.autoencoder.encode(inputs) + condition = self.ldm.net.conditioner(latent_inputs) + + gen_shape = (32, 5, 256//4, 256//4) + batch_size = len(latent_inputs) + + # this could also be put in the LatentDiffusion class, by overriding the predict_step method (https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#inference) + with contextlib.redirect_stdout(None): + (s, intermediates) = self.sampler.sample( + num_diffusion_iters, + batch_size, + gen_shape, + condition, + progbar = verbose) + + return s + + latent_pred = self.ldm(latent_inputs) + return self.autoencoder.net.decode(latent_pred) + + def save(self, folder): + torch.save(self.autoencoder.net.state_dict(), f'{folder}/autoencoder.pt') + torch.save(self.ldm.net.conditioner.state_dict(), f'{folder}/conditioner.pt') + torch.save(self.ldm.net.denoiser.state_dict(), f'{folder}/denoiser.pt') + + if hasattr(self.ldm, 'ema'): + self.ldm.ema.save(f'{folder}/ema.pt') + + def load(self, folder): + self.autoencoder.net.load_state_dict(torch.load(f'{folder}/autoencoder.pt')) + self.ldm.net.conditioner.load_state_dict(torch.load(f'{folder}/conditioner.pt')) + self.ldm.net.denoiser.load_state_dict(torch.load(f'{folder}/denoiser.pt')) + + if hasattr(self.ldm, 'ema'): + self.ldm.ema.load(f'{folder}/ema.pt') + + @classmethod + def from_config(cls, config): + + from .autoenc.autoenc import Autoencoder + from .diffusion.diffusion import LatentDiffusion + from .diffusion.plms import PLMSSampler + + autoencoder = Autoencoder.from_config(config['autoencoder']) + ldm = LatentDiffusion.from_config(config['ldm'], autoencoder) + sampler = PLMSSampler(ldm.net.denoiser) + + return cls(ldm, autoencoder, sampler) + + + + \ No newline at end of file diff --git a/src/mlcast/models/ldcast/original_weights.py b/src/mlcast/models/ldcast/original_weights.py new file mode 100644 index 0000000..38ddb7a --- /dev/null +++ b/src/mlcast/models/ldcast/original_weights.py @@ -0,0 +1,89 @@ +import torch +import re + +def convert_original_weights(ldm_weights_fn): + ''' + returns the original weights of the denoiser and the conditioner from the way they were saved originally + at the moment, the ema scope is not taken into account + the unmatched_keys are the ema keys and the buffer keys for the schedule (at the moment) + ''' + ldm_state_dict = torch.load(ldm_weights_fn) + + # track unmatched keys + unmatched_keys = list(ldm_state_dict.keys()) + + # remove the weights of the autoencoder + for k in unmatched_keys.copy(): + if k.startswith('autoencoder.') or k.startswith('context_encoder.autoencoder.'): + unmatched_keys.remove(k) + + # extract the keys of the denoiser (it was called 'model' in the original code) + denoiser_state_dict = {} + for k in unmatched_keys.copy(): + if k.startswith('model.'): + new_key = k.replace('model.', '') + denoiser_state_dict[new_key] = ldm_state_dict[k] + unmatched_keys.remove(k) + + denoiser_buffers_keys = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod'] + for k in unmatched_keys.copy(): + if k in denoiser_buffers_keys: + denoiser_state_dict[k] = ldm_state_dict[k] + unmatched_keys.remove(k) + + # extract the keys of the conditioner (it was called 'context_encoder' in the original code) + conditioner_state_dict = {} + for k in unmatched_keys.copy(): + if k.startswith('context_encoder.'): + new_key = k.replace('context_encoder.', '') + conditioner_state_dict[new_key] = ldm_state_dict[k] + unmatched_keys.remove(k) + + # proj, temporal_transformer and analysis were lists with one only element, I simplified this + # the keys have to be adapted + new_conditioner_state_dict = {} + for k, v in conditioner_state_dict.items(): + new_key = k + if k.startswith('proj.0.'): + new_key = k.replace('proj.0.', 'proj.') + if k.startswith('temporal_transformer.0.'): + new_key = k.replace('temporal_transformer.0.', 'temporal_transformer.') + if k.startswith('analysis.0.'): + new_key = k.replace('analysis.0.', 'analysis.') + new_conditioner_state_dict[new_key] = v + conditioner_state_dict = new_conditioner_state_dict + + ema = {} + for k in unmatched_keys.copy(): + if k.startswith('model_ema.'): + new_key = restore_name(k.replace('model_ema.', '')) + ema[new_key] = ldm_state_dict[k] + unmatched_keys.remove(k) + + # create dict with unmatched keys + unmatched = {key: ldm_state_dict[key] for key in unmatched_keys} + + return {'denoiser': denoiser_state_dict, + 'conditioner': conditioner_state_dict, + 'ema': ema, + 'unmatched': unmatched} + +def restore_name(s): + '''for the EMA, all the dots were removed from the parameters names in the original code, so they should be added again to match during swapping''' + # add dots before and after every digit + res = re.sub(r'(\d)', r'.\1.', s) + # if the digit was in 'fc1', 'fc2', it should not be preceded by a dot + res = res.replace('fc.1', 'fc1') + res = res.replace('fc.2', 'fc2') + # same, but there should be in addition a dot before w1, w2, b1 and b2 + res = res.replace('w.1.', '.w1') + res = res.replace('w.2.', '.w2') + res = res.replace('b.1.', '.b1') + res = res.replace('b.2.', '.b2') + # add a dot before each 'weights' and each 'bias' + res = res.replace('weight', '.weight').replace('bias', '.bias') + # add dot after mlp + res = res.replace('mlp', 'mlp.') + # if two dots are inserted, replace them by one (happens if two digits follow each other, or if a digit is followed by b or w) + res = res.replace('..', '.') + return res \ No newline at end of file diff --git a/src/mlcast/models/ldcast/transforms/antialiasing.py b/src/mlcast/models/ldcast/transforms/antialiasing.py new file mode 100644 index 0000000..c6ab4c9 --- /dev/null +++ b/src/mlcast/models/ldcast/transforms/antialiasing.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +class Antialiaser(nn.ModuleDict): + def __init__(self): + + super().__init__() + + # construct the kernel (symmetric in both directions), shape = (5, 5) + (x, y) = torch.meshgrid(torch.arange(-2, 3), torch.arange(-2, 3), indexing = 'ij') + kernel = torch.exp(-0.5*(x**2+y**2)/(0.5**2)) + kernel /= kernel.sum() + + # the convolution will be done on x (shape = (1, autoenc_time_ratio) + spatial_shape) + # so treat the autoenc_time_ratio as one axis of the convolution -> Conv3d + # but we do not want to convolve on this axis, so use kernl_size = 1 and padding = 0 on this axis + self.conv = nn.Conv3d(1, 1, bias = False, kernel_size = (1, 5, 5), padding = (0, 2, 2)) + + # set the weights to be those of the kernel + self.conv.weight = nn.Parameter(kernel[None, None, None], requires_grad = False) + + def forward(self, x): + + # factor is 1 in the bulk of x, but is greater than 1 near the border to accounr that less values were used in the convolution + # recomputed each time because the image shape could change + factor = self.conv(torch.ones(x.shape, device = x.device)) + + return self.conv(x) / factor \ No newline at end of file diff --git a/src/mlcast/models/ldcast/utils.py b/src/mlcast/models/ldcast/utils.py new file mode 100644 index 0000000..f38bd29 --- /dev/null +++ b/src/mlcast/models/ldcast/utils.py @@ -0,0 +1,30 @@ +# from https://github.com/MeteoSwiss/ldcast/blob/master/ldcast/models/utils.py + +import torch +from torch import nn + + +def normalization(channels, norm_type="group", num_groups=32): + if norm_type == "batch": + return nn.BatchNorm3d(channels) + elif norm_type == "group": + return nn.GroupNorm(num_groups=num_groups, num_channels=channels) + elif (not norm_type) or (norm_type.tolower() == "none"): + return nn.Identity() + else: + raise NotImplementedError(norm) + + +def activation(act_type="swish"): + if act_type == "swish": + return nn.SiLU() + elif act_type == "gelu": + return nn.GELU() + elif act_type == "relu": + return nn.ReLU() + elif act_type == "tanh": + return nn.Tanh() + elif not act_type: + return nn.Identity() + else: + raise NotImplementedError(act_type) \ No newline at end of file