diff --git a/experiments/configs/denoiser/unet_deep_medium.yaml b/experiments/configs/denoiser/unet_deep_medium.yaml index 303bb6b4..1c9f834a 100644 --- a/experiments/configs/denoiser/unet_deep_medium.yaml +++ b/experiments/configs/denoiser/unet_deep_medium.yaml @@ -1,12 +1,12 @@ name: "unet_deep_medium" arch: "unet" emb_features: 256 -hid_channels: [64, 128, 256, 384, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [64, 128, 256, 512, 768, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/denoiser/unet_deep_small.yaml b/experiments/configs/denoiser/unet_deep_small.yaml index 761f5827..b60b0ec1 100644 --- a/experiments/configs/denoiser/unet_deep_small.yaml +++ b/experiments/configs/denoiser/unet_deep_small.yaml @@ -1,12 +1,12 @@ name: "unet_deep_small" arch: "unet" emb_features: 256 -hid_channels: [32, 64, 128, 256, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [32, 64, 128, 256, 512, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/denoiser/unet_large.yaml b/experiments/configs/denoiser/unet_large.yaml index 9446ea69..72c1e4be 100644 --- a/experiments/configs/denoiser/unet_large.yaml +++ b/experiments/configs/denoiser/unet_large.yaml @@ -2,10 +2,10 @@ name: "unet_large" arch: "unet" emb_features: 256 hid_channels: [256, 512, 768, 1024] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 6, 3: 8} +attention_heads: {3: 8} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/denoiser/unet_medium.yaml b/experiments/configs/denoiser/unet_medium.yaml index dd47cfac..24138189 100644 --- a/experiments/configs/denoiser/unet_medium.yaml +++ b/experiments/configs/denoiser/unet_medium.yaml @@ -2,10 +2,10 @@ name: "unet_medium" arch: "unet" emb_features: 256 hid_channels: [128, 256, 384, 512] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 3, 3: 4} +attention_heads: {3: 4} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/denoiser/vit_huge.yaml b/experiments/configs/denoiser/vit_huge.yaml new file mode 100644 index 00000000..fec99e90 --- /dev/null +++ b/experiments/configs/denoiser/vit_huge.yaml @@ -0,0 +1,12 @@ +name: "vit_huge" +arch: "vit" +emb_features: 256 +hid_channels: 2048 +hid_blocks: 16 +attention_heads: 16 +qk_norm: true +rope: true +patch_size: [1, 1, 1] +window_size: null +dropout: 0.05 +checkpointing: false diff --git a/experiments/configs/optim/adamw.yaml b/experiments/configs/optim/adamw.yaml index e86301ea..49dec497 100644 --- a/experiments/configs/optim/adamw.yaml +++ b/experiments/configs/optim/adamw.yaml @@ -1,7 +1,7 @@ name: "${.optimizer}_${.learning_rate}_${.scheduler}" optimizer: "adamw" betas: [0.9, 0.999] -learning_rate: 1e-5 +learning_rate: 3e-5 weight_decay: 0.0 warmup: 0 scheduler: "cosine" diff --git a/experiments/configs/optim/psgd.yaml b/experiments/configs/optim/psgd.yaml index 6c9d220a..8bbdf262 100644 --- a/experiments/configs/optim/psgd.yaml +++ b/experiments/configs/optim/psgd.yaml @@ -3,7 +3,7 @@ optimizer: "psgd" betas: [0.9] precondition_frequency: 64 precondition_dim: 4096 -learning_rate: 1e-5 +learning_rate: 3e-5 weight_decay: 0.0 warmup: 0 scheduler: "cosine" diff --git a/experiments/configs/optim/soap.yaml b/experiments/configs/optim/soap.yaml index b7c6b890..63ff6e6a 100644 --- a/experiments/configs/optim/soap.yaml +++ b/experiments/configs/optim/soap.yaml @@ -3,7 +3,7 @@ optimizer: "soap" betas: [0.9, 0.999, 0.999] precondition_frequency: 16 precondition_dim: 4096 -learning_rate: 1e-5 +learning_rate: 3e-5 weight_decay: 0.0 warmup: 0 scheduler: "cosine" diff --git a/experiments/configs/surrogate/unet_deep_medium.yaml b/experiments/configs/surrogate/unet_deep_medium.yaml index 303bb6b4..1c9f834a 100644 --- a/experiments/configs/surrogate/unet_deep_medium.yaml +++ b/experiments/configs/surrogate/unet_deep_medium.yaml @@ -1,12 +1,12 @@ name: "unet_deep_medium" arch: "unet" emb_features: 256 -hid_channels: [64, 128, 256, 384, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [64, 128, 256, 512, 768, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/surrogate/unet_deep_small.yaml b/experiments/configs/surrogate/unet_deep_small.yaml index 761f5827..b60b0ec1 100644 --- a/experiments/configs/surrogate/unet_deep_small.yaml +++ b/experiments/configs/surrogate/unet_deep_small.yaml @@ -1,12 +1,12 @@ name: "unet_deep_small" arch: "unet" emb_features: 256 -hid_channels: [32, 64, 128, 256, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [32, 64, 128, 256, 512, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/surrogate/unet_large.yaml b/experiments/configs/surrogate/unet_large.yaml index 9446ea69..72c1e4be 100644 --- a/experiments/configs/surrogate/unet_large.yaml +++ b/experiments/configs/surrogate/unet_large.yaml @@ -2,10 +2,10 @@ name: "unet_large" arch: "unet" emb_features: 256 hid_channels: [256, 512, 768, 1024] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 6, 3: 8} +attention_heads: {3: 8} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/surrogate/unet_medium.yaml b/experiments/configs/surrogate/unet_medium.yaml index dd47cfac..24138189 100644 --- a/experiments/configs/surrogate/unet_medium.yaml +++ b/experiments/configs/surrogate/unet_medium.yaml @@ -2,10 +2,10 @@ name: "unet_medium" arch: "unet" emb_features: 256 hid_channels: [128, 256, 384, 512] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 3, 3: 4} +attention_heads: {3: 4} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/surrogate/vit_huge.yaml b/experiments/configs/surrogate/vit_huge.yaml new file mode 100644 index 00000000..fec99e90 --- /dev/null +++ b/experiments/configs/surrogate/vit_huge.yaml @@ -0,0 +1,12 @@ +name: "vit_huge" +arch: "vit" +emb_features: 256 +hid_channels: 2048 +hid_blocks: 16 +attention_heads: 16 +qk_norm: true +rope: true +patch_size: [1, 1, 1] +window_size: null +dropout: 0.05 +checkpointing: false diff --git a/experiments/configs/train_ae.yaml b/experiments/configs/train_ae.yaml index 9c1caeb3..59cc3afc 100644 --- a/experiments/configs/train_ae.yaml +++ b/experiments/configs/train_ae.yaml @@ -1,6 +1,6 @@ defaults: - ae: f32c64_medium - - ae/loss: mae + - ae/loss: mse - dataset: euler_all - optim: psgd - server: rusty @@ -8,12 +8,13 @@ defaults: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 fork: run: null target: "state" + strict: true compute: cpus_per_gpu: 8 diff --git a/experiments/configs/train_dm.yaml b/experiments/configs/train_dm.yaml index 34c93105..e89b609b 100644 --- a/experiments/configs/train_dm.yaml +++ b/experiments/configs/train_dm.yaml @@ -17,13 +17,14 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 - ema_decay: 0.9999 + ema_decay: 0.999 fork: run: null target: "state" + strict: true compute: nodes: 1 diff --git a/experiments/configs/train_ldm.yaml b/experiments/configs/train_ldm.yaml index 1df4dd1a..4e8d61ac 100644 --- a/experiments/configs/train_ldm.yaml +++ b/experiments/configs/train_ldm.yaml @@ -19,13 +19,14 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 - ema_decay: 0.9999 + ema_decay: 0.999 fork: run: null target: "state" + strict: true compute: cpus_per_gpu: 8 diff --git a/experiments/configs/train_lsm.yaml b/experiments/configs/train_lsm.yaml index 4af391a2..176a0f01 100644 --- a/experiments/configs/train_lsm.yaml +++ b/experiments/configs/train_lsm.yaml @@ -17,12 +17,13 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 fork: run: null target: "state" + strict: true compute: cpus_per_gpu: 8 diff --git a/experiments/configs/train_sm.yaml b/experiments/configs/train_sm.yaml index 0336e350..8c94ce88 100644 --- a/experiments/configs/train_sm.yaml +++ b/experiments/configs/train_sm.yaml @@ -15,12 +15,13 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 fork: run: null target: "state" + strict: true compute: nodes: 1 diff --git a/experiments/get_stats.py b/experiments/get_stats.py index e36306a1..2ad134f0 100644 --- a/experiments/get_stats.py +++ b/experiments/get_stats.py @@ -90,7 +90,7 @@ def get_stats(cfg: DictConfig): # Job dawgz.schedule( dawgz.job( - f=partial(get_stats, cfg.dataset, args.samples), + f=partial(get_stats, cfg), name="stats", cpus=cfg.compute.cpus, gpus=cfg.compute.gpus, diff --git a/experiments/train_ae.py b/experiments/train_ae.py index ddc69e20..b236e5ce 100644 --- a/experiments/train_ae.py +++ b/experiments/train_ae.py @@ -3,6 +3,7 @@ import argparse import dawgz import wandb +import os from functools import partial from omegaconf import DictConfig @@ -41,8 +42,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.ae.name}" @@ -86,7 +87,7 @@ def train(runid: str, cfg: DictConfig): path=cfg.server.datasets, physics=cfg.dataset.physics, split=split, - steps=1, + steps=cfg.dataset.steps, include_filters=cfg.dataset.include_filters, augment=cfg.dataset.augment, ) @@ -96,7 +97,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -119,11 +120,13 @@ def train(runid: str, cfg: DictConfig): pix_channels=dataset["train"].metadata.n_fields, **cfg.ae, ).to(device) + if rank == 0: + print(autoencoder) autoencoder_loss = WeightedLoss(**cfg.ae.loss).to(device) if cfg.fork.run is not None: - autoencoder.load_state_dict(stem_state) + autoencoder.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state autoencoder = DistributedDataParallel( @@ -138,7 +141,7 @@ def train(runid: str, cfg: DictConfig): ) # W&B - if rank == 0: + if rank == 0 and not cfg.wandb.dry_run: run = wandb.init( entity=cfg.wandb.entity, project="mpp-ae", @@ -162,7 +165,7 @@ def train(runid: str, cfg: DictConfig): for i in range(cfg.train.epoch_size // cfg.train.batch_size): x, _ = get_well_inputs(next(train_loader), device=device) x = preprocess(x) - x = rearrange(x, "B 1 H W C -> B C H W") + x = rearrange(x, "B T H W C -> B C T H W") if (i + 1) % cfg.train.accumulation == 0: y, z = autoencoder(x) @@ -173,7 +176,7 @@ def train(runid: str, cfg: DictConfig): grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip) grads.append(grad_norm) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with autoencoder.no_sync(): @@ -221,7 +224,7 @@ def train(runid: str, cfg: DictConfig): for _ in range(cfg.train.epoch_size // cfg.train.batch_size): x, _ = get_well_inputs(next(valid_loader), device=device) x = preprocess(x) - x = rearrange(x, "B 1 H W C -> B C H W") + x = rearrange(x, "B T H W C -> B C T H W") y, z = autoencoder(x) loss = autoencoder_loss(x, y) @@ -281,7 +284,7 @@ def train(runid: str, cfg: DictConfig): # Config cfg = compose( - config_file="./configs/train_ae.yaml", + config_file=os.path.join(os.path.dirname(__file__), "configs", "train_ae.yaml"), overrides=args.overrides, ) diff --git a/experiments/train_dm.py b/experiments/train_dm.py index 4614099d..435bf7cf 100644 --- a/experiments/train_dm.py +++ b/experiments/train_dm.py @@ -45,8 +45,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.denoiser.name}" @@ -102,7 +102,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -134,7 +134,7 @@ def train(runid: str, cfg: DictConfig): denoiser_loss = DenoiserLoss(**cfg.denoiser.loss).to(device) if cfg.fork.run is not None: - denoiser.load_state_dict(stem_state) + denoiser.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state denoiser = DistributedDataParallel( @@ -194,7 +194,7 @@ def train(runid: str, cfg: DictConfig): average.update_parameters(denoiser.module) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with denoiser.no_sync(): diff --git a/experiments/train_ldm.py b/experiments/train_ldm.py index c3e4ee0d..437cf061 100644 --- a/experiments/train_ldm.py +++ b/experiments/train_ldm.py @@ -40,8 +40,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.denoiser.name}" @@ -109,7 +109,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -134,7 +134,7 @@ def train(runid: str, cfg: DictConfig): denoiser_loss = DenoiserLoss(**cfg.denoiser.loss).to(device) if cfg.fork.run is not None: - denoiser.load_state_dict(stem_state) + denoiser.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state denoiser = DistributedDataParallel( @@ -193,7 +193,7 @@ def train(runid: str, cfg: DictConfig): average.update_parameters(denoiser.module) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with denoiser.no_sync(): diff --git a/experiments/train_lsm.py b/experiments/train_lsm.py index 65b48314..f874e2fe 100644 --- a/experiments/train_lsm.py +++ b/experiments/train_lsm.py @@ -40,8 +40,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.surrogate.name}" @@ -109,7 +109,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -131,7 +131,7 @@ def train(runid: str, cfg: DictConfig): ).to(device) if cfg.fork.run is not None: - surrogate.load_state_dict(stem_state) + surrogate.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state surrogate = DistributedDataParallel( @@ -185,7 +185,7 @@ def train(runid: str, cfg: DictConfig): grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip) grads.append(grad_norm) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with surrogate.no_sync(): diff --git a/experiments/train_sm.py b/experiments/train_sm.py index 79ef1974..490b72d9 100644 --- a/experiments/train_sm.py +++ b/experiments/train_sm.py @@ -45,8 +45,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.surrogate.name}" @@ -102,7 +102,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -131,7 +131,7 @@ def train(runid: str, cfg: DictConfig): ).to(device) if cfg.fork.run is not None: - surrogate.load_state_dict(stem_state) + surrogate.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state surrogate = DistributedDataParallel( @@ -186,7 +186,7 @@ def train(runid: str, cfg: DictConfig): grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip) grads.append(grad_norm) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with surrogate.no_sync(): diff --git a/lpdm/nn/attention.py b/lpdm/nn/attention.py index a068f375..10e19399 100644 --- a/lpdm/nn/attention.py +++ b/lpdm/nn/attention.py @@ -8,7 +8,8 @@ import torch.nn as nn import warnings -with warnings.catch_warnings(action="ignore"): +with warnings.catch_warnings(): + warnings.simplefilter("ignore") import xformers.components.attention.core as xfa import xformers.sparse as xfs diff --git a/lpdm/nn/layers.py b/lpdm/nn/layers.py index 2036d252..0729bb15 100644 --- a/lpdm/nn/layers.py +++ b/lpdm/nn/layers.py @@ -194,8 +194,8 @@ def forward(self, x: Tensor) -> Tensor: h, w = self.patch_size return rearrange(x, "... C (H h) (W w) -> ... (C h w) H W", h=h, w=w) elif len(self.patch_size) == 3: - l, h, w = self.patch_size - return rearrange(x, "... C (L l) (H h) (W w) -> ... (C l h w) L H W", l=l, h=h, w=w) + t, h, w = self.patch_size + return rearrange(x, "... C (T t) (H h) (W w) -> ... (C t h w) T H W", t=t, h=h, w=w) else: raise NotImplementedError() diff --git a/lpdm/nn/unet.py b/lpdm/nn/unet.py index c2393ec5..c107971e 100644 --- a/lpdm/nn/unet.py +++ b/lpdm/nn/unet.py @@ -62,6 +62,7 @@ def __init__( self.attn = nn.Identity() else: self.attn = Residual( + LayerNorm(dim=-spatial - 1), SelfAttentionNd(channels, heads=attention_heads), )