diff --git a/README.md b/README.md
index 1841881..f3c220f 100644
--- a/README.md
+++ b/README.md
@@ -64,8 +64,9 @@ Developers should set up `pre-commit` as well with `pre-commit install`.
### Running Test Cases
```
-> pytest # will run all test cases - including ones that require a gpu
-> pytest -m "not gpu" # run test cases that can work with just cpu
+> pytest # run test cases that can work with just cpu
+> pytest -m '' # will run all test cases - including ones that require a gpu
+> pytest -m gpu # run only gpu test cases
```
@@ -99,6 +100,30 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port
## Codebase
+
+### 1. /configs
+
+| module | description |
+| - | - |
+| `configs.dataset_creation` | Configuration file for dataset splitting into train-eval-val pipeline |
+| `configs.datasets` | Datasets for training and evaluation phases of the model |
+| `configs.models` | Configuration files for different resolution models |
+
+
+### 2. /data
+
+| module | description |
+| - | - |
+| `data` |
- bert.vocab: BERT-trained dictionary containing tokens and their associated vector representations
- c4_wpm.vocab: C4-trained dictionary containing tokens and their associated vector representations
- cifar10.vocab: CIFAR10-trained dictionary containing tokens and their associated vector representations
- imagenet.vocab: Prompts associated with Imagenet dataset
- prompts_cc12m-64x64.tsv: Prompts associated with cc12m dataset for the 64x64 res. model
- prompts_cc12m-256x256.tsv: Prompts associated with cc12m dataset for the 256x256 res. model
- prompts_cifar10-32x32.tsv: Prompts associated with cifar10 dataset for the 32x32 res. model
- prompts_cifar10-64x64.tsv: Prompts associated with cifar10 dataset for the 64x64 res. model
- prompts_demo.tsv: Extra demo prompts
- prompts_imagenet-64px.tsv: Prompts associated with imagenet dataset for the 64x64 res. model
- prompts_WebImage-ALIGN-64px.tsv: Prompts associated with WebImage-ALIGN dataset for the 64x64 res. model
- t5.vocab: t5-trained dictionary containing tokens and their associated vector representations
- tokenizer_spm_32000_50m.vocab: SPM-trained dictionary containing tokens and their associated vector representations
|
+
+### 3. /docs
+
+| module | description |
+| - | - |
+| `docs` | - web_demo.png: Screenshot of the web demo of the model
|
+
+### 4. /ml_mdm
+
| module | description |
| - | - |
| `ml_mdm.models` | The core model implementations |
@@ -107,7 +132,11 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port
| `ml_mdm.clis` | All command line tools in the project, the most relevant being `train_parallel.py` |
| `tests/` | Unit tests and sample training files |
+### 5. /tests
+| module | description |
+| - | - |
+| `tests.test_files` | Sample files for testing |
# Concepts
@@ -125,6 +154,22 @@ In the `ml_mdm.models` submodule, we've open sourced our implementations of:
> In essence, `simple_parsing` will convert all passed cli arguments and yaml files into clean configuration classes like `ml_mdm.reader.ReaderConfig`, `ml_mdm.diffusion.DiffusionConfig`.
+`ml_mdm.config` stores a global mapping of names to classes in `MODEL_REGISTRY`, `MODEL_CONFIG_REGISTRY`, `PIPELINE_REGISTRY`, and `PIPELINE_CONFIG_REGISTRY`.
+
+`MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example:
+
+> *_CONFIG_REGISTRY[architecture name]["model"] = model name
+
+> *_CONFIG_REGISTRY[architecture name]["config"] = configuration class
+
+MODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY store information as shown in the following example:
+> *_CONFIG_REGISTRY[architecture name]["model"] = model name
+
+> *_CONFIG_REGISTRY[architecture name]["config"] = configuration class
+
+
+architecture name and model name are passed into ml_mdm.config through the function parameter *names. where *names points to "architecture name", "model name"
+
# Tutorials
@@ -263,11 +308,11 @@ reader_config:
Then you can use our dataset download helper:
```console
python -m ml_mdm.clis.download_tar_from_index \
- --dataset-config-file configs/datasets/cc12m.yaml \
+ --dataset_config_file configs/datasets/cc12m.yaml \
--subset train --download_tar
python -m ml_mdm.clis.download_tar_from_index \
- --dataset-config-file configs/datasets/cc12m.yaml \
+ --dataset_config_file configs/datasets/cc12m.yaml \
--subset eval --download_tar
```
diff --git a/configs/dataset_creation/sample_cc12m.yaml b/ml-mdm-matryoshka/configs/dataset_creation/sample_cc12m.yaml
similarity index 100%
rename from configs/dataset_creation/sample_cc12m.yaml
rename to ml-mdm-matryoshka/configs/dataset_creation/sample_cc12m.yaml
diff --git a/configs/datasets/cc12m.yaml b/ml-mdm-matryoshka/configs/datasets/cc12m.yaml
similarity index 100%
rename from configs/datasets/cc12m.yaml
rename to ml-mdm-matryoshka/configs/datasets/cc12m.yaml
diff --git a/configs/models/cc12m_1024x1024.yaml b/ml-mdm-matryoshka/configs/models/cc12m_1024x1024.yaml
similarity index 100%
rename from configs/models/cc12m_1024x1024.yaml
rename to ml-mdm-matryoshka/configs/models/cc12m_1024x1024.yaml
diff --git a/configs/models/cc12m_256x256.yaml b/ml-mdm-matryoshka/configs/models/cc12m_256x256.yaml
similarity index 100%
rename from configs/models/cc12m_256x256.yaml
rename to ml-mdm-matryoshka/configs/models/cc12m_256x256.yaml
diff --git a/configs/models/cc12m_64x64.yaml b/ml-mdm-matryoshka/configs/models/cc12m_64x64.yaml
similarity index 100%
rename from configs/models/cc12m_64x64.yaml
rename to ml-mdm-matryoshka/configs/models/cc12m_64x64.yaml
diff --git a/data/bert.vocab b/ml-mdm-matryoshka/data/bert.vocab
similarity index 100%
rename from data/bert.vocab
rename to ml-mdm-matryoshka/data/bert.vocab
diff --git a/data/c4_wpm.vocab b/ml-mdm-matryoshka/data/c4_wpm.vocab
similarity index 100%
rename from data/c4_wpm.vocab
rename to ml-mdm-matryoshka/data/c4_wpm.vocab
diff --git a/data/cifar10.vocab b/ml-mdm-matryoshka/data/cifar10.vocab
similarity index 100%
rename from data/cifar10.vocab
rename to ml-mdm-matryoshka/data/cifar10.vocab
diff --git a/data/imagenet.vocab b/ml-mdm-matryoshka/data/imagenet.vocab
similarity index 100%
rename from data/imagenet.vocab
rename to ml-mdm-matryoshka/data/imagenet.vocab
diff --git a/data/prompts_WebImage-ALIGN-64px.tsv b/ml-mdm-matryoshka/data/prompts_WebImage-ALIGN-64px.tsv
similarity index 100%
rename from data/prompts_WebImage-ALIGN-64px.tsv
rename to ml-mdm-matryoshka/data/prompts_WebImage-ALIGN-64px.tsv
diff --git a/data/prompts_cc12m-256x256.tsv b/ml-mdm-matryoshka/data/prompts_cc12m-256x256.tsv
similarity index 100%
rename from data/prompts_cc12m-256x256.tsv
rename to ml-mdm-matryoshka/data/prompts_cc12m-256x256.tsv
diff --git a/data/prompts_cc12m-64x64.tsv b/ml-mdm-matryoshka/data/prompts_cc12m-64x64.tsv
similarity index 100%
rename from data/prompts_cc12m-64x64.tsv
rename to ml-mdm-matryoshka/data/prompts_cc12m-64x64.tsv
diff --git a/data/prompts_cifar10-32x32.tsv b/ml-mdm-matryoshka/data/prompts_cifar10-32x32.tsv
similarity index 100%
rename from data/prompts_cifar10-32x32.tsv
rename to ml-mdm-matryoshka/data/prompts_cifar10-32x32.tsv
diff --git a/data/prompts_cifar10-64x64.tsv b/ml-mdm-matryoshka/data/prompts_cifar10-64x64.tsv
similarity index 100%
rename from data/prompts_cifar10-64x64.tsv
rename to ml-mdm-matryoshka/data/prompts_cifar10-64x64.tsv
diff --git a/data/prompts_demo.tsv b/ml-mdm-matryoshka/data/prompts_demo.tsv
similarity index 100%
rename from data/prompts_demo.tsv
rename to ml-mdm-matryoshka/data/prompts_demo.tsv
diff --git a/data/prompts_imagenet-64px.tsv b/ml-mdm-matryoshka/data/prompts_imagenet-64px.tsv
similarity index 100%
rename from data/prompts_imagenet-64px.tsv
rename to ml-mdm-matryoshka/data/prompts_imagenet-64px.tsv
diff --git a/data/t5.vocab b/ml-mdm-matryoshka/data/t5.vocab
similarity index 100%
rename from data/t5.vocab
rename to ml-mdm-matryoshka/data/t5.vocab
diff --git a/data/tokenizer_spm_32000_50m.vocab b/ml-mdm-matryoshka/data/tokenizer_spm_32000_50m.vocab
similarity index 100%
rename from data/tokenizer_spm_32000_50m.vocab
rename to ml-mdm-matryoshka/data/tokenizer_spm_32000_50m.vocab
diff --git a/ml_mdm/clis/__init__.py b/ml-mdm-matryoshka/ml_mdm/clis/__init__.py
similarity index 100%
rename from ml_mdm/clis/__init__.py
rename to ml-mdm-matryoshka/ml_mdm/clis/__init__.py
diff --git a/ml_mdm/clis/download_tar_from_index.py b/ml-mdm-matryoshka/ml_mdm/clis/download_tar_from_index.py
similarity index 90%
rename from ml_mdm/clis/download_tar_from_index.py
rename to ml-mdm-matryoshka/ml_mdm/clis/download_tar_from_index.py
index faba594..08bb793 100644
--- a/ml_mdm/clis/download_tar_from_index.py
+++ b/ml-mdm-matryoshka/ml_mdm/clis/download_tar_from_index.py
@@ -17,7 +17,7 @@
nodes this data will be distributed over.
"""
-import argparse
+import simple_parsing
import csv
import logging
import os
@@ -33,7 +33,30 @@
import mlx.data
from ml_mdm import helpers, s3_helpers
-
+from dataclasses import dataclass, field
+
+@dataclass
+class DownloadConfig:
+ dataset_config_file: str = field(default="",
+ metadata={"help": "yaml file with dataset names"})
+ worker_id: int = field(default=0,
+ metadata={"help": "current worker in [0, num-downloaders -1]"})
+ num_downloaders: int = field(default=1,
+ metadata={"help": "number of parallel downloaders"})
+ no_bandwidth: bool = field(default=False)
+ download_tar: bool = field(default=False,
+ metadata={"help": "whether or not to download tar files also"})
+ pretrained_text_embeddings: str = field(default=None)
+ endpoint_url: str = field(default="",
+ metadata={"help": "end point for the s3 bucket — uses environment variable AWS_ENDPOINT_URL otherwise"})
+ subset: str = field(default="train",
+ metadata={"choices": ["train", "eval"],
+ "help": "subset to download [train|eval]"})
+
+def get_parser():
+ parser = simple_parsing.ArgumentParser(description="Download tar files referred to in index file from mlx")
+ parser.add_arguments(DownloadConfig, dest="options")
+ return parser
def read_tsv(filename):
# Open the TSV file for reading
@@ -331,44 +354,7 @@ def main(args):
if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Download tar files referred to in index file from mlx"
- )
- parser.add_argument(
- "--dataset-config-file",
- type=str,
- default="",
- help="yaml file with dataset names",
- )
- parser.add_argument(
- "--worker-id",
- type=int,
- default=0,
- help="current worker in [0, num-downloaders -1]",
- )
- parser.add_argument(
- "--num-downloaders", type=int, default=1, help="number of parallel downloaders"
- )
- parser.add_argument("--no_bandwidth", action="store_true")
- parser.add_argument(
- "--download_tar",
- action="store_true",
- help="whether or not to download tar files also",
- )
- parser.add_argument("--pretrained-text-embeddings", type=str, default=None)
- parser.add_argument(
- "--endpoint-url",
- type=str,
- default="",
- help="end point for the s3 bucket — uses environment variable AWS_ENDPOINT_URL otherwise",
- )
- parser.add_argument(
- "--subset",
- type=str,
- default="train",
- choices=["train", "eval"],
- help="subset to download [train|eval]",
- )
+ parser = get_parser()
args = parser.parse_args()
logging.basicConfig(
level="INFO",
@@ -377,5 +363,5 @@ def main(args):
),
datefmt="%H:%M:%S",
)
- helpers.print_args(args)
- main(args)
+ helpers.print_args(args.options)
+ main(args.options)
\ No newline at end of file
diff --git a/ml_mdm/clis/generate_batch.py b/ml-mdm-matryoshka/ml_mdm/clis/generate_batch.py
similarity index 100%
rename from ml_mdm/clis/generate_batch.py
rename to ml-mdm-matryoshka/ml_mdm/clis/generate_batch.py
diff --git a/ml_mdm/clis/generate_sample.py b/ml-mdm-matryoshka/ml_mdm/clis/generate_sample.py
similarity index 94%
rename from ml_mdm/clis/generate_sample.py
rename to ml-mdm-matryoshka/ml_mdm/clis/generate_sample.py
index 975ab31..b254de2 100644
--- a/ml_mdm/clis/generate_sample.py
+++ b/ml-mdm-matryoshka/ml_mdm/clis/generate_sample.py
@@ -1,11 +1,12 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
+import argparse
import logging
import os
import shlex
import time
from dataclasses import dataclass
-from typing import Optional
+from typing import Optional, Tuple
import gradio as gr
import simple_parsing
@@ -16,6 +17,8 @@
import torch
from torchvision.utils import make_grid
+import ml_mdm.language_models.factory
+import ml_mdm.language_models.tokenizer
from ml_mdm import helpers, reader
from ml_mdm.config import get_arguments, get_model, get_pipeline
from ml_mdm.language_models import factory
@@ -36,14 +39,20 @@
)
-def dividable(n):
+def dividable(n: int) -> Tuple[int, int]:
for i in range(int(np.sqrt(n)), 0, -1):
if n % i == 0:
break
return i, n // i
-def generate_lm_outputs(device, sample, tokenizer, language_model, args):
+def generate_lm_outputs(
+ device: torch.device,
+ sample: dict,
+ tokenizer: ml_mdm.language_models.tokenizer.Tokenizer,
+ language_model: ml_mdm.language_models.factory.LanguageModel,
+ args: argparse.Namespace,
+) -> dict:
with torch.no_grad():
lm_outputs, lm_mask = language_model(sample, tokenizer)
sample["lm_outputs"] = lm_outputs
@@ -51,7 +60,7 @@ def generate_lm_outputs(device, sample, tokenizer, language_model, args):
return sample
-def setup_models(args, device):
+def setup_models(args: argparse.Namespace, device: torch.device):
input_channels = 3
# load the language model
@@ -68,7 +77,10 @@ def setup_models(args, device):
return tokenizer, language_model, diffusion_model
-def plot_logsnr(logsnrs, total_steps):
+
+def plot_logsnr(logsnrs: list, total_steps: int) -> np.ndarray:
+ import matplotlib
+ matplotlib.use('Agg')
import matplotlib.pyplot as plt
x = 1 - np.arange(len(logsnrs)) / (total_steps - 1)
@@ -103,39 +115,40 @@ class GLOBAL_DATA:
global_config = GLOBAL_DATA()
-def stop_run():
+def stop_run() -> gr.component:
return (
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
)
-def get_model_type(config_file):
+
+def get_model_type(config_file: str) -> str:
with open(config_file, "r") as f:
d = yaml.safe_load(f)
return d.get("model", d.get("vision_model", "unet"))
def generate(
- config_file="cc12m_64x64.yaml",
- ckpt_name="vis_model_64x64.pth",
- prompt="a chair",
- input_template="",
- negative_prompt="",
- negative_template="",
- batch_size=20,
- guidance_scale=7.5,
- threshold_function="clip",
- num_inference_steps=250,
- eta=0,
- save_diffusion_path=False,
- show_diffusion_path=False,
- show_xt=False,
- reader_config="",
- seed=10,
- comment="",
- override_args="",
- output_inner=False,
+ config_file: str = "cc12m_64x64.yaml",
+ ckpt_name: str = "vis_model_64x64.pth",
+ prompt: str = "a chair",
+ input_template: str = "",
+ negative_prompt: str = "",
+ negative_template: str = "",
+ batch_size: int = 20,
+ guidance_scale: float = 7.5,
+ threshold_function: str = "clip",
+ num_inference_steps: int = 250,
+ eta: int = 0,
+ save_diffusion_path: bool = False,
+ show_diffusion_path: bool = False,
+ show_xt: bool = False,
+ reader_config: str = "",
+ seed: int = 10,
+ comment: str = "",
+ override_args: str = "",
+ output_inner: bool = False,
):
np.random.seed(seed)
torch.random.manual_seed(seed)
@@ -292,7 +305,7 @@ def generate(
)
-def main(args):
+def main(args: argparse.Namespace):
# get the language model outputs
example_texts = open("data/prompts_demo.tsv").readlines()
diff --git a/ml_mdm/clis/run_torchmetrics.py b/ml-mdm-matryoshka/ml_mdm/clis/run_torchmetrics.py
similarity index 76%
rename from ml_mdm/clis/run_torchmetrics.py
rename to ml-mdm-matryoshka/ml_mdm/clis/run_torchmetrics.py
index ec5a502..9331370 100644
--- a/ml_mdm/clis/run_torchmetrics.py
+++ b/ml-mdm-matryoshka/ml_mdm/clis/run_torchmetrics.py
@@ -1,6 +1,7 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
-import argparse
+
+import simple_parsing
import json
import logging
import os
@@ -14,7 +15,38 @@
import torch
from ml_mdm import helpers
-
+from dataclasses import dataclass, field
+
+@dataclass
+class MetricsConfig:
+ loglevel: str = field(default="INFO",
+ metadata={"help": "Logging level"})
+ sample_dir: str = field(default="",
+ metadata={"help": "directory with samples"})
+ metrics: str = field(default="clip,fid",
+ metadata={"help": "Metrics to compute(comma separated)"})
+ reference_dir: str = field(default="",
+ metadata={"help": "directory with reference images"})
+ num_samplers: int = field(default=1,
+ metadata={"help": "Number of jobs generating samples"})
+ num_training_steps: int = field(default=850000,
+ metadata={"help": "# of training steps to train for"})
+ max_caption_length: int = field(default=77,
+ metadata={"help": "Maximum length of caption"})
+ eval_freq: int = field(default=1000,
+ metadata={"help": "Minimum Evaluation interval"})
+ clip_model: str = field(default="openai/clip-vit-base-patch16",
+ metadata={"help": "Model to use for clip scores"})
+ inception_layer_fid: int = field(default=2048,
+ metadata={
+ "choices": [64, 192, 768, 2048],
+ "help": "Which layer of inception to use for fid"
+ })
+
+def get_parser():
+ parser = simple_parsing.ArgumentParser(description="Compute metrics on samples from diffusion model")
+ parser.add_arguments(MetricsConfig, dest="options")
+ return parser
def load_captions_and_images(dir_name, args, override_path=None):
map_files = []
@@ -140,54 +172,12 @@ def main(args):
if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Compute metrics on samples from diffusion model"
- )
- parser.add_argument("--loglevel", type=str, default="INFO", help="Logging level")
- parser.add_argument(
- "--sample-dir", type=str, default="", help="directory with samples"
- )
- parser.add_argument(
- "--metrics",
- type=str,
- default="clip,fid",
- help="Metrics to compute(comma separated)",
- )
- parser.add_argument(
- "--reference-dir", type=str, default="", help="directory with reference images"
- )
- parser.add_argument(
- "--num-samplers", type=int, default=1, help="Number of jobs generating samples"
- )
- parser.add_argument(
- "--num-training-steps",
- type=int,
- default=850000,
- help="# of training steps to train for",
- )
- parser.add_argument(
- "--max-caption-length", type=int, default=77, help="Maximum length of caption"
- )
- parser.add_argument(
- "--eval-freq", type=int, default=1000, help="Minimum Evaluation interval"
- )
- parser.add_argument(
- "--clip-model",
- type=str,
- default="openai/clip-vit-base-patch16",
- help="Model to use for clip scores",
- )
- parser.add_argument(
- "--inception-layer-fid",
- type=int,
- default=2048,
- choices=[64, 192, 768, 2048],
- help="Which layer of inception to use for fid",
- )
+ parser = get_parser()
args = parser.parse_args()
logging.basicConfig(
- level=getattr(logging, args.loglevel.upper(), None),
+ level=getattr(logging, args.options.loglevel.upper(), None),
format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
)
- main(args)
+ helpers.print_args(args.options)
+ main(args.options)
diff --git a/ml_mdm/clis/scrape_cc12m.py b/ml-mdm-matryoshka/ml_mdm/clis/scrape_cc12m.py
similarity index 100%
rename from ml_mdm/clis/scrape_cc12m.py
rename to ml-mdm-matryoshka/ml_mdm/clis/scrape_cc12m.py
diff --git a/ml_mdm/clis/train_parallel.py b/ml-mdm-matryoshka/ml_mdm/clis/train_parallel.py
similarity index 87%
rename from ml_mdm/clis/train_parallel.py
rename to ml-mdm-matryoshka/ml_mdm/clis/train_parallel.py
index 6a4f0af..07bd4c7 100644
--- a/ml_mdm/clis/train_parallel.py
+++ b/ml-mdm-matryoshka/ml_mdm/clis/train_parallel.py
@@ -7,6 +7,7 @@
import logging
import os
import time
+from contextlib import nullcontext
import numpy as np
import torch
@@ -53,7 +54,11 @@ def main(args):
local_rank, global_rank, world_size = init_distributed_singlenode(timeout=36000)
input_channels = 3
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ device = torch.device("mps")
tokenizer, language_model = factory.create_lm(args, device=device)
language_model_dim = language_model.embed_dim
@@ -71,7 +76,8 @@ def main(args):
os.makedirs(args.output_dir)
if "MASTER_ADDR" in os.environ:
- dist.barrier()
+ if dist.is_available() and dist.is_initialized():
+ dist.barrier()
other_items = None
if (
@@ -109,7 +115,8 @@ def main(args):
else:
grad_scaler = None
- dist.barrier()
+ if dist.is_available() and dist.is_initialized():
+ dist.barrier()
max_lr = args.lr
# Should eps be 1e-4 like for LMs in fp16 ?
if args.use_adamw:
@@ -137,13 +144,18 @@ def main(args):
CLIP = 3
# intialize the model
- model = nn.parallel.DistributedDataParallel(
- diffusion_model.model,
- device_ids=[local_rank],
- )
+ if int(os.environ.get("WORLD_SIZE", "1")) > 1:
+ model = nn.parallel.DistributedDataParallel(
+ diffusion_model.model,
+ device_ids=[local_rank],
+ )
+ else:
+ model = diffusion_model.model
diffusion_model.model = model
- dist.barrier()
- ema_model = ModelEma(diffusion_model.model.module.vision_model)
+ if dist.is_available() and dist.is_initialized():
+ dist.barrier()
+ # Check if the model is wrapped in DistributedDataParallel
+ ema_model = ModelEma(getattr(diffusion_model.model, "module", diffusion_model.model).vision_model)
# get the dataloader
if args.multinode:
@@ -187,7 +199,8 @@ def main(args):
sample["images"] = images
if accumulate_gradient:
- with diffusion_model.model.no_sync():
+ no_sync_context = diffusion_model.model.no_sync() if hasattr(diffusion_model.model, "no_sync") else nullcontext()
+ with no_sync_context:
loss_val, losses, times, x_t, means, targets = trainer.train_batch(
diffusion_model,
sample,
@@ -220,7 +233,6 @@ def main(args):
num_time_counts += 1
if np.isnan(loss_val):
continue
-
# accumulate loss
if batch_num != 1:
# E[(x-E[x])^2] = E[x^2] - E[x]^2
@@ -239,6 +251,8 @@ def main(args):
exp_avg_loss = loss_val
exp_avg_loss_var = loss_val**2
total_loss_val += loss_val
+ # print(f"Allocated memory: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB", end='')
+ # print(f"Val loss: {loss_val}")
if (not accumulate_gradient) and (global_rank == 0):
metrics = {
@@ -274,12 +288,15 @@ def main(args):
"args": args,
} # save full config.
ema_model.save(vision_model_file, other_items=other_items)
- diffusion_model.model.module.vision_model.save(
+ getattr(diffusion_model.model, "module", diffusion_model.model).vision_model.save(
vision_model_noema_file, other_items=other_items
)
+ torch.cuda.empty_cache()
+ torch.mps.empty_cache()
if (batch_num % args.save_freq == 0) or (batch_num == args.num_training_steps):
- dist.barrier()
+ if dist.is_available() and dist.is_initialized():
+ dist.barrier()
if batch_num == args.num_training_steps:
break
@@ -302,5 +319,6 @@ def main(args):
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.empty_cache()
+ torch.mps.empty_cache()
helpers.print_args(args)
main(args)
diff --git a/ml_mdm/config.py b/ml-mdm-matryoshka/ml_mdm/config.py
similarity index 100%
rename from ml_mdm/config.py
rename to ml-mdm-matryoshka/ml_mdm/config.py
diff --git a/ml_mdm/diffusion.py b/ml-mdm-matryoshka/ml_mdm/diffusion.py
similarity index 94%
rename from ml_mdm/diffusion.py
rename to ml-mdm-matryoshka/ml_mdm/diffusion.py
index c687e22..9a2766d 100644
--- a/ml_mdm/diffusion.py
+++ b/ml-mdm-matryoshka/ml_mdm/diffusion.py
@@ -13,6 +13,7 @@
import torch.nn.functional as F
from torchvision.utils import save_image
+import ml_mdm.samplers
from ml_mdm import config, samplers
@@ -59,10 +60,10 @@ def __init__(
self.vision_model = vision_model
self.sampler = None
- def set_sampler(self, sampler):
+ def set_sampler(self, sampler: ml_mdm.samplers.Sampler):
self.sampler = sampler
- def load(self, vision_file):
+ def load(self, vision_file: str) -> dict:
return self.vision_model.load(vision_file)
def save(self, vision_file, other_items=None):
@@ -103,7 +104,7 @@ def get_model(self):
return self.model.module
return self.model
- def to(self, device):
+ def to(self, device: torch.device):
self.model = self.model.to(device)
self.sampler = self.sampler.to(device)
return self
@@ -115,7 +116,7 @@ def eval(self):
self.model.eval()
self.sampler.eval()
- def get_xt_minus_1(self, t, x_t, lm_outputs, lm_mask):
+ def get_xt_minus_1(self, t, x_t, lm_outputs: torch.Tensor, lm_mask: torch.Tensor):
self.eval()
return self.sampler.get_xt_minus_1(t, x_t, lm_outputs, lm_mask)
@@ -134,13 +135,13 @@ def get_pred_for_training(self, x_t, pred, g):
)
return pred
- def get_micro_conditioning(self, sample):
+ def get_micro_conditioning(self, sample: dict) -> dict:
micros, conditions = {}, self.get_model().vision_model.conditions
if conditions is not None:
micros = {key: sample[key] for key in conditions if key in sample}
return micros
- def get_loss(self, sample):
+ def get_loss(self, sample: dict):
images, lm_outputs, lm_mask = (
sample["images"],
sample["lm_outputs"],
@@ -166,12 +167,25 @@ def get_loss(self, sample):
loss = self.loss_fn(pred, tgt).mean(axis=(1, 2, 3))
return loss, time, x_t, means, tgt, weights
- def get_noise(self, num_examples, input_channels, image_side, device):
+ def get_noise(
+ self,
+ num_examples: int,
+ input_channels: int,
+ image_side: int,
+ device: torch.device,
+ ) -> torch.Tensor:
return torch.randn(num_examples, input_channels, image_side, image_side).to(
device
)
- def sample(self, num_examples, sample, image_side, device, **kwargs):
+ def sample(
+ self,
+ num_examples: int,
+ sample: dict,
+ image_side: int,
+ device: torch.device,
+ **kwargs: dict,
+ ):
self.eval()
noise = self.get_noise(
num_examples, self.get_model().input_channels, image_side, device
@@ -298,7 +312,7 @@ def __init__(self, denoising_model, diffusion_config: DiffusionConfig):
)
self.mixed_ratio = self.mixed_ratio / self.mixed_ratio[-1]
- def get_loss(self, sample):
+ def get_loss(self, sample: dict):
images, lm_outputs, lm_mask = (
sample["images"],
sample["lm_outputs"],
@@ -370,5 +384,4 @@ def get_loss(self, sample):
loss_ = pred[i].mean() * 0.0
loss_ = loss_ * w[i]
loss = loss + loss_
-
return loss, time, x_t[0], pred[0], tgt[0], weights
diff --git a/ml_mdm/distributed.py b/ml-mdm-matryoshka/ml_mdm/distributed.py
similarity index 97%
rename from ml_mdm/distributed.py
rename to ml-mdm-matryoshka/ml_mdm/distributed.py
index ebd2aa5..99997fc 100644
--- a/ml_mdm/distributed.py
+++ b/ml-mdm-matryoshka/ml_mdm/distributed.py
@@ -32,7 +32,8 @@ def init_distributed_singlenode(timeout=0):
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
- if not "MASTER_ADDR" in os.environ:
+
+ if not "MASTER_ADDR" in os.environ or world_size == 1:
return local_rank, rank, world_size
if timeout == 0:
diff --git a/ml_mdm/generate_html.py b/ml-mdm-matryoshka/ml_mdm/generate_html.py
similarity index 100%
rename from ml_mdm/generate_html.py
rename to ml-mdm-matryoshka/ml_mdm/generate_html.py
diff --git a/ml_mdm/helpers.py b/ml-mdm-matryoshka/ml_mdm/helpers.py
similarity index 100%
rename from ml_mdm/helpers.py
rename to ml-mdm-matryoshka/ml_mdm/helpers.py
diff --git a/ml_mdm/language_models/__init__.py b/ml-mdm-matryoshka/ml_mdm/language_models/__init__.py
similarity index 100%
rename from ml_mdm/language_models/__init__.py
rename to ml-mdm-matryoshka/ml_mdm/language_models/__init__.py
diff --git a/ml_mdm/language_models/factory.py b/ml-mdm-matryoshka/ml_mdm/language_models/factory.py
similarity index 98%
rename from ml_mdm/language_models/factory.py
rename to ml-mdm-matryoshka/ml_mdm/language_models/factory.py
index 180d406..df8f838 100644
--- a/ml_mdm/language_models/factory.py
+++ b/ml-mdm-matryoshka/ml_mdm/language_models/factory.py
@@ -8,7 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from .tokenizer import Tokenizer
+from ml_mdm.language_models.tokenizer import Tokenizer
class T5Encoder(T5ForConditionalGeneration):
diff --git a/ml_mdm/language_models/self_attention.py b/ml-mdm-matryoshka/ml_mdm/language_models/self_attention.py
similarity index 100%
rename from ml_mdm/language_models/self_attention.py
rename to ml-mdm-matryoshka/ml_mdm/language_models/self_attention.py
diff --git a/ml_mdm/language_models/tokenizer.py b/ml-mdm-matryoshka/ml_mdm/language_models/tokenizer.py
similarity index 91%
rename from ml_mdm/language_models/tokenizer.py
rename to ml-mdm-matryoshka/ml_mdm/language_models/tokenizer.py
index 0fb08dd..b3af8a8 100644
--- a/ml_mdm/language_models/tokenizer.py
+++ b/ml-mdm-matryoshka/ml_mdm/language_models/tokenizer.py
@@ -5,11 +5,11 @@
from mlx.data.core import CharTrie
-def read_dictionary_bert(token_file):
+def read_dictionary_bert(vocab_file):
trie_key_scores = []
trie = CharTrie()
- f = open(token_file, "rb")
+ f = open(vocab_file, "rb")
sep = "\u2581".encode()
max_score = 0
@@ -42,11 +42,11 @@ def read_dictionary_bert(token_file):
return trie, trie_key_scores, eos, bos, pad
-def read_dictionary_t5(token_file):
+def read_dictionary_t5(vocab_file):
trie_key_scores = []
trie = CharTrie()
- f = open(token_file, "rb")
+ f = open(vocab_file, "rb")
sep = "\u2581".encode()
max_score = 0
@@ -75,7 +75,7 @@ def read_dictionary_t5(token_file):
return trie, trie_key_scores, eos, bos, pad
-def read_dictionary(token_file):
+def read_dictionary(vocab_file):
trie_key_scores = []
trie = CharTrie()
@@ -85,7 +85,7 @@ def read_dictionary(token_file):
trie.insert(token)
trie_key_scores.append(0.0)
- f = open(token_file, "rb")
+ f = open(vocab_file, "rb")
sep = "\u2581".encode()
max_score = 0
@@ -130,7 +130,7 @@ def read_dictionary(token_file):
class Tokenizer:
- def __init__(self, token_file, mode=None):
+ def __init__(self, vocab_file, mode=None):
if mode == "t5":
(
self._trie,
@@ -138,7 +138,7 @@ def __init__(self, token_file, mode=None):
self.eos,
self.bos,
self.pad,
- ) = read_dictionary_t5(token_file)
+ ) = read_dictionary_t5(vocab_file)
elif mode == "bert":
(
self._trie,
@@ -146,7 +146,7 @@ def __init__(self, token_file, mode=None):
self.eos,
self.bos,
self.pad,
- ) = read_dictionary_bert(token_file)
+ ) = read_dictionary_bert(vocab_file)
else:
(
self._trie,
@@ -154,7 +154,7 @@ def __init__(self, token_file, mode=None):
self.eos,
self.bos,
self.pad,
- ) = read_dictionary(token_file)
+ ) = read_dictionary(vocab_file)
self.vocab_size = self._trie.num_keys()
@property
diff --git a/ml_mdm/language_models/transformer.py b/ml-mdm-matryoshka/ml_mdm/language_models/transformer.py
similarity index 100%
rename from ml_mdm/language_models/transformer.py
rename to ml-mdm-matryoshka/ml_mdm/language_models/transformer.py
diff --git a/ml_mdm/lr_scaler.py b/ml-mdm-matryoshka/ml_mdm/lr_scaler.py
similarity index 100%
rename from ml_mdm/lr_scaler.py
rename to ml-mdm-matryoshka/ml_mdm/lr_scaler.py
diff --git a/ml_mdm/models/__init__.py b/ml-mdm-matryoshka/ml_mdm/models/__init__.py
similarity index 100%
rename from ml_mdm/models/__init__.py
rename to ml-mdm-matryoshka/ml_mdm/models/__init__.py
diff --git a/ml_mdm/models/model_ema.py b/ml-mdm-matryoshka/ml_mdm/models/model_ema.py
similarity index 91%
rename from ml_mdm/models/model_ema.py
rename to ml-mdm-matryoshka/ml_mdm/models/model_ema.py
index b2c0f83..c916633 100644
--- a/ml_mdm/models/model_ema.py
+++ b/ml-mdm-matryoshka/ml_mdm/models/model_ema.py
@@ -10,7 +10,7 @@
class ModelEma(nn.Module):
- def __init__(self, model, decay=0.9999, warmup_steps=0, device=None):
+ def __init__(self, model, decay: float=0.9999, warmup_steps: int = 0, device: torch.device =None):
super(ModelEma, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
@@ -33,7 +33,7 @@ def update(self, model):
model_v = model_v.to(device=self.device)
ema_v.mul_(decay).add_(model_v, alpha=(1.0 - decay))
- def save(self, fname, other_items=None):
+ def save(self, fname: str, other_items=None):
logging.info(f"Saving EMA model file: {fname}")
checkpoint = {"state_dict": self.module.state_dict()}
if other_items is not None:
@@ -41,7 +41,7 @@ def save(self, fname, other_items=None):
checkpoint[k] = v
torch.save(checkpoint, fname)
- def load(self, fname):
+ def load(self, fname: str):
logging.info(f"Loading EMA model file: {fname}")
fix_old_checkpoints.mimic_old_modules()
checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
diff --git a/ml_mdm/models/nested_unet.py b/ml-mdm-matryoshka/ml_mdm/models/nested_unet.py
similarity index 99%
rename from ml_mdm/models/nested_unet.py
rename to ml-mdm-matryoshka/ml_mdm/models/nested_unet.py
index b87c20c..3b170fa 100644
--- a/ml_mdm/models/nested_unet.py
+++ b/ml-mdm-matryoshka/ml_mdm/models/nested_unet.py
@@ -75,7 +75,7 @@ class Nested4UNetConfig(Nested3UNetConfig):
)
-def download(vision_model_path):
+def download(vision_model_path: str):
import os
from distributed import get_local_rank
diff --git a/ml_mdm/models/unet.py b/ml-mdm-matryoshka/ml_mdm/models/unet.py
similarity index 99%
rename from ml_mdm/models/unet.py
rename to ml-mdm-matryoshka/ml_mdm/models/unet.py
index 43a8506..2d5ffd1 100644
--- a/ml_mdm/models/unet.py
+++ b/ml-mdm-matryoshka/ml_mdm/models/unet.py
@@ -578,7 +578,7 @@ def forward(
@config.register_model("unet")
class UNet(nn.Module):
- def __init__(self, input_channels, output_channels, config: UNetConfig):
+ def __init__(self, input_channels: int, output_channels: int, config: UNetConfig):
super().__init__()
self.down_blocks = []
self.config = config
@@ -776,7 +776,7 @@ def __init__(self, input_channels, output_channels, config: UNetConfig):
def model_type(self):
return "unet"
- def print_size(self, target_image_size=64):
+ def print_size(self, target_image_size: int =64):
summary(
self,
[
@@ -791,7 +791,7 @@ def print_size(self, target_image_size=64):
depth=4,
)
- def save(self, fname, other_items=None):
+ def save(self, fname: str, other_items=None):
logging.info(f"Saving model file: {fname}")
checkpoint = {"state_dict": self.state_dict()}
if other_items is not None:
@@ -799,7 +799,7 @@ def save(self, fname, other_items=None):
checkpoint[k] = v
torch.save(checkpoint, fname)
- def load(self, fname):
+ def load(self, fname: str):
logging.info(f"Loading model file: {fname}")
fix_old_checkpoints.mimic_old_modules()
# first load to cpu or we will run out of memory.
diff --git a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py
new file mode 100644
index 0000000..ea58505
--- /dev/null
+++ b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py
@@ -0,0 +1,249 @@
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All rights reserved.
+
+import math
+
+import einops.array_api
+
+import mlx.core as mx
+import mlx.nn as nn
+
+
+def zero_module_mlx(module):
+ """
+ Zero out the parameters of an MLX module and return it.
+ """
+ # Create a new parameter dictionary with all parameters replaced by zeros
+ zeroed_params = {
+ name: mx.zeros(param.shape, dtype=param.dtype)
+ for name, param in module.parameters().items()
+ }
+ # Update the module's parameters with the zeroed parameters
+ module.update(zeroed_params)
+ return module
+
+
+
+class SelfAttention1D_MLX(nn.Module):
+ def __init__(
+ self,
+ channels,
+ num_heads=8,
+ num_head_channels=-1,
+ use_attention_ffn=False,
+ pos_emb=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+
+ self.norm = nn.LayerNorm(channels)
+ self.qkv = nn.Linear(channels, channels * 3)
+ self.proj_out = zero_module_mlx(nn.Linear(channels, channels))
+ if use_attention_ffn:
+ self.ffn = nn.Sequential(
+ nn.LayerNorm(channels),
+ nn.Linear(channels, 4 * channels),
+ nn.GELU(),
+ zero_module_mlx(nn.Linear(4 * channels, channels)),
+ )
+ else:
+ self.ffn = None
+ if pos_emb:
+ from mlx.nn import RoPE
+
+ self.pos_emb = RoPE(dim=channels // self.num_heads)
+ else:
+ self.pos_emb = None
+
+ def attention(self, q, k, v, mask=None):
+ bs, length, width = q.shape
+ ch = width // self.num_heads
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ q = q.reshape(bs, length, self.num_heads, ch)
+ k = k.reshape(bs, length, self.num_heads, ch)
+ if self.pos_emb is not None:
+ q = self.pos_emb.rotate_queries_or_keys(q.permute(0, 2, 1, 3)).permute(
+ 0, 2, 1, 3
+ )
+ k = self.pos_emb.rotate_queries_or_keys(k.permute(0, 2, 1, 3)).permute(
+ 0, 2, 1, 3
+ )
+ weight = mx.einsum(
+ "bthc,bshc->bhts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ if mask is not None:
+ mask = mask.view(mask.size(0), 1, 1, mask.size(1))
+ weight = weight.masked_fill(mask == 0, float("-inf"))
+ weight = mx.softmax(weight, axis=-1)
+ a = mx.einsum("bhts,bshc->bthc", weight, v.reshape(bs, -1, self.num_heads, ch))
+ return a.reshape(bs, length, -1)
+
+ def forward(self, x, mask):
+ # assert (self.cond_dim is not None) == (cond is not None)
+ qkv = self.qkv(self.norm(x))
+ q, k, v = mx.split(qkv, 3, axis=-1)
+ h = self.attention(q, k, v, mask)
+ h = self.proj_out(h)
+ x = x + h
+ if self.ffn is not None:
+ x = x + self.ffn(x)
+ return x
+
+
+class TemporalAttentionBlock_MLX(nn.Module):
+ def __init__(
+ self, channels, num_heads=8, num_head_channels=-1, down=False, pos_emb=False
+ ):
+ super().__init__()
+ self.attn = SelfAttention1D_MLX(
+ channels, num_heads, num_head_channels, pos_emb=pos_emb
+ )
+ self.mlp = MLP_MLX(channels, multiplier=4)
+ self.down = down
+ if down:
+ self.down_conv = nn.Conv2d(
+ channels, channels, kernel_size=3, stride=2, padding=1, bias=True
+ )
+ self.up_conv = nn.Conv2d(
+ channels, channels, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ def forward(self, x, temb):
+ x_ = x
+ if self.down:
+ x = einops.array_api.rearrange(x, "b c h w -> b h w c")
+ x = self.down_conv(x)
+ x = einops.array_api.rearrange(x, "b h w c -> b c h w")
+
+ x = einops.array_api.rearrange(x, "b c h w -> b h w c")
+ T, H, W = x.shape[0] // temb.shape[0], x.shape[2], x.shape[3]
+ x = einops.array_api.rearrange(x, "(b t) h w c -> (b h w) t c", t=T)
+ x = self.attn.forward(x, None)
+ x = self.mlp.forward(x)
+ x = einops.array_api.rearrange(x, "(b h w) t c -> (b t) h w c", h=H, w=W)
+ x = einops.array_api.rearrange(x, "b h w c -> b c h w")
+
+ if self.down:
+ x = einops.array_api.rearrange(x, "b c h w -> b h w c")
+ x = nn.Upsample(scale_factor=2, mode="nearest")(x)
+ x = self.up_conv(x)
+ x = einops.array_api.rearrange(x, "b h w c -> b c h w")
+ x = x + x_
+ return x
+
+
+class MLP_MLX(nn.Module): # mlx based nn.Module
+ def __init__(self, channels, multiplier=4):
+ super().__init__()
+ ### use mlx layers
+ self.main = nn.Sequential(
+ nn.LayerNorm(channels),
+ nn.Linear(channels, multiplier * channels),
+ nn.GELU(),
+ zero_module_mlx(nn.Linear(multiplier * channels, channels)),
+ )
+
+ def forward(self, x):
+ return x + self.main(x)
+
+
+class SelfAttention_MLX(nn.Module):
+ def __init__(
+ self,
+ channels,
+ num_heads=8,
+ num_head_channels=-1,
+ cond_dim=None,
+ use_attention_ffn=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.norm = nn.GroupNorm(32, channels, pytorch_compatible=True)
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
+ self.cond_dim = cond_dim
+ if cond_dim is not None and cond_dim > 0:
+ self.norm_cond = nn.LayerNorm(cond_dim)
+ self.kv_cond = nn.Linear(cond_dim, channels * 2)
+ self.proj_out = zero_module_mlx(nn.Conv2d(channels, channels, 1))
+ if use_attention_ffn:
+ self.ffn = nn.Sequential(
+ nn.GroupNorm(32, channels, pytorch_compatible=True),
+ nn.Conv2d(channels, 4 * channels, 1),
+ nn.GELU(),
+ zero_module_mlx(nn.Conv2d(4 * channels, channels, 1)),
+ )
+ else:
+ self.ffn = None
+
+ def attention(self, q, k, v, mask=None):
+ bs, width, length = q.shape
+ ch = width // self.num_heads
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = mx.einsum(
+ "bct,bcs->bts",
+ (q * scale).reshape(bs * self.num_heads, ch, length),
+ (k * scale).reshape(bs * self.num_heads, ch, -1),
+ ) # More stable with f16 than dividing afterwards
+ if mask is not None:
+ # Reshape mask to match attention shape
+ # From [bs, seq_len] to [bs * num_heads, 1, seq_len]
+ expanded_mask = einops.array_api.repeat(
+ mask[:, None, :], # Add dimension for broadcasting
+ "b 1 s -> (b h) 1 s",
+ h=self.num_heads,
+ )
+ # Apply mask
+ weight = mx.where(expanded_mask, weight, float("-inf"))
+
+ weight = mx.softmax(weight, axis=-1)
+
+ return mx.einsum(
+ "bts,bcs->bct", weight, v.reshape(bs * self.num_heads, ch, -1)
+ ).reshape(bs, width, length)
+
+ def forward(self, x, cond=None, cond_mask=None):
+
+ x = einops.array_api.rearrange(x, "b c h w -> b h w c")
+ b, h, w, c = x.shape
+
+ qkv = self.qkv(self.norm(x))
+ qkv = einops.array_api.rearrange(qkv, "b h w (three c) -> three b (h w) c", three=3)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn_output = self.attention(q, k, v)
+
+ if self.cond_dim is not None and cond is not None:
+ kv_cond = self.kv_cond(self.norm_cond(cond))
+ kv_cond = einops.array_api.rearrange(kv_cond, "b s (two c) -> two b s c", two=2)
+ k_cond, v_cond = kv_cond[0], kv_cond[1]
+ attn_cond = self.attention(q, k_cond, v_cond, cond_mask)
+ attn_output += attn_cond
+
+ attn_output = einops.array_api.rearrange(attn_output, "b (h w) c -> b h w c", h=h, w=w)
+ h = self.proj_out(attn_output)
+
+ x = einops.array_api.rearrange(x, "b h w c -> b c h w")
+ h = einops.array_api.rearrange(h, "b h w c -> b c h w")
+ x = x + h
+
+ if self.ffn is not None:
+ x = einops.array_api.rearrange(x, "b c h w -> b h w c")
+ x = self.ffn(x) + x
+ x = einops.array_api.rearrange(x, "b h w c -> b c h w")
+
+ return x
+
diff --git a/ml_mdm/reader.py b/ml-mdm-matryoshka/ml_mdm/reader.py
similarity index 100%
rename from ml_mdm/reader.py
rename to ml-mdm-matryoshka/ml_mdm/reader.py
diff --git a/ml_mdm/s3_helpers.py b/ml-mdm-matryoshka/ml_mdm/s3_helpers.py
similarity index 86%
rename from ml_mdm/s3_helpers.py
rename to ml-mdm-matryoshka/ml_mdm/s3_helpers.py
index e70479e..a57af49 100644
--- a/ml_mdm/s3_helpers.py
+++ b/ml-mdm-matryoshka/ml_mdm/s3_helpers.py
@@ -12,10 +12,10 @@
def download_object(
- bucket_name,
- file_name,
- download_path=None,
- endpoint_url=ENDPOINT_URL,
+ bucket_name: str,
+ file_name: str,
+ download_path: str =None,
+ endpoint_url: str = ENDPOINT_URL,
max_bandwidth=None,
):
"""Downloads an object from S3 to local."""
@@ -37,7 +37,7 @@ def download_object(
return download_path
-def download_object_from_full_path(path, download_path=None, endpoint_url=ENDPOINT_URL):
+def download_object_from_full_path(path: str, download_path: str =None, endpoint_url: str = ENDPOINT_URL):
bucket_name, parent_path, basename = _parse_path(path)
file_name = os.path.join(parent_path, basename)
return download_object(
@@ -46,10 +46,10 @@ def download_object_from_full_path(path, download_path=None, endpoint_url=ENDPOI
def upload_object(
- bucket_name,
- file_name,
- upload_path,
- endpoint_url=ENDPOINT_URL,
+ bucket_name: str,
+ file_name: str,
+ upload_path: str,
+ endpoint_url: str = ENDPOINT_URL,
):
"""Uload an object from S3 to local."""
@@ -70,7 +70,7 @@ def _parse_path(tsv_pattern):
return bucket, "/".join(parts[3:-1]), pattern
-def get_file_list(tsv_pattern, endpoint_url=ENDPOINT_URL):
+def get_file_list(tsv_pattern: str, endpoint_url: str = ENDPOINT_URL):
bucket_name, parent_path, pattern = _parse_path(tsv_pattern)
resource = boto3.resource("s3", endpoint_url=endpoint_url)
bucket = resource.Bucket(bucket_name)
@@ -84,7 +84,7 @@ def get_file_list(tsv_pattern, endpoint_url=ENDPOINT_URL):
return fnames
-def download_parallel(files, endpoint_url=ENDPOINT_URL):
+def download_parallel(files: str, endpoint_url: str=ENDPOINT_URL):
logging.info("Doing parallel download")
with ProcessPoolExecutor() as executor:
logging.info(f"Submitting {files}")
diff --git a/ml_mdm/samplers.py b/ml-mdm-matryoshka/ml_mdm/samplers.py
similarity index 87%
rename from ml_mdm/samplers.py
rename to ml-mdm-matryoshka/ml_mdm/samplers.py
index c72caac..1870814 100644
--- a/ml_mdm/samplers.py
+++ b/ml-mdm-matryoshka/ml_mdm/samplers.py
@@ -4,6 +4,7 @@
import math
from dataclasses import dataclass, field
from enum import Enum
+from typing import Callable, Tuple
from einops import repeat
from tqdm import tqdm
@@ -26,6 +27,7 @@ class ScheduleType(Type):
COSINE = 0
DDPM = 2
DEEPFLOYD = 3
+ SIGMOID = 4
@staticmethod
def argparse(s):
@@ -112,9 +114,7 @@ class SamplerConfig:
)
schedule_shifted_power: float = field(
default=1,
- metadata={
- "help": "noise shifted ratio, by default using 1."
- },
+ metadata={"help": "noise shifted ratio, by default using 1."},
)
@@ -146,12 +146,12 @@ def schedule_ddpm_defults(
return gammas
-def squaredcos_cap_v2(timesteps: int):
+def squaredcos_cap_v2(timesteps: int) -> np.ndarray:
"""
https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L147
"""
- def alpha_bar(time_step):
+ def alpha_bar(time_step: float) -> float:
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = [0]
@@ -164,6 +164,10 @@ def alpha_bar(time_step):
gammas = np.exp(np.cumsum(log_alphas))
return gammas
+def schedule_sigmoid(timesteps: int, beta_start: float, beta_end: float) -> np.ndarray:
+ """https://arxiv.org/pdf/2301.10972"""
+ betas = np.linspace((-6,6), timesteps)
+ return (1/(np.exp(betas) + 1)) * (beta_end - beta_start) + beta_start
##########################################################################################
# Sampler Class #
@@ -189,12 +193,14 @@ def __init__(self, sampler_config: SamplerConfig):
if self._config.loss_target_type is None:
self._config.loss_target_type = self._config.prediction_type
- def read_gamma(self, time, image):
+ def read_gamma(self, time: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
B, C, H, W = image.size()
time = repeat(time, "b -> b c h w", c=C, h=H, w=W)
return self.gammas[time]
- def get_noise_schedule(self, schedule_type, n_steps, sampler_config):
+ def get_noise_schedule(
+ self, schedule_type: ScheduleType, n_steps: int, sampler_config: SamplerConfig
+ ):
# pre-defined noise schedule functions
if schedule_type == ScheduleType.COSINE:
_gammas = schedule_cosine(n_steps)
@@ -246,10 +252,12 @@ def get_image_rescaled(self, images, scale_factor=None):
images = images / scale_factor
return images
- def get_schedule_shifted(self, gammas, scale_factor=None):
+ def get_schedule_shifted(
+ self, gammas: torch.Tensor, scale_factor: float = None
+ ) -> torch.Tensor:
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schecule
p = self._config.schedule_shifted_power
- scale_factor = scale_factor ** p
+ scale_factor = scale_factor**p
snr = gammas / (1 - gammas)
scaled_snr = snr / scale_factor
gammas = 1 / (1 + 1 / scaled_snr)
@@ -272,17 +280,17 @@ def get_prediction_targets(self, images, eps, g, g_last, prediction_type=None):
def get_prediction_xt_last(
self,
- x_t,
- pred,
- g,
- g_last,
- prediction_type=None,
- clip_fn=None,
- need_noise=False,
- ddim_eta=None,
+ x_t: torch.Tensor,
+ pred: torch.Tensor,
+ g: torch.Tensor,
+ g_last: torch.Tensor,
+ prediction_type: PredictionType = None,
+ clip_fn: Callable = None,
+ need_noise: torch.Tensor = False,
+ ddim_eta: int = None,
input_noise=None,
image_scale=None,
- ):
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
x_t: noisy image
pred: model prediction (can be x0, eps, v, etc)
@@ -293,7 +301,6 @@ def get_prediction_xt_last(
need_noise: use noise or not
ddim_eta: if None, then not using DDIM, otherwise, use DDIM implementation (1==DDPM)
"""
-
if prediction_type is None:
prediction_type = self._config.prediction_type
@@ -338,7 +345,13 @@ def get_prediction_xt_last(
return x0, x_t_last, eps
def get_x0_eps_from_pred(
- self, x_t, pred, g, prediction_type=None, clip_fn=None, return_eps=True
+ self,
+ x_t: torch.Tensor,
+ pred: torch.Tensor,
+ g: torch.Tensor,
+ prediction_type: PredictionType = None,
+ clip_fn=None,
+ return_eps: bool = True,
):
batch_size = x_t.size(0)
if prediction_type is None:
@@ -378,16 +391,16 @@ def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None):
def get_xt_minus_1(
self,
- model,
- time_step,
- x_t,
- lm_outputs,
- lm_mask,
- micros={},
- time_step_last=None,
- guidance_scale=1,
- ddim_eta=None,
- return_details=False,
+ model, # TODO - This is ml_mdm.diffusion.Model but importing diffusion is a circular import
+ time_step: torch.Tensor,
+ x_t: torch.Tensor,
+ lm_outputs: torch.Tensor,
+ lm_mask: torch.Tensor,
+ micros: dict = {},
+ time_step_last: torch.Tensor = None,
+ guidance_scale: float = 1,
+ ddim_eta: int = None,
+ return_details: bool = False,
):
batch_size = x_t.shape[0]
ones = torch.ones(batch_size, dtype=torch.long, device=self.gammas.device)
@@ -420,8 +433,15 @@ def get_xt_minus_1(
return x_s
def forward_model(
- self, model, x_t, t, lm_outputs, lm_mask, micros={}, guidance_scale=1
- ):
+ self,
+ model, # TODO - This is ml_mdm.diffusion.Model but to import diffusion it is a circular import
+ x_t: torch.Tensor,
+ t: torch.Tensor,
+ lm_outputs: torch.Tensor,
+ lm_mask: torch.Tensor,
+ micros: dict = {},
+ guidance_scale: float = 1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
if guidance_scale != 1:
assert x_t.shape[0] * 2 == lm_outputs.shape[0]
pred, extras = model(
@@ -439,8 +459,11 @@ def forward_model(
return pred, extras
def _threshold_sample(
- self, sample, dynamic_thresholding_ratio=0.995, sample_max_value=100
- ):
+ self,
+ sample: torch.Tensor,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 100,
+ ) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -474,7 +497,7 @@ def _threshold_sample(
return sample
- def clip_sample(self, pred_x0, image_scale=1):
+ def clip_sample(self, pred_x0: torch.Tensor, image_scale: int = 1) -> torch.Tensor:
s = image_scale
if self._config.threshold_function == ThresholdType.CLIP:
return (pred_x0 * s).clip(-1, 1) / s
@@ -492,20 +515,20 @@ def sample(self, *args, **kwargs):
def _sample(
self,
- model,
- x_t,
- lm_outputs,
- lm_mask,
- micros,
- return_sequence=False,
- use_beta_tilde=False,
- t=-1,
- num_inference_steps=2000,
- ddim_eta=None,
- guidance_scale=1,
- resample_steps=False,
- disable_bar=True,
- yield_output=False,
+ model, # TODO - This is ml_mdm.diffusion.Model but to import diffusion it is a circular import
+ x_t: torch.Tensor,
+ lm_outputs: torch.Tensor,
+ lm_mask: torch.Tensor,
+ micros: dict,
+ return_sequence: bool = False,
+ use_beta_tilde: bool = False,
+ t: int = -1,
+ num_inference_steps: int = 2000,
+ ddim_eta: int = None,
+ guidance_scale: float = 1,
+ resample_steps: bool = False,
+ disable_bar: bool = True,
+ yield_output: bool = False,
**post_args,
):
"""
@@ -556,12 +579,12 @@ def _sample(
def _postprocess(
self,
- x_t,
- x0=None,
- extra=None,
- yield_full=False,
- clip=False,
- image_scale=None,
+ x_t: torch.Tensor,
+ x0: torch.Tensor = None,
+ extra: tuple = None,
+ yield_full: bool = False,
+ clip: bool = False,
+ image_scale: float = None,
**unused,
):
if image_scale is None:
@@ -575,7 +598,7 @@ def _postprocess(
return (x0, x_t, extra)
return x_t
- def set_timesteps(self, num_inference_steps=250):
+ def set_timesteps(self, num_inference_steps: int = 250) -> np.ndarray:
step_ratio = (self._config.num_diffusion_steps + 1) / (num_inference_steps + 1)
timesteps = (
(np.arange(0, num_inference_steps + 1) * step_ratio)
@@ -634,8 +657,8 @@ def get_xt_minus_1(
model,
time_step,
x_t,
- lm_outputs,
- lm_mask,
+ lm_outputs: torch.Tensor,
+ lm_mask: torch.Tensor,
micros={},
time_step_last=None,
guidance_scale=1,
@@ -694,9 +717,9 @@ def _postprocess(
x_t,
x0=None,
extra=None,
- yield_full=False,
- clip=False,
- output_inner=False,
+ yield_full: bool=False,
+ clip: bool=False,
+ output_inner: bool =False,
**unused,
):
scales = [
@@ -749,7 +772,7 @@ def cat(x, size):
return out
def forward_model(
- self, model, x_t, t, lm_outputs, lm_mask, micros={}, guidance_scale=1
+ self, model, x_t, t, lm_outputs: torch.Tensor, lm_mask: torch.Tensor, micros={}, guidance_scale=1
):
def cfg(pred):
pred_uncond, pred = pred.chunk(2)
diff --git a/ml_mdm/trainer.py b/ml-mdm-matryoshka/ml_mdm/trainer.py
similarity index 76%
rename from ml_mdm/trainer.py
rename to ml-mdm-matryoshka/ml_mdm/trainer.py
index 1c2b93d..7a33a1b 100644
--- a/ml_mdm/trainer.py
+++ b/ml-mdm-matryoshka/ml_mdm/trainer.py
@@ -1,22 +1,27 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
+
+
+from argparse import Namespace
+from typing import Optional
+
import numpy as np
import torch
import torch.nn as nn
-
+from torch.utils.tensorboard import SummaryWriter
def train_batch(
- model,
- sample,
- optimizer,
- scheduler,
- logger,
- args,
- grad_scaler=None,
- accumulate_gradient=False,
- num_grad_accumulations=1,
- ema_model=None,
- loss_factor=1,
+ model: torch.nn.Module,
+ sample: dict,
+ optimizer: torch.optim.Optimizer,
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
+ logger: Optional[torch.utils.tensorboard.SummaryWriter],
+ args: Namespace,
+ grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
+ accumulate_gradient: bool = False,
+ num_grad_accumulations: int = 1,
+ ema_model: Optional[nn.Module] = None,
+ loss_factor: float = 1.0,
):
model.train()
lr = scheduler.get_last_lr()[0]
@@ -50,7 +55,9 @@ def train_batch(
grad_scaler.step(optimizer)
grad_scaler.update()
if ema_model is not None:
- ema_model.update(model.model.module.vision_model)
+ ema_model.update(
+ getattr(model.model, "module", model.model).vision_model
+ )
else:
losses, times, x_t, means, targets, weights = model.get_loss(sample)
if weights is None:
@@ -74,7 +81,9 @@ def train_batch(
).item()
optimizer.step()
if ema_model is not None:
- ema_model.update(model.model.module.vision_model)
+ ema_model.update(
+ getattr(model.model, "module", model.model).vision_model
+ )
if logger is not None and not accumulate_gradient:
logger.add_scalar("train/Loss", loss_val)
diff --git a/ml_mdm/utils/__init__.py b/ml-mdm-matryoshka/ml_mdm/utils/__init__.py
similarity index 100%
rename from ml_mdm/utils/__init__.py
rename to ml-mdm-matryoshka/ml_mdm/utils/__init__.py
diff --git a/ml_mdm/utils/fix_old_checkpoints.py b/ml-mdm-matryoshka/ml_mdm/utils/fix_old_checkpoints.py
similarity index 100%
rename from ml_mdm/utils/fix_old_checkpoints.py
rename to ml-mdm-matryoshka/ml_mdm/utils/fix_old_checkpoints.py
diff --git a/ml_mdm/utils/simple_logger.py b/ml-mdm-matryoshka/ml_mdm/utils/simple_logger.py
similarity index 100%
rename from ml_mdm/utils/simple_logger.py
rename to ml-mdm-matryoshka/ml_mdm/utils/simple_logger.py
diff --git a/pyproject.toml b/ml-mdm-matryoshka/pyproject.toml
similarity index 76%
rename from pyproject.toml
rename to ml-mdm-matryoshka/pyproject.toml
index 43aaeeb..d3d16d2 100644
--- a/pyproject.toml
+++ b/ml-mdm-matryoshka/pyproject.toml
@@ -5,9 +5,10 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["."]
exclude = ["tests*", "*clis*"]
+namespaces = true
[project]
-name = "ml_mdm"
+name = "ml-mdm-matryoshka"
authors = [{name = "Apple"}]
readme = "README.md"
version = "1.0"
@@ -17,33 +18,53 @@ description = "A python package to simplify the creation of text conditioned ima
dependencies = [
"dataclass-wizard",
"einops",
- "fastapi>=0.109.1", # Required due to CVE-2024-24762
- "gradio>=4.14", # Required due to CVE-2023-6572
"httpx==0.24.1",
- "imageio[ffmpeg]",
- "matplotlib",
"mlx-data",
"numpy<2",
- "pytorch-model-summary",
- "rotary-embedding-torch",
"simple-parsing==0.1.5",
- "tensorboardX==2.6.2.2",
- "tensorboard==2.16.2",
- "torchinfo",
- "torchmetrics[image]",
"torchvision",
"transformers",
"sentencepiece",
- "boto3",
"torch==2.2.2",
- "pytest",
- "pytest-cov",
- "pre-commit"
+ "matplotlib",
+ "gradio",
+ "boto3",
+ "torchmetrics",
+ "img2dataset",
+ "torchinfo"
]
[project.optional-dependencies]
+cpu = [
+ "torch==2.2.2+cpu",
+ "tensorflow==2.5.0",
+]
+gpu = [
+ "torch==2.2.2+cu111",
+ "tensorflow-gpu==2.5.0",
+]
data_prep = [
- "img2dataset"
+ "img2dataset",
+ "boto3",
+]
+web_demo = [
+ "fastapi>=0.109.1", # Required due to CVE-2024-24762
+ "gradio>=4.14", # Required due to CVE-2023-6572
+ "matplotlib",
+ "imageio[ffmpeg]",
+]
+training = [
+ "tensorboard==2.16.2",
+ "tensorboardX==2.6.2.2",
+ "torchmetrics[image]",
+ "rotary-embedding-torch",
+ "pytorch-model-summary",
+ "torchinfo",
+]
+dev = [
+ "pytest",
+ "pytest-cov",
+ "pre-commit",
]
[tool.isort]
@@ -51,9 +72,8 @@ profile = "black"
sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "FIRSTPARTY", "LOCALFOLDER"]
known_numeric = ["torch", "torchvision", "numpy", "jax", "flax", "mlx"]
-
[tool.pytest.ini_options]
-addopts = "--cov=ml_mdm"
+addopts = " -m 'not gpu'"
markers = [
"gpu" # tests that require a gpu
]
diff --git a/tests/test_configs.py b/ml-mdm-matryoshka/tests/test_configs.py
similarity index 80%
rename from tests/test_configs.py
rename to ml-mdm-matryoshka/tests/test_configs.py
index 2d3fe83..a61a8dd 100644
--- a/tests/test_configs.py
+++ b/ml-mdm-matryoshka/tests/test_configs.py
@@ -7,16 +7,19 @@
def test_unet_in_registry():
+ """Check that 'nested_unet' and 'unet' models are correctly registered in the Model Registry."""
assert config.get_model("nested_unet") is not None
assert config.get_model("unet") is not None
def test_unet_in_pipeline():
+ """Check that 'nested_unet' and 'unet' models have corresponding pipelines defined."""
assert config.get_pipeline("unet") is not None
assert config.get_pipeline("nested_unet") is not None
def test_config_cc12m_64x64():
+ """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo)."""
f = "configs/models/cc12m_64x64.yaml"
args = config.get_arguments(
mode="trainer",
@@ -44,6 +47,7 @@ def test_config_cc12m_64x64():
def test_config_cc12m_256x256():
+ """Check that the 'cc12m_256x256' configuration loads with 'nested_unet' as model in all modes (trainer, sampler, demo)."""
f = "configs/models/cc12m_256x256.yaml"
args = config.get_arguments(
args=["--model=nested_unet"],
@@ -75,6 +79,7 @@ def test_config_cc12m_256x256():
def test_config_cc12m_1024x1024():
+ """Check that the 'cc12m_1024x1024' configuration loads with 'nested2_unet' model in all modes (trainer, sampler, demo)."""
f = "configs/models/cc12m_1024x1024.yaml"
args = config.get_arguments(
args=["--model=nested2_unet"],
@@ -102,4 +107,4 @@ def test_config_cc12m_1024x1024():
mode="demo",
additional_config_paths=[f],
)
- assert args
+ assert args
\ No newline at end of file
diff --git a/tests/test_files/c12m_10samples.tsv b/ml-mdm-matryoshka/tests/test_files/c12m_10samples.tsv
similarity index 100%
rename from tests/test_files/c12m_10samples.tsv
rename to ml-mdm-matryoshka/tests/test_files/c12m_10samples.tsv
diff --git a/tests/test_files/images_00000.tar b/ml-mdm-matryoshka/tests/test_files/images_00000.tar
similarity index 100%
rename from tests/test_files/images_00000.tar
rename to ml-mdm-matryoshka/tests/test_files/images_00000.tar
diff --git a/tests/test_files/images_00000.tsv b/ml-mdm-matryoshka/tests/test_files/images_00000.tsv
similarity index 100%
rename from tests/test_files/images_00000.tsv
rename to ml-mdm-matryoshka/tests/test_files/images_00000.tsv
diff --git a/tests/test_files/sample_training_0.tsv b/ml-mdm-matryoshka/tests/test_files/sample_training_0.tsv
similarity index 100%
rename from tests/test_files/sample_training_0.tsv
rename to ml-mdm-matryoshka/tests/test_files/sample_training_0.tsv
diff --git a/tests/test_generate_batch.py b/ml-mdm-matryoshka/tests/test_generate_batch.py
similarity index 80%
rename from tests/test_generate_batch.py
rename to ml-mdm-matryoshka/tests/test_generate_batch.py
index 2334ac2..4edb759 100644
--- a/tests/test_generate_batch.py
+++ b/ml-mdm-matryoshka/tests/test_generate_batch.py
@@ -8,6 +8,10 @@
def test_small_batch():
+ """
+ Test small batch generation with T5 model.
+ Check that basic data generation pipeline works with minimal settings.
+ """
args = Namespace(
batch_size=10,
test_file_list="tests/test_files/sample_training_0.tsv",
@@ -33,6 +37,12 @@ def test_small_batch():
def test_generate_batch():
+ """
+ Test batch generation with default config settings.
+
+ Note: This test currently only sets up the configuration but doesn't execute
+ the generation (ends with pass statement).
+ """
args = config.get_arguments(mode="sampler")
args.batch_size = 10
args.test_file_list = "tests/test_files/sample_training_0.tsv"
diff --git a/tests/test_generate_sample.py b/ml-mdm-matryoshka/tests/test_generate_sample.py
similarity index 72%
rename from tests/test_generate_sample.py
rename to ml-mdm-matryoshka/tests/test_generate_sample.py
index d36c889..02c70e5 100644
--- a/tests/test_generate_sample.py
+++ b/ml-mdm-matryoshka/tests/test_generate_sample.py
@@ -4,6 +4,10 @@
def test_load_flick_config():
+ """
+ Test loading of cc12m_64x64.yaml config file.
+ Checks image dimensions are correctly loaded in reader config.
+ """
args = config.get_arguments(
"",
mode="demo",
diff --git a/tests/test_imports.py b/ml-mdm-matryoshka/tests/test_imports.py
similarity index 72%
rename from tests/test_imports.py
rename to ml-mdm-matryoshka/tests/test_imports.py
index 19138c2..04474a8 100644
--- a/tests/test_imports.py
+++ b/ml-mdm-matryoshka/tests/test_imports.py
@@ -1,6 +1,7 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
def test_top_level_imports_work():
+ """Checks that all top-level ml_mdm module imports are accessible."""
from ml_mdm import (
config,
diffusion,
@@ -16,6 +17,7 @@ def test_top_level_imports_work():
def test_cli_imports_work():
+ """Checks that all CLI module imports are accessible."""
from ml_mdm.clis import (
download_tar_from_index,
generate_batch,
@@ -25,8 +27,10 @@ def test_cli_imports_work():
def test_model_imports_work():
+ """Checks that all model module imports are accessible."""
from ml_mdm.models import model_ema, nested_unet, unet
def test_lm_imports_work():
+ """Checks that all language model module imports are accessible."""
from ml_mdm.language_models import factory, tokenizer
diff --git a/tests/test_models.py b/ml-mdm-matryoshka/tests/test_models.py
similarity index 88%
rename from tests/test_models.py
rename to ml-mdm-matryoshka/tests/test_models.py
index b945fb5..2a96dad 100644
--- a/tests/test_models.py
+++ b/ml-mdm-matryoshka/tests/test_models.py
@@ -13,6 +13,7 @@
def test_initialize_unet():
+ """Test UNet model and EMA initialization with default configs."""
unet_config = models.unet.UNetConfig()
diffusion_config = diffusion.DiffusionConfig(
use_vdm_loss_weights=True, model_output_scale=0.1
@@ -30,6 +31,7 @@ def test_initialize_unet():
def test_all_registered_models():
+ """Test instantiation of all models in the registry with default configs."""
for config_name, additional_info in config.MODEL_CONFIG_REGISTRY.items():
model_name = additional_info["model"]
config_cls = additional_info["config"]
@@ -44,6 +46,7 @@ def test_all_registered_models():
@pytest.mark.gpu
def test_initialize_pretrained():
+ """Test loading pretrained 64x64 model on GPU if available."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = config.get_arguments(
diff --git a/tests/test_reader.py b/ml-mdm-matryoshka/tests/test_reader.py
similarity index 88%
rename from tests/test_reader.py
rename to ml-mdm-matryoshka/tests/test_reader.py
index bcc3ef7..3ad3fe8 100644
--- a/tests/test_reader.py
+++ b/ml-mdm-matryoshka/tests/test_reader.py
@@ -10,6 +10,7 @@
def test_get_dataset():
+ """Test dataset loading and verify sample format and dimensions."""
tokenizer = factory.create_tokenizer("data/t5.vocab")
dataset = reader.get_dataset(
tokenizer=tokenizer,
@@ -31,6 +32,7 @@ def test_get_dataset():
def test_get_dataset_partition():
+ """Test dataset partitioning and iteration."""
tokenizer = factory.create_tokenizer("data/t5.vocab")
train_loader = reader.get_dataset_partition(
partition_num=0,
@@ -46,6 +48,7 @@ def test_get_dataset_partition():
def test_process_text():
+ """Test text tokenization with default reader config."""
line = "A bicycle on top of a boat."
tokenizer = factory.create_tokenizer("data/t5.vocab")
tokens = reader.process_text(
@@ -53,3 +56,6 @@ def test_process_text():
)
assert len(tokens) > 0
assert len(tokens[0]) > 0
+
+
+test_get_dataset()
\ No newline at end of file
diff --git a/ml-mdm-matryoshka/tests/test_tokenizer.py b/ml-mdm-matryoshka/tests/test_tokenizer.py
new file mode 100644
index 0000000..30b090c
--- /dev/null
+++ b/ml-mdm-matryoshka/tests/test_tokenizer.py
@@ -0,0 +1,23 @@
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All rights reserved.
+
+import logging
+
+from pathlib import Path
+from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py
+
+def test_tokenizer_bert():
+ f = Path(__file__).parent.parent/"data/bert.vocab" # To solve from relative to absolute import
+ assert Tokenizer(f, mode="bert")
+
+def test_tokenizer_t5():
+ f = Path(__file__).parent.parent/"data/t5.vocab"
+ assert Tokenizer(f, mode="tf")
+
+def test_tokenizer():
+ f = Path(__file__).parent.parent/"data/imagenet.vocab"
+ assert Tokenizer(f)
+
+test_tokenizer_bert()
+test_tokenizer_t5()
+test_tokenizer()
diff --git a/tests/test_train.py b/ml-mdm-matryoshka/tests/test_train.py
similarity index 94%
rename from tests/test_train.py
rename to ml-mdm-matryoshka/tests/test_train.py
index c1f2f52..073fcb4 100644
--- a/tests/test_train.py
+++ b/ml-mdm-matryoshka/tests/test_train.py
@@ -23,6 +23,7 @@
reason="more effective to test this with torchrun, just here for documentation"
)
def test_small():
+ """Test minimal training run with single process setup."""
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
diff --git a/ml-mdm-matryoshka/tests/test_unet_mlx.py b/ml-mdm-matryoshka/tests/test_unet_mlx.py
new file mode 100644
index 0000000..90c12c5
--- /dev/null
+++ b/ml-mdm-matryoshka/tests/test_unet_mlx.py
@@ -0,0 +1,249 @@
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All rights reserved.
+
+import mlx.core as mx
+import numpy as np
+import torch
+
+from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock, SelfAttention
+from ml_mdm.models.unet_mlx import (
+ MLP_MLX,
+ SelfAttention1D_MLX,
+ TemporalAttentionBlock_MLX,
+ SelfAttention_MLX
+)
+
+
+def test_pytorch_mlp():
+ """
+ Simple test for our MLP implementations
+ """
+ # Define parameters
+ channels = 8 # Number of channels
+ multiplier = 4 # Multiplier for hidden dimensions
+
+ # Create a model instance
+ pytorch_mlp = MLP(channels=channels, multiplier=multiplier)
+ mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier)
+
+ ## Start by testing pytorch version
+
+ # Set model to evaluation mode
+ pytorch_mlp.eval()
+
+ # Create a dummy pytorch input tensor (batch size = 2, channels = 8)
+ input_tensor = torch.randn(2, channels)
+
+ # Pass the input through the model
+ output = pytorch_mlp(input_tensor)
+
+ # Assertions to validate the output shape and properties
+ assert output.shape == input_tensor.shape, "Output shape mismatch"
+ assert torch.allclose(
+ output, input_tensor, atol=1e-5
+ ), "Output should be close to input as the final layer is zero-initialized"
+
+ ## now test mlx version
+
+ # Convert the same input to MLX tensor
+ mlx_tensor = mx.array(input_tensor.numpy())
+
+ mlx_mlp.eval()
+
+ mlx_output = mlx_mlp.forward(mlx_tensor)
+
+ assert isinstance(mlx_output, mx.array)
+ assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch"
+
+ # Validate numerical equivalence using numpy
+ assert np.allclose(
+
+ output.detach().numpy(), np.array(mx.stop_gradient(mlx_output)), atol=1e-5
+ ), "Outputs of PyTorch MLP and MLX MLP should match"
+
+ print("Test passed for both PyTorch and MLX MLP!")
+
+
+
+def test_pytorch_mlx_self_attention():
+ """
+ Test for feature parity between PyTorch and MLX implementations of SelfAttention.
+ We'll test both the basic self-attention and conditional attention scenarios.
+ """
+ # Define test parameters
+ channels = 64
+ batch_size = 2
+ spatial_size = 8
+ cond_dim = 32
+ num_heads = 8
+
+ # ===== 1. Test WITH CONDITIONAL INPUT =====
+ # Create models WITH conditional support
+ pytorch_attn_with_cond = SelfAttention(
+ channels=channels,
+ num_heads=num_heads,
+ cond_dim=cond_dim, # Enable conditioning
+ use_attention_ffn=True,
+ )
+ mlx_attn_with_cond = SelfAttention_MLX(
+ channels=channels,
+ num_heads=num_heads,
+ cond_dim=cond_dim,
+ use_attention_ffn=True,
+ )
+
+ # Create conditional inputs
+ cond_seq_len = 4
+ pytorch_cond = torch.randn(batch_size, cond_seq_len, cond_dim)
+ pytorch_cond_mask = torch.ones(batch_size, cond_seq_len)
+ mlx_cond = mx.array(pytorch_cond.numpy())
+ mlx_cond_mask = mx.array(pytorch_cond_mask.numpy())
+
+ # Run conditional tests
+ pytorch_input = torch.randn(batch_size, channels, spatial_size, spatial_size)
+ mlx_input = mx.array(pytorch_input.numpy())
+
+ # PyTorch conditional forward
+ pytorch_output_with_cond = pytorch_attn_with_cond(
+ pytorch_input, cond=pytorch_cond, cond_mask=pytorch_cond_mask
+ )
+ # MLX conditional forward
+ mlx_output_with_cond = mlx_attn_with_cond.forward(
+ mlx_input, cond=mlx_cond, cond_mask=mlx_cond_mask
+ )
+
+ # ===== 2. Test WITHOUT CONDITIONAL INPUT =====
+ # Create NEW models WITHOUT conditional support
+ pytorch_attn_no_cond = SelfAttention(
+ channels=channels,
+ num_heads=num_heads,
+ cond_dim=None,
+ use_attention_ffn=True,
+ )
+ mlx_attn_no_cond = SelfAttention_MLX(
+ channels=channels,
+ num_heads=num_heads,
+ cond_dim=None,
+ use_attention_ffn=True,
+ )
+
+ # Run non-conditional tests
+ pytorch_output_no_cond = pytorch_attn_no_cond(pytorch_input)
+ mlx_output_no_cond = mlx_attn_no_cond.forward(mlx_input)
+
+ # ===== Assertions =====
+ # Check conditional outputs
+ assert pytorch_output_with_cond.shape == pytorch_input.shape
+ assert mlx_output_with_cond.shape == mlx_input.shape
+ assert np.allclose(
+ pytorch_output_with_cond.detach().numpy(),
+ np.array(mlx_output_with_cond),
+ atol=1e-5, rtol=1e-5
+ ), "Outputs of PyTorch and MLX attention should match"
+
+ # Check non-conditional outputs
+ assert pytorch_output_no_cond.shape == pytorch_input.shape
+ assert mlx_output_no_cond.shape == mlx_input.shape
+ assert np.allclose(
+ pytorch_output_no_cond.detach().numpy(),
+ np.array(mlx_output_no_cond),
+ atol=1e-5, rtol=1e-5
+ ), "Outputs without conditioning should match"
+
+ print("Self-attention test passed for both PyTorch and MLX!")
+
+def test_self_attention_1d():
+ # Define parameters
+ channels = 8
+ num_heads = 2
+ seq_length = 16
+ batch_size = 2
+
+ # Create a model instance
+ pytorch_attn = SelfAttention1D(channels=channels, num_heads=num_heads)
+ mlx_attn = SelfAttention1D_MLX(channels=channels, num_heads=num_heads)
+
+ # Set models to evaluation mode
+ pytorch_attn.eval()
+ mlx_attn.eval()
+
+ # Create a dummy input tensor
+ input_tensor = torch.randn(batch_size, seq_length, channels)
+
+ # Pass the input through the PyTorch model
+ pytorch_output = pytorch_attn(input_tensor, mask=None)
+
+ # Convert the input to MLX format
+ mlx_input = mx.array(input_tensor.numpy())
+
+ # Pass the input through the MLX model
+ mlx_output = mlx_attn.forward(mlx_input, mask=None)
+
+ # Assertions to validate the output shape and properties
+ assert pytorch_output.shape == mlx_output.shape, "Output shape mismatch"
+ assert np.allclose(
+ pytorch_output.detach().numpy(), np.array(mx.stop_gradient(mlx_output)), atol=1e-5
+ ), "Outputs of PyTorch and MLX SelfAttention1D should match"
+
+ print("Test passed for both PyTorch and MLX SelfAttention1D!")
+
+
+def test_pytorch_mlx_temporal_attention_block():
+ """
+ Test for verifying parity between PyTorch and MLX implementations of TemporalAttentionBlock.
+ """
+ # Define parameters
+ channels = 8
+ num_heads = 2
+ batch_size = 2
+ time_steps = 4
+ height = 16
+ width = 16
+
+ # Create model instances
+ pytorch_block = TemporalAttentionBlock(
+ channels=channels, num_heads=num_heads, down=True
+ )
+
+ mlx_block = TemporalAttentionBlock_MLX(
+ channels=channels, num_heads=num_heads, down=True
+ )
+
+ # Set models to evaluation mode
+ pytorch_block.eval()
+ mlx_block.eval()
+
+ # Create random arrays with correct shape and dtype
+ arr_input = np.random.normal(0, 1, (batch_size * time_steps, channels, height, width)).astype(np.float32)
+ arr_temb = np.random.normal(0, 1, (batch_size, channels)).astype(np.float32)
+
+ # Create dummy input tensors
+ pytorch_input = torch.from_numpy(arr_input)
+ pytorch_temb = torch.from_numpy(arr_temb)
+
+ mlx_input = mx.array(arr_input)
+ mlx_temb = mx.array(arr_temb)
+
+ pytorch_output = pytorch_block(pytorch_input, pytorch_temb)
+
+ mlx_output = mlx_block.forward(mlx_input, mlx_temb)
+
+ # Print output tensors for debugging
+ print("pytorch_output tensor shape: ", pytorch_output.shape)
+ print("mlx_output tensor shape: ", mlx_output.shape)
+ print("torch: ", pytorch_output)
+ print("mlx : ", mlx_output)
+ print("mean difference: ", np.mean(np.abs(pytorch_output.detach().numpy() - np.array(mx.stop_gradient(mlx_output))))) #0.35
+ print("psnr: ", 10 * np.log10(np.max(pytorch_output.detach().numpy())**2 / np.mean((pytorch_output.detach().numpy() - np.array(mx.stop_gradient(mlx_output)))**2))) # 19.2 dB
+
+ assert pytorch_output.shape == tuple(mlx_output.shape), f"Output shape mismatch: {pytorch_output.shape} vs {mlx_output.shape}"
+
+ # Increase tolerance to allow for small discrepancies in floating-point operations
+ assert np.allclose(
+ pytorch_output.detach().numpy(),
+ np.array(mx.stop_gradient(mlx_output)),
+ rtol=1e-1, # Significantly increased tolerance
+ atol=1e-1, # Significantly increased tolerance
+ ), "Outputs of PyTorch and MLX TemporalAttentionBlock should match"
+
+ print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")
diff --git a/ml_mdm/__init__.py b/ml-mdm/ml_mdm/__about__.py
similarity index 100%
rename from ml_mdm/__init__.py
rename to ml-mdm/ml_mdm/__about__.py
diff --git a/ml-mdm/ml_mdm/core.py b/ml-mdm/ml_mdm/core.py
new file mode 100644
index 0000000..2d6152d
--- /dev/null
+++ b/ml-mdm/ml_mdm/core.py
@@ -0,0 +1,35 @@
+
+
+from dataclasses import dataclass, is_dataclass
+from simple_parsing.helpers import Serializable
+from simple_parsing import parse
+from simple_parsing.utils import DataclassT
+
+from typing import TypeVar
+
+C = TypeVar('C')
+
+@dataclass
+class MDMConfig(Serializable):
+ pass
+
+class ConfigPrinter:
+ def __init__(self, config : MDMConfig) -> None:
+ print(config)
+
+@dataclass
+class CLIBuilder():
+ class_to_call: type[C] = ConfigPrinter
+ config_class: type = MDMConfig
+ default_config : DataclassT = None
+
+ def build_config(self, args: str = None) -> DataclassT:
+ assert is_dataclass(self.config_class)
+ cfg: DataclassT = parse(
+ config_class=self.config_class, add_config_path_arg="config-file", default=self.default_config, args=args
+ )
+ return cfg
+
+ def run(self)-> C:
+ cfg: DataclassT = self.build_config()
+ return self.class_to_call(cfg)
diff --git a/ml-mdm/pyproject.toml b/ml-mdm/pyproject.toml
new file mode 100644
index 0000000..81d18d2
--- /dev/null
+++ b/ml-mdm/pyproject.toml
@@ -0,0 +1,12 @@
+[build-system]
+build-backend = "setuptools.build_meta"
+requires = ["setuptools"]
+
+[project]
+dependencies = []
+name = "ml-mdm"
+version = "0.1.0"
+
+[tool.setuptools.packages.find]
+namespaces = true
+where = ["."]
diff --git a/ml-mdm/tests/__init__.py b/ml-mdm/tests/__init__.py
new file mode 100644
index 0000000..5c8f054
--- /dev/null
+++ b/ml-mdm/tests/__init__.py
@@ -0,0 +1,2 @@
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All rights reserved.
\ No newline at end of file
diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py
deleted file mode 100644
index 0084ccc..0000000
--- a/ml_mdm/models/unet_mlx.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# For licensing see accompanying LICENSE file.
-# Copyright (C) 2024 Apple Inc. All rights reserved.
-
-import mlx.core as mx
-import mlx.nn as nn
-
-
-def zero_module_mlx(module):
- """
- Zero out the parameters of an MLX module and return it.
- """
- # Create a new parameter dictionary with all parameters replaced by zeros
- zeroed_params = {
- name: mx.zeros(param.shape, dtype=param.dtype)
- for name, param in module.parameters().items()
- }
- # Update the module's parameters with the zeroed parameters
- module.update(zeroed_params)
- return module
-
-
-class MLP_MLX(nn.Module): # mlx based nn.Module
- def __init__(self, channels, multiplier=4):
- super().__init__()
- ### use mlx layers
- self.main = nn.Sequential(
- nn.LayerNorm(channels),
- nn.Linear(channels, multiplier * channels),
- nn.GELU(),
- zero_module_mlx(nn.Linear(multiplier * channels, channels)),
- )
-
- def forward(self, x):
- return x + self.main(x)
diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py
deleted file mode 100644
index a58a60e..0000000
--- a/tests/test_mlx_unet.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# For licensing see accompanying LICENSE file.
-# Copyright (C) 2024 Apple Inc. All rights reserved.
-
-import mlx.core as mx
-import numpy as np
-import torch
-
-from ml_mdm.models.unet import MLP
-from ml_mdm.models.unet_mlx import MLP_MLX
-
-
-def test_pytorch_mlp():
- """
- Simple test for our MLP implementations
- """
- # Define parameters
- channels = 8 # Number of channels
- multiplier = 4 # Multiplier for hidden dimensions
-
- # Create a model instance
- pytorch_mlp = MLP(channels=channels, multiplier=multiplier)
- mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier)
-
- ## Start by testing pytorch version
-
- # Set model to evaluation mode
- pytorch_mlp.eval()
-
- # Create a dummy pytorch input tensor (batch size = 2, channels = 8)
- input_tensor = torch.randn(2, channels)
-
- # Pass the input through the model
- output = pytorch_mlp(input_tensor)
-
- # Assertions to validate the output shape and properties
- assert output.shape == input_tensor.shape, "Output shape mismatch"
- assert torch.allclose(
- output, input_tensor, atol=1e-5
- ), "Output should be close to input as the final layer is zero-initialized"
-
- ## now test mlx version
-
- # Convert the same input to MLX tensor
- mlx_tensor = mx.array(input_tensor.numpy())
-
- mlx_mlp.eval()
-
- mlx_output = mlx_mlp.forward(mlx_tensor)
-
- assert isinstance(mlx_output, mx.array)
- assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch"
-
- # Validate numerical equivalence using numpy
- assert np.allclose(
- output.detach().numpy(), np.array(mlx_output), atol=1e-5
- ), "Outputs of PyTorch MLP and MLX MLP should match"
-
- print("Test passed for both PyTorch and MLX MLP!")