Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions experiments/configs/denoiser/unet_deep_medium.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: "unet_deep_medium"
arch: "unet"
emb_features: 256
hid_channels: [64, 128, 256, 384, 512, 768, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3, 3]
hid_channels: [64, 128, 256, 512, 768, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {6: 8}
periodic: false
attention_heads: {5: 8}
periodic: true
dropout: 0.05
checkpointing: true
identity_init: true
8 changes: 4 additions & 4 deletions experiments/configs/denoiser/unet_deep_small.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: "unet_deep_small"
arch: "unet"
emb_features: 256
hid_channels: [32, 64, 128, 256, 512, 768, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3, 3]
hid_channels: [32, 64, 128, 256, 512, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {6: 8}
periodic: false
attention_heads: {5: 8}
periodic: true
dropout: 0.05
checkpointing: true
identity_init: true
4 changes: 2 additions & 2 deletions experiments/configs/denoiser/unet_large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: "unet_large"
arch: "unet"
emb_features: 256
hid_channels: [256, 512, 768, 1024]
hid_blocks: [3, 3, 3, 3]
hid_blocks: [1, 2, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {2: 6, 3: 8}
attention_heads: {3: 8}
periodic: false
dropout: 0.05
checkpointing: false
Expand Down
4 changes: 2 additions & 2 deletions experiments/configs/denoiser/unet_medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: "unet_medium"
arch: "unet"
emb_features: 256
hid_channels: [128, 256, 384, 512]
hid_blocks: [3, 3, 3, 3]
hid_blocks: [1, 2, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {2: 3, 3: 4}
attention_heads: {3: 4}
periodic: false
dropout: 0.05
checkpointing: false
Expand Down
12 changes: 12 additions & 0 deletions experiments/configs/denoiser/vit_huge.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "vit_huge"
arch: "vit"
emb_features: 256
hid_channels: 2048
hid_blocks: 16
attention_heads: 16
qk_norm: true
rope: true
patch_size: [1, 1, 1]
window_size: null
dropout: 0.05
checkpointing: false
2 changes: 1 addition & 1 deletion experiments/configs/optim/adamw.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: "${.optimizer}_${.learning_rate}_${.scheduler}"
optimizer: "adamw"
betas: [0.9, 0.999]
learning_rate: 1e-5
learning_rate: 3e-5
weight_decay: 0.0
warmup: 0
scheduler: "cosine"
Expand Down
2 changes: 1 addition & 1 deletion experiments/configs/optim/psgd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ optimizer: "psgd"
betas: [0.9]
precondition_frequency: 64
precondition_dim: 4096
learning_rate: 1e-5
learning_rate: 3e-5
weight_decay: 0.0
warmup: 0
scheduler: "cosine"
Expand Down
2 changes: 1 addition & 1 deletion experiments/configs/optim/soap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ optimizer: "soap"
betas: [0.9, 0.999, 0.999]
precondition_frequency: 16
precondition_dim: 4096
learning_rate: 1e-5
learning_rate: 3e-5
weight_decay: 0.0
warmup: 0
scheduler: "cosine"
Expand Down
8 changes: 4 additions & 4 deletions experiments/configs/surrogate/unet_deep_medium.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: "unet_deep_medium"
arch: "unet"
emb_features: 256
hid_channels: [64, 128, 256, 384, 512, 768, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3, 3]
hid_channels: [64, 128, 256, 512, 768, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {6: 8}
periodic: false
attention_heads: {5: 8}
periodic: true
dropout: 0.05
checkpointing: true
identity_init: true
8 changes: 4 additions & 4 deletions experiments/configs/surrogate/unet_deep_small.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: "unet_deep_small"
arch: "unet"
emb_features: 256
hid_channels: [32, 64, 128, 256, 512, 768, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3, 3]
hid_channels: [32, 64, 128, 256, 512, 1024]
hid_blocks: [1, 2, 3, 3, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {6: 8}
periodic: false
attention_heads: {5: 8}
periodic: true
dropout: 0.05
checkpointing: true
identity_init: true
4 changes: 2 additions & 2 deletions experiments/configs/surrogate/unet_large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: "unet_large"
arch: "unet"
emb_features: 256
hid_channels: [256, 512, 768, 1024]
hid_blocks: [3, 3, 3, 3]
hid_blocks: [1, 2, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {2: 6, 3: 8}
attention_heads: {3: 8}
periodic: false
dropout: 0.05
checkpointing: false
Expand Down
4 changes: 2 additions & 2 deletions experiments/configs/surrogate/unet_medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ name: "unet_medium"
arch: "unet"
emb_features: 256
hid_channels: [128, 256, 384, 512]
hid_blocks: [3, 3, 3, 3]
hid_blocks: [1, 2, 3, 3]
stride: [1, 2, 2]
norm: "layer"
attention_heads: {2: 3, 3: 4}
attention_heads: {3: 4}
periodic: false
dropout: 0.05
checkpointing: false
Expand Down
12 changes: 12 additions & 0 deletions experiments/configs/surrogate/vit_huge.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "vit_huge"
arch: "vit"
emb_features: 256
hid_channels: 2048
hid_blocks: 16
attention_heads: 16
qk_norm: true
rope: true
patch_size: [1, 1, 1]
window_size: null
dropout: 0.05
checkpointing: false
5 changes: 3 additions & 2 deletions experiments/configs/train_ae.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
defaults:
- ae: f32c64_medium
- ae/loss: mae
- ae/loss: mse
- dataset: euler_all
- optim: psgd
- server: rusty

train:
epochs: 1024
epoch_size: 16384
batch_size: 64
batch_size: 256
accumulation: 1

fork:
run: null
target: "state"
strict: true

compute:
cpus_per_gpu: 8
Expand Down
5 changes: 3 additions & 2 deletions experiments/configs/train_dm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ trajectory:
train:
epochs: 1024
epoch_size: 16384
batch_size: 64
batch_size: 256
accumulation: 1
ema_decay: 0.9999
ema_decay: 0.999

fork:
run: null
target: "state"
strict: true

compute:
nodes: 1
Expand Down
5 changes: 3 additions & 2 deletions experiments/configs/train_ldm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ trajectory:
train:
epochs: 1024
epoch_size: 16384
batch_size: 64
batch_size: 256
accumulation: 1
ema_decay: 0.9999
ema_decay: 0.999

fork:
run: null
target: "state"
strict: true

compute:
cpus_per_gpu: 8
Expand Down
3 changes: 2 additions & 1 deletion experiments/configs/train_lsm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ trajectory:
train:
epochs: 1024
epoch_size: 16384
batch_size: 64
batch_size: 256
accumulation: 1

fork:
run: null
target: "state"
strict: true

compute:
cpus_per_gpu: 8
Expand Down
3 changes: 2 additions & 1 deletion experiments/configs/train_sm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ trajectory:
train:
epochs: 1024
epoch_size: 16384
batch_size: 64
batch_size: 256
accumulation: 1

fork:
run: null
target: "state"
strict: true

compute:
nodes: 1
Expand Down
2 changes: 1 addition & 1 deletion experiments/get_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_stats(cfg: DictConfig):
# Job
dawgz.schedule(
dawgz.job(
f=partial(get_stats, cfg.dataset, args.samples),
f=partial(get_stats, cfg),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Introduced in a51de35.

name="stats",
cpus=cfg.compute.cpus,
gpus=cfg.compute.gpus,
Expand Down
23 changes: 13 additions & 10 deletions experiments/train_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import dawgz
import wandb
import os

from functools import partial
from omegaconf import DictConfig
Expand Down Expand Up @@ -41,8 +42,8 @@ def train(runid: str, cfg: DictConfig):
torch.cuda.set_device(device)

# Config
assert cfg.train.batch_size % world_size == 0
assert cfg.train.epoch_size % (cfg.train.batch_size * cfg.train.accumulation) == 0
assert cfg.train.epoch_size % cfg.train.batch_size == 0
assert cfg.train.batch_size % (cfg.train.accumulation * world_size) == 0

runname = f"{runid}_{cfg.dataset.name}_{cfg.ae.name}"

Expand Down Expand Up @@ -86,7 +87,7 @@ def train(runid: str, cfg: DictConfig):
path=cfg.server.datasets,
physics=cfg.dataset.physics,
split=split,
steps=1,
steps=cfg.dataset.steps,
include_filters=cfg.dataset.include_filters,
augment=cfg.dataset.augment,
)
Expand All @@ -96,7 +97,7 @@ def train(runid: str, cfg: DictConfig):
train_loader, valid_loader = [
get_dataloader(
dataset=dataset[split],
batch_size=cfg.train.batch_size // world_size,
batch_size=cfg.train.batch_size // cfg.train.accumulation // world_size,
shuffle=True if split == "train" else False,
infinite=True,
num_workers=cfg.compute.cpus_per_gpu,
Expand All @@ -119,11 +120,13 @@ def train(runid: str, cfg: DictConfig):
pix_channels=dataset["train"].metadata.n_fields,
**cfg.ae,
).to(device)
if rank == 0:
print(autoencoder)

autoencoder_loss = WeightedLoss(**cfg.ae.loss).to(device)

if cfg.fork.run is not None:
autoencoder.load_state_dict(stem_state)
autoencoder.load_state_dict(stem_state, strict=cfg.fork.strict)
del stem_state

autoencoder = DistributedDataParallel(
Expand All @@ -138,7 +141,7 @@ def train(runid: str, cfg: DictConfig):
)

# W&B
if rank == 0:
if rank == 0 and not cfg.wandb.dry_run:
run = wandb.init(
entity=cfg.wandb.entity,
project="mpp-ae",
Expand All @@ -162,7 +165,7 @@ def train(runid: str, cfg: DictConfig):
for i in range(cfg.train.epoch_size // cfg.train.batch_size):
x, _ = get_well_inputs(next(train_loader), device=device)
x = preprocess(x)
x = rearrange(x, "B 1 H W C -> B C H W")
x = rearrange(x, "B T H W C -> B C T H W")

if (i + 1) % cfg.train.accumulation == 0:
y, z = autoencoder(x)
Expand All @@ -173,7 +176,7 @@ def train(runid: str, cfg: DictConfig):
grad_norm = safe_gd_step(optimizer, grad_clip=cfg.optim.grad_clip)
grads.append(grad_norm)

counter["update_samples"] += cfg.train.batch_size * cfg.train.accumulation
counter["update_samples"] += cfg.train.batch_size
counter["update_steps"] += 1
else:
with autoencoder.no_sync():
Expand Down Expand Up @@ -221,7 +224,7 @@ def train(runid: str, cfg: DictConfig):
for _ in range(cfg.train.epoch_size // cfg.train.batch_size):
x, _ = get_well_inputs(next(valid_loader), device=device)
x = preprocess(x)
x = rearrange(x, "B 1 H W C -> B C H W")
x = rearrange(x, "B T H W C -> B C T H W")

y, z = autoencoder(x)
loss = autoencoder_loss(x, y)
Expand Down Expand Up @@ -281,7 +284,7 @@ def train(runid: str, cfg: DictConfig):

# Config
cfg = compose(
config_file="./configs/train_ae.yaml",
config_file=os.path.join(os.path.dirname(__file__), "configs", "train_ae.yaml"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary.

overrides=args.overrides,
)

Expand Down
Loading
Loading