From 6e84ad85b56e0a32c881f333ecd19d674db573ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 29 Jan 2025 16:50:46 +0100 Subject: [PATCH 1/6] Increase default batch size and learning rate --- experiments/configs/optim/adamw.yaml | 2 +- experiments/configs/optim/psgd.yaml | 2 +- experiments/configs/optim/soap.yaml | 2 +- experiments/configs/train_ae.yaml | 2 +- experiments/configs/train_dm.yaml | 4 ++-- experiments/configs/train_ldm.yaml | 4 ++-- experiments/configs/train_lsm.yaml | 2 +- experiments/configs/train_sm.yaml | 2 +- experiments/train_ae.py | 8 ++++---- experiments/train_dm.py | 8 ++++---- experiments/train_ldm.py | 8 ++++---- experiments/train_lsm.py | 8 ++++---- experiments/train_sm.py | 8 ++++---- 13 files changed, 30 insertions(+), 30 deletions(-) diff --git a/experiments/configs/optim/adamw.yaml b/experiments/configs/optim/adamw.yaml index e86301ea..49dec497 100644 --- a/experiments/configs/optim/adamw.yaml +++ b/experiments/configs/optim/adamw.yaml @@ -1,7 +1,7 @@ name: "${.optimizer}_${.learning_rate}_${.scheduler}" optimizer: "adamw" betas: [0.9, 0.999] -learning_rate: 1e-5 +learning_rate: 3e-5 weight_decay: 0.0 warmup: 0 scheduler: "cosine" diff --git a/experiments/configs/optim/psgd.yaml b/experiments/configs/optim/psgd.yaml index 6c9d220a..8bbdf262 100644 --- a/experiments/configs/optim/psgd.yaml +++ b/experiments/configs/optim/psgd.yaml @@ -3,7 +3,7 @@ optimizer: "psgd" betas: [0.9] precondition_frequency: 64 precondition_dim: 4096 -learning_rate: 1e-5 +learning_rate: 3e-5 weight_decay: 0.0 warmup: 0 scheduler: "cosine" diff --git a/experiments/configs/optim/soap.yaml b/experiments/configs/optim/soap.yaml index b7c6b890..63ff6e6a 100644 --- a/experiments/configs/optim/soap.yaml +++ b/experiments/configs/optim/soap.yaml @@ -3,7 +3,7 @@ optimizer: "soap" betas: [0.9, 0.999, 0.999] precondition_frequency: 16 precondition_dim: 4096 -learning_rate: 1e-5 +learning_rate: 3e-5 weight_decay: 0.0 warmup: 0 scheduler: "cosine" diff --git a/experiments/configs/train_ae.yaml b/experiments/configs/train_ae.yaml index 9c1caeb3..b4abf02a 100644 --- a/experiments/configs/train_ae.yaml +++ b/experiments/configs/train_ae.yaml @@ -8,7 +8,7 @@ defaults: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 fork: diff --git a/experiments/configs/train_dm.yaml b/experiments/configs/train_dm.yaml index 34c93105..12516a6f 100644 --- a/experiments/configs/train_dm.yaml +++ b/experiments/configs/train_dm.yaml @@ -17,9 +17,9 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 - ema_decay: 0.9999 + ema_decay: 0.999 fork: run: null diff --git a/experiments/configs/train_ldm.yaml b/experiments/configs/train_ldm.yaml index 1df4dd1a..c16d4721 100644 --- a/experiments/configs/train_ldm.yaml +++ b/experiments/configs/train_ldm.yaml @@ -19,9 +19,9 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 - ema_decay: 0.9999 + ema_decay: 0.999 fork: run: null diff --git a/experiments/configs/train_lsm.yaml b/experiments/configs/train_lsm.yaml index 4af391a2..5bf89695 100644 --- a/experiments/configs/train_lsm.yaml +++ b/experiments/configs/train_lsm.yaml @@ -17,7 +17,7 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 fork: diff --git a/experiments/configs/train_sm.yaml b/experiments/configs/train_sm.yaml index 0336e350..41587964 100644 --- a/experiments/configs/train_sm.yaml +++ b/experiments/configs/train_sm.yaml @@ -15,7 +15,7 @@ trajectory: train: epochs: 1024 epoch_size: 16384 - batch_size: 64 + batch_size: 256 accumulation: 1 fork: diff --git a/experiments/train_ae.py b/experiments/train_ae.py index ddc69e20..eded0efc 100644 --- a/experiments/train_ae.py +++ b/experiments/train_ae.py @@ -41,8 +41,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.ae.name}" @@ -96,7 +96,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -173,7 +173,7 @@ def train(runid: str, cfg: DictConfig): grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip) grads.append(grad_norm) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with autoencoder.no_sync(): diff --git a/experiments/train_dm.py b/experiments/train_dm.py index 4614099d..f374b477 100644 --- a/experiments/train_dm.py +++ b/experiments/train_dm.py @@ -45,8 +45,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.denoiser.name}" @@ -102,7 +102,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -194,7 +194,7 @@ def train(runid: str, cfg: DictConfig): average.update_parameters(denoiser.module) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with denoiser.no_sync(): diff --git a/experiments/train_ldm.py b/experiments/train_ldm.py index c3e4ee0d..f395bbb0 100644 --- a/experiments/train_ldm.py +++ b/experiments/train_ldm.py @@ -40,8 +40,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.denoiser.name}" @@ -109,7 +109,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -193,7 +193,7 @@ def train(runid: str, cfg: DictConfig): average.update_parameters(denoiser.module) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with denoiser.no_sync(): diff --git a/experiments/train_lsm.py b/experiments/train_lsm.py index 65b48314..0baae14b 100644 --- a/experiments/train_lsm.py +++ b/experiments/train_lsm.py @@ -40,8 +40,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.surrogate.name}" @@ -109,7 +109,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -185,7 +185,7 @@ def train(runid: str, cfg: DictConfig): grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip) grads.append(grad_norm) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with surrogate.no_sync(): diff --git a/experiments/train_sm.py b/experiments/train_sm.py index 79ef1974..bb681086 100644 --- a/experiments/train_sm.py +++ b/experiments/train_sm.py @@ -45,8 +45,8 @@ def train(runid: str, cfg: DictConfig): torch.cuda.set_device(device) # Config - assert cfg.train.batch_size % world_size == 0 - assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0 + assert cfg.train.epoch_size % cfg.train.batch_size == 0 + assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0 runname = f"{runid}_{cfg.dataset.name}_{cfg.surrogate.name}" @@ -102,7 +102,7 @@ def train(runid: str, cfg: DictConfig): train_loader, valid_loader = [ get_dataloader( dataset=dataset[split], - batch_size=cfg.train.batch_size // world_size, + batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size, shuffle=True if split == "train" else False, infinite=True, num_workers=cfg.compute.cpus_per_gpu, @@ -186,7 +186,7 @@ def train(runid: str, cfg: DictConfig): grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip) grads.append(grad_norm) - counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation + counter["update_samples"] += cfg.train.batch_size counter["update_steps"] += 1 else: with surrogate.no_sync(): From 0b5345d543c79619430fae67728798963044429a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 29 Jan 2025 16:50:49 +0100 Subject: [PATCH 2/6] Update U-Net architecture --- experiments/configs/denoiser/unet_deep_medium.yaml | 8 ++++---- experiments/configs/denoiser/unet_deep_small.yaml | 8 ++++---- experiments/configs/denoiser/unet_large.yaml | 4 ++-- experiments/configs/denoiser/unet_medium.yaml | 4 ++-- experiments/configs/surrogate/unet_deep_medium.yaml | 8 ++++---- experiments/configs/surrogate/unet_deep_small.yaml | 8 ++++---- experiments/configs/surrogate/unet_large.yaml | 4 ++-- experiments/configs/surrogate/unet_medium.yaml | 4 ++-- lpdm/nn/unet.py | 1 + 9 files changed, 25 insertions(+), 24 deletions(-) diff --git a/experiments/configs/denoiser/unet_deep_medium.yaml b/experiments/configs/denoiser/unet_deep_medium.yaml index 303bb6b4..1c9f834a 100644 --- a/experiments/configs/denoiser/unet_deep_medium.yaml +++ b/experiments/configs/denoiser/unet_deep_medium.yaml @@ -1,12 +1,12 @@ name: "unet_deep_medium" arch: "unet" emb_features: 256 -hid_channels: [64, 128, 256, 384, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [64, 128, 256, 512, 768, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/denoiser/unet_deep_small.yaml b/experiments/configs/denoiser/unet_deep_small.yaml index 761f5827..b60b0ec1 100644 --- a/experiments/configs/denoiser/unet_deep_small.yaml +++ b/experiments/configs/denoiser/unet_deep_small.yaml @@ -1,12 +1,12 @@ name: "unet_deep_small" arch: "unet" emb_features: 256 -hid_channels: [32, 64, 128, 256, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [32, 64, 128, 256, 512, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/denoiser/unet_large.yaml b/experiments/configs/denoiser/unet_large.yaml index 9446ea69..72c1e4be 100644 --- a/experiments/configs/denoiser/unet_large.yaml +++ b/experiments/configs/denoiser/unet_large.yaml @@ -2,10 +2,10 @@ name: "unet_large" arch: "unet" emb_features: 256 hid_channels: [256, 512, 768, 1024] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 6, 3: 8} +attention_heads: {3: 8} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/denoiser/unet_medium.yaml b/experiments/configs/denoiser/unet_medium.yaml index dd47cfac..24138189 100644 --- a/experiments/configs/denoiser/unet_medium.yaml +++ b/experiments/configs/denoiser/unet_medium.yaml @@ -2,10 +2,10 @@ name: "unet_medium" arch: "unet" emb_features: 256 hid_channels: [128, 256, 384, 512] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 3, 3: 4} +attention_heads: {3: 4} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/surrogate/unet_deep_medium.yaml b/experiments/configs/surrogate/unet_deep_medium.yaml index 303bb6b4..1c9f834a 100644 --- a/experiments/configs/surrogate/unet_deep_medium.yaml +++ b/experiments/configs/surrogate/unet_deep_medium.yaml @@ -1,12 +1,12 @@ name: "unet_deep_medium" arch: "unet" emb_features: 256 -hid_channels: [64, 128, 256, 384, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [64, 128, 256, 512, 768, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/surrogate/unet_deep_small.yaml b/experiments/configs/surrogate/unet_deep_small.yaml index 761f5827..b60b0ec1 100644 --- a/experiments/configs/surrogate/unet_deep_small.yaml +++ b/experiments/configs/surrogate/unet_deep_small.yaml @@ -1,12 +1,12 @@ name: "unet_deep_small" arch: "unet" emb_features: 256 -hid_channels: [32, 64, 128, 256, 512, 768, 1024] -hid_blocks: [1, 2, 3, 3, 3, 3, 3] +hid_channels: [32, 64, 128, 256, 512, 1024] +hid_blocks: [1, 2, 3, 3, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {6: 8} -periodic: false +attention_heads: {5: 8} +periodic: true dropout: 0.05 checkpointing: true identity_init: true diff --git a/experiments/configs/surrogate/unet_large.yaml b/experiments/configs/surrogate/unet_large.yaml index 9446ea69..72c1e4be 100644 --- a/experiments/configs/surrogate/unet_large.yaml +++ b/experiments/configs/surrogate/unet_large.yaml @@ -2,10 +2,10 @@ name: "unet_large" arch: "unet" emb_features: 256 hid_channels: [256, 512, 768, 1024] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 6, 3: 8} +attention_heads: {3: 8} periodic: false dropout: 0.05 checkpointing: false diff --git a/experiments/configs/surrogate/unet_medium.yaml b/experiments/configs/surrogate/unet_medium.yaml index dd47cfac..24138189 100644 --- a/experiments/configs/surrogate/unet_medium.yaml +++ b/experiments/configs/surrogate/unet_medium.yaml @@ -2,10 +2,10 @@ name: "unet_medium" arch: "unet" emb_features: 256 hid_channels: [128, 256, 384, 512] -hid_blocks: [3, 3, 3, 3] +hid_blocks: [1, 2, 3, 3] stride: [1, 2, 2] norm: "layer" -attention_heads: {2: 3, 3: 4} +attention_heads: {3: 4} periodic: false dropout: 0.05 checkpointing: false diff --git a/lpdm/nn/unet.py b/lpdm/nn/unet.py index c2393ec5..c107971e 100644 --- a/lpdm/nn/unet.py +++ b/lpdm/nn/unet.py @@ -62,6 +62,7 @@ def __init__( self.attn = nn.Identity() else: self.attn = Residual( + LayerNorm(dim=-spatial - 1), SelfAttentionNd(channels, heads=attention_heads), ) From 64f642ca9398af3b1da6d14a2dca2020206fae86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 29 Jan 2025 15:30:16 +0100 Subject: [PATCH 3/6] Add ViT-H config --- experiments/configs/denoiser/vit_huge.yaml | 12 ++++++++++++ experiments/configs/surrogate/vit_huge.yaml | 12 ++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 experiments/configs/denoiser/vit_huge.yaml create mode 100644 experiments/configs/surrogate/vit_huge.yaml diff --git a/experiments/configs/denoiser/vit_huge.yaml b/experiments/configs/denoiser/vit_huge.yaml new file mode 100644 index 00000000..fec99e90 --- /dev/null +++ b/experiments/configs/denoiser/vit_huge.yaml @@ -0,0 +1,12 @@ +name: "vit_huge" +arch: "vit" +emb_features: 256 +hid_channels: 2048 +hid_blocks: 16 +attention_heads: 16 +qk_norm: true +rope: true +patch_size: [1, 1, 1] +window_size: null +dropout: 0.05 +checkpointing: false diff --git a/experiments/configs/surrogate/vit_huge.yaml b/experiments/configs/surrogate/vit_huge.yaml new file mode 100644 index 00000000..fec99e90 --- /dev/null +++ b/experiments/configs/surrogate/vit_huge.yaml @@ -0,0 +1,12 @@ +name: "vit_huge" +arch: "vit" +emb_features: 256 +hid_channels: 2048 +hid_blocks: 16 +attention_heads: 16 +qk_norm: true +rope: true +patch_size: [1, 1, 1] +window_size: null +dropout: 0.05 +checkpointing: false From e70544cd854e9bfcc01acea5966acc0e2639eb04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 29 Jan 2025 14:58:27 +0100 Subject: [PATCH 4/6] Use MSE objective by default --- experiments/configs/train_ae.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/configs/train_ae.yaml b/experiments/configs/train_ae.yaml index b4abf02a..9ee9c57e 100644 --- a/experiments/configs/train_ae.yaml +++ b/experiments/configs/train_ae.yaml @@ -1,6 +1,6 @@ defaults: - ae: f32c64_medium - - ae/loss: mae + - ae/loss: mse - dataset: euler_all - optim: psgd - server: rusty From ae0b8c10a8eece7d43fe4d529cf778489864ff9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Wed, 29 Jan 2025 15:28:10 +0100 Subject: [PATCH 5/6] Enable loose forking --- experiments/configs/train_ae.yaml | 1 + experiments/configs/train_dm.yaml | 1 + experiments/configs/train_ldm.yaml | 1 + experiments/configs/train_lsm.yaml | 1 + experiments/configs/train_sm.yaml | 1 + experiments/train_ae.py | 2 +- experiments/train_dm.py | 2 +- experiments/train_ldm.py | 2 +- experiments/train_lsm.py | 2 +- experiments/train_sm.py | 2 +- 10 files changed, 10 insertions(+), 5 deletions(-) diff --git a/experiments/configs/train_ae.yaml b/experiments/configs/train_ae.yaml index 9ee9c57e..59cc3afc 100644 --- a/experiments/configs/train_ae.yaml +++ b/experiments/configs/train_ae.yaml @@ -14,6 +14,7 @@ train: fork: run: null target: "state" + strict: true compute: cpus_per_gpu: 8 diff --git a/experiments/configs/train_dm.yaml b/experiments/configs/train_dm.yaml index 12516a6f..e89b609b 100644 --- a/experiments/configs/train_dm.yaml +++ b/experiments/configs/train_dm.yaml @@ -24,6 +24,7 @@ train: fork: run: null target: "state" + strict: true compute: nodes: 1 diff --git a/experiments/configs/train_ldm.yaml b/experiments/configs/train_ldm.yaml index c16d4721..4e8d61ac 100644 --- a/experiments/configs/train_ldm.yaml +++ b/experiments/configs/train_ldm.yaml @@ -26,6 +26,7 @@ train: fork: run: null target: "state" + strict: true compute: cpus_per_gpu: 8 diff --git a/experiments/configs/train_lsm.yaml b/experiments/configs/train_lsm.yaml index 5bf89695..176a0f01 100644 --- a/experiments/configs/train_lsm.yaml +++ b/experiments/configs/train_lsm.yaml @@ -23,6 +23,7 @@ train: fork: run: null target: "state" + strict: true compute: cpus_per_gpu: 8 diff --git a/experiments/configs/train_sm.yaml b/experiments/configs/train_sm.yaml index 41587964..8c94ce88 100644 --- a/experiments/configs/train_sm.yaml +++ b/experiments/configs/train_sm.yaml @@ -21,6 +21,7 @@ train: fork: run: null target: "state" + strict: true compute: nodes: 1 diff --git a/experiments/train_ae.py b/experiments/train_ae.py index eded0efc..697968a8 100644 --- a/experiments/train_ae.py +++ b/experiments/train_ae.py @@ -123,7 +123,7 @@ def train(runid: str, cfg: DictConfig): autoencoder_loss = WeightedLoss(**cfg.ae.loss).to(device) if cfg.fork.run is not None: - autoencoder.load_state_dict(stem_state) + autoencoder.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state autoencoder = DistributedDataParallel( diff --git a/experiments/train_dm.py b/experiments/train_dm.py index f374b477..435bf7cf 100644 --- a/experiments/train_dm.py +++ b/experiments/train_dm.py @@ -134,7 +134,7 @@ def train(runid: str, cfg: DictConfig): denoiser_loss = DenoiserLoss(**cfg.denoiser.loss).to(device) if cfg.fork.run is not None: - denoiser.load_state_dict(stem_state) + denoiser.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state denoiser = DistributedDataParallel( diff --git a/experiments/train_ldm.py b/experiments/train_ldm.py index f395bbb0..437cf061 100644 --- a/experiments/train_ldm.py +++ b/experiments/train_ldm.py @@ -134,7 +134,7 @@ def train(runid: str, cfg: DictConfig): denoiser_loss = DenoiserLoss(**cfg.denoiser.loss).to(device) if cfg.fork.run is not None: - denoiser.load_state_dict(stem_state) + denoiser.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state denoiser = DistributedDataParallel( diff --git a/experiments/train_lsm.py b/experiments/train_lsm.py index 0baae14b..f874e2fe 100644 --- a/experiments/train_lsm.py +++ b/experiments/train_lsm.py @@ -131,7 +131,7 @@ def train(runid: str, cfg: DictConfig): ).to(device) if cfg.fork.run is not None: - surrogate.load_state_dict(stem_state) + surrogate.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state surrogate = DistributedDataParallel( diff --git a/experiments/train_sm.py b/experiments/train_sm.py index bb681086..490b72d9 100644 --- a/experiments/train_sm.py +++ b/experiments/train_sm.py @@ -131,7 +131,7 @@ def train(runid: str, cfg: DictConfig): ).to(device) if cfg.fork.run is not None: - surrogate.load_state_dict(stem_state) + surrogate.load_state_dict(stem_state, strict=cfg.fork.strict) del stem_state surrogate = DistributedDataParallel( From 24517a5957d542ded24e0fe9eb0e7279ab3eb6c4 Mon Sep 17 00:00:00 2001 From: Helen Qu Date: Wed, 29 Jan 2025 14:11:40 -0500 Subject: [PATCH 6/6] add support for time dim --- experiments/get_stats.py | 2 +- experiments/train_ae.py | 13 ++++++++----- lpdm/nn/attention.py | 3 ++- lpdm/nn/layers.py | 4 ++-- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/experiments/get_stats.py b/experiments/get_stats.py index e36306a1..2ad134f0 100644 --- a/experiments/get_stats.py +++ b/experiments/get_stats.py @@ -90,7 +90,7 @@ def get_stats(cfg: DictConfig): # Job dawgz.schedule( dawgz.job( - f=partial(get_stats, cfg.dataset, args.samples), + f=partial(get_stats, cfg), name="stats", cpus=cfg.compute.cpus, gpus=cfg.compute.gpus, diff --git a/experiments/train_ae.py b/experiments/train_ae.py index 697968a8..b236e5ce 100644 --- a/experiments/train_ae.py +++ b/experiments/train_ae.py @@ -3,6 +3,7 @@ import argparse import dawgz import wandb +import os from functools import partial from omegaconf import DictConfig @@ -86,7 +87,7 @@ def train(runid: str, cfg: DictConfig): path=cfg.server.datasets, physics=cfg.dataset.physics, split=split, - steps=1, + steps=cfg.dataset.steps, include_filters=cfg.dataset.include_filters, augment=cfg.dataset.augment, ) @@ -119,6 +120,8 @@ def train(runid: str, cfg: DictConfig): pix_channels=dataset["train"].metadata.n_fields, **cfg.ae, ).to(device) + if rank == 0: + print(autoencoder) autoencoder_loss = WeightedLoss(**cfg.ae.loss).to(device) @@ -138,7 +141,7 @@ def train(runid: str, cfg: DictConfig): ) # W&B - if rank == 0: + if rank == 0 and not cfg.wandb.dry_run: run = wandb.init( entity=cfg.wandb.entity, project="mpp-ae", @@ -162,7 +165,7 @@ def train(runid: str, cfg: DictConfig): for i in range(cfg.train.epoch_size // cfg.train.batch_size): x, _ = get_well_inputs(next(train_loader), device=device) x = preprocess(x) - x = rearrange(x, "B 1 H W C -> B C H W") + x = rearrange(x, "B T H W C -> B C T H W") if (i + 1) % cfg.train.accumulation == 0: y, z = autoencoder(x) @@ -221,7 +224,7 @@ def train(runid: str, cfg: DictConfig): for _ in range(cfg.train.epoch_size // cfg.train.batch_size): x, _ = get_well_inputs(next(valid_loader), device=device) x = preprocess(x) - x = rearrange(x, "B 1 H W C -> B C H W") + x = rearrange(x, "B T H W C -> B C T H W") y, z = autoencoder(x) loss = autoencoder_loss(x, y) @@ -281,7 +284,7 @@ def train(runid: str, cfg: DictConfig): # Config cfg = compose( - config_file="./configs/train_ae.yaml", + config_file=os.path.join(os.path.dirname(__file__), "configs", "train_ae.yaml"), overrides=args.overrides, ) diff --git a/lpdm/nn/attention.py b/lpdm/nn/attention.py index a068f375..10e19399 100644 --- a/lpdm/nn/attention.py +++ b/lpdm/nn/attention.py @@ -8,7 +8,8 @@ import torch.nn as nn import warnings -with warnings.catch_warnings(action="ignore"): +with warnings.catch_warnings(): + warnings.simplefilter("ignore") import xformers.components.attention.core as xfa import xformers.sparse as xfs diff --git a/lpdm/nn/layers.py b/lpdm/nn/layers.py index 2036d252..0729bb15 100644 --- a/lpdm/nn/layers.py +++ b/lpdm/nn/layers.py @@ -194,8 +194,8 @@ def forward(self, x: Tensor) -> Tensor: h, w = self.patch_size return rearrange(x, "... C (H h) (W w) -> ... (C h w) H W", h=h, w=w) elif len(self.patch_size) == 3: - l, h, w = self.patch_size - return rearrange(x, "... C (L l) (H h) (W w) -> ... (C l h w) L H W", l=l, h=h, w=w) + t, h, w = self.patch_size + return rearrange(x, "... C (T t) (H h) (W w) -> ... (C t h w) T H W", t=t, h=h, w=w) else: raise NotImplementedError()