Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
642faf3
test_configs.py docstrings
bdeanhardt Nov 11, 2024
1a82a4b
added docstrings to remaining test cases
bdeanhardt Nov 12, 2024
5cad650
removed duplicate line
bdeanhardt Nov 12, 2024
e024eb0
Update distributed.py
levinkhho Nov 13, 2024
367131c
Update trainer.py
levinkhho Nov 13, 2024
571a6b9
Update train_parallel.py
levinkhho Nov 13, 2024
3124abc
Update train_parallel.py
levinkhho Nov 13, 2024
274ee5f
Update train_parallel.py
levinkhho Nov 13, 2024
534bea9
Merge pull request #28 from bdeanhardt/patch-1
luke-carlson Nov 13, 2024
ac560df
Type hints for train_batch
emmagarr Nov 14, 2024
7f08370
Update torch requirement pyproject.toml
levinkhho Nov 14, 2024
1e1b286
Type hinted one function
ethanernst11 Nov 14, 2024
23e008e
Start of tokenizer tests.
pbpcruz Nov 14, 2024
7147d66
update test_tokenizer and fixed import in factory.py
pedroborgescruz Nov 14, 2024
76f5502
Finished test_tokenizer.py - need to check if we should assert the co…
pedroborgescruz Nov 14, 2024
9802f2a
Update test_tokenizer.py
pedroborgescruz Nov 15, 2024
195fae4
Update test_tokenizer.py
pedroborgescruz Nov 16, 2024
e62c9de
changed token_file to vocab_file
Nov 17, 2024
8438c07
Merge branch 'apple:main' into main
aaliyahnl Nov 17, 2024
3b1eca2
Update pyproject.toml
levinkhho Nov 19, 2024
fde2865
Update pyproject.toml
levinkhho Nov 19, 2024
7619a12
Update generate_sample.py (matplotlib use Agg backend)
levinkhho Nov 20, 2024
bf2fad3
small edit.
pedroborgescruz Nov 21, 2024
c048fd8
Merge branch 'main' of https://github.com/pedroborgescruz/ml-mdm
pedroborgescruz Nov 21, 2024
81fca8c
Update README.md
pedroborgescruz Nov 21, 2024
875837a
README changes
pedroborgescruz Nov 21, 2024
32c0347
READ me update.
pedroborgescruz Nov 21, 2024
13a4168
List test.
pedroborgescruz Nov 21, 2024
c198434
Updated README with file descriptions.
pedroborgescruz Nov 21, 2024
d2fd65f
Update README.md
pedroborgescruz Nov 21, 2024
fd89a1a
Update to README
pedroborgescruz Nov 21, 2024
268603b
Merge branch 'main' of https://github.com/pedroborgescruz/ml-mdm
pedroborgescruz Nov 21, 2024
a6f45df
Update README.md
pedroborgescruz Nov 21, 2024
0ad0bb5
Merge pull request #42 from levinkhho/patch-2
luke-carlson Nov 22, 2024
33f9544
Merge pull request #35 from levinkhho/patch-1
luke-carlson Nov 22, 2024
9035f13
Merge pull request #43 from aaliyahnl/main
luke-carlson Nov 22, 2024
c34b191
Merge pull request #33 from emmagarr/main
luke-carlson Nov 22, 2024
aa40a70
switch the default pytest call to run just CPU test cases
bdeanhardt Nov 23, 2024
1a433f5
updated readme to reflect new pytest functionalities
bdeanhardt Nov 23, 2024
4d29bbf
Added some typehinting to diffusion.py and generate_sample.py
ethanernst11 Nov 23, 2024
feed873
Merge pull request #48 from bdeanhardt/bella/39
luke-carlson Nov 25, 2024
40e0bea
Merge pull request #44 from pedroborgescruz/readme-updates
luke-carlson Nov 25, 2024
0b82947
Separate python dependencies into optional groups.
pedroborgescruz Nov 25, 2024
d9df79d
Merge branch 'main' of https://github.com/pedroborgescruz/ml-mdm
pedroborgescruz Nov 25, 2024
88abb66
Merge branch 'main' of https://github.com/pedroborgescruz/ml-mdm
pedroborgescruz Nov 25, 2024
ce43e05
Separate python dependencies into optional groups
pedroborgescruz Nov 25, 2024
50c4e73
Update files to use simple parsing
bregwin Nov 28, 2024
19d616a
Fix argument issues readme
bregwin Nov 28, 2024
10d7167
added config description
Dec 3, 2024
c3e66cc
Update README.md
aaliyahnl Dec 3, 2024
7509888
Update README.md
aaliyahnl Dec 3, 2024
f0e6d03
Merge pull request #29 from levinkhho/main
luke-carlson Dec 4, 2024
70fc872
Merge pull request #51 from aaliyahnl/main
luke-carlson Dec 4, 2024
dacd66d
Merge pull request #50 from brokoli777/convert-argparse
luke-carlson Dec 4, 2024
8238342
typehinting for generate and dividable
bdeanhardt Dec 4, 2024
3dc26b8
Merge branch 'main' into main
ethanernst11 Dec 4, 2024
9656803
Took out breakpoints
ethanernst11 Dec 5, 2024
5bf935b
gpu and cpu dependencies
emmagarr Dec 5, 2024
8891aa8
Merge pull request #54 from bdeanhardt/bella/typehinting
luke-carlson Dec 5, 2024
faec175
Merge pull request #55 from emmagarr/main
luke-carlson Dec 5, 2024
0e78b01
Fixing missing modules in pyproject.toml
gabrielfnayres Dec 6, 2024
5b793c0
added sigmoid beta scheduler
gabrielfnayres Dec 6, 2024
ab5851c
Added some typing in s3_helpers.py and trainer.py
gabrielfnayres Dec 6, 2024
8914e19
Merge branch 'main' into main
ethanernst11 Dec 6, 2024
2042cfa
imports for type hinting
ethanernst11 Dec 9, 2024
0f29322
Merge pull request #52 from ethanernst11/main
luke-carlson Dec 9, 2024
9a5632c
Merge pull request #56 from gabrielfnayres/main
luke-carlson Dec 9, 2024
d034077
Merge branch 'main' of github.com:ethanernst11/ml-mdm
ethanernst11 Dec 10, 2024
26af98a
Added type hinting to some functions in samplers.py
ethanernst11 Dec 10, 2024
559f10e
Added more type hinting in samplers.py
ethanernst11 Dec 10, 2024
e0c1c8c
Took out breakpoints and added more typehinting
ethanernst11 Dec 10, 2024
5d57c51
Changes to trainer.py to match remote repository
ethanernst11 Dec 10, 2024
b9d7d32
Added one import for type hinting
ethanernst11 Dec 10, 2024
42dd762
Merge branch 'main' into ethan-typeHint
ethanernst11 Dec 10, 2024
4ea6554
Changed comments to TODO and added two type hints
ethanernst11 Dec 15, 2024
04e35b0
Merge branch 'ethan-typeHint' of github.com:ethanernst11/ml-mdm into …
ethanernst11 Dec 15, 2024
49fc8e3
Merge pull request #59 from ethanernst11/ethan-typeHint
luke-carlson Dec 17, 2024
ffbdfa7
typing NestedSampler, modelEMA, nestedUNET, UNET and Diffusion
gabrielfnayres Jan 5, 2025
0a06989
removing tuple import
gabrielfnayres Jan 5, 2025
23d2614
Merge pull request #62 from gabrielfnayres/main
luke-carlson Jan 17, 2025
c667183
Feature/namespace package config (#5)
luke-carlson Feb 11, 2025
256f576
Include test_tokenizer fix
luke-carlson Feb 12, 2025
de88fbc
fixed self attention implementation and separated the self attention …
gabrielfnayres Feb 22, 2025
df6d533
improved tests flag comments
gabrielfnayres Feb 22, 2025
753feef
debug?
gabrielfnayres Feb 25, 2025
a3eac3f
selfAttention1D passes
bdeanhardt Mar 3, 2025
14d877c
working through parity issues
bdeanhardt Mar 4, 2025
3a8052f
debugging temp attention and self atten tests
gabrielfnayres Mar 13, 2025
63d2884
temporal attention with psnr comparison: 20dB
gabrielfnayres Mar 20, 2025
517fd1f
fixing file name
gabrielfnayres Mar 23, 2025
68db322
Merge pull request #3 from gabrielfnayres/temporal_atten_mlx
bdeanhardt Mar 24, 2025
432acf4
Merge pull request #68 from bdeanhardt/bella/TemporalAttentionBlock_MLX
luke-carlson Mar 27, 2025
fab0d14
fixing dirs structure
gabrielfnayres Mar 27, 2025
dfa4718
Merge pull request #71 from gabrielfnayres/main
luke-carlson Mar 27, 2025
904bea1
fix: fixing conflicts with main branch
gabrielfnayres Apr 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ Developers should set up `pre-commit` as well with `pre-commit install`.
### Running Test Cases

```
> pytest # will run all test cases - including ones that require a gpu
> pytest -m "not gpu" # run test cases that can work with just cpu
> pytest # run test cases that can work with just cpu
> pytest -m '' # will run all test cases - including ones that require a gpu
> pytest -m gpu # run only gpu test cases
```


Expand Down Expand Up @@ -99,6 +100,30 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port

## Codebase


### 1. /configs

| module | description |
| - | - |
| `configs.dataset_creation` | Configuration file for dataset splitting into train-eval-val pipeline |
| `configs.datasets` | Datasets for training and evaluation phases of the model |
| `configs.models` | Configuration files for different resolution models |


### 2. /data

| module | description |
| - | - |
| `data` | <ul><li><b>bert.vocab:</b> BERT-trained dictionary containing tokens and their associated vector representations</li><li><b>c4_wpm.vocab:</b> C4-trained dictionary containing tokens and their associated vector representations</li><li><b>cifar10.vocab:</b> CIFAR10-trained dictionary containing tokens and their associated vector representations</li><li><b>imagenet.vocab:</b> Prompts associated with Imagenet dataset</li><li><b>prompts_cc12m-64x64.tsv:</b> Prompts associated with cc12m dataset for the 64x64 res. model</li><li><b>prompts_cc12m-256x256.tsv:</b> Prompts associated with cc12m dataset for the 256x256 res. model</li><li><b>prompts_cifar10-32x32.tsv:</b> Prompts associated with cifar10 dataset for the 32x32 res. model </li><li><b>prompts_cifar10-64x64.tsv:</b> Prompts associated with cifar10 dataset for the 64x64 res. model </li><li><b>prompts_demo.tsv:</b> Extra demo prompts </li><li><b>prompts_imagenet-64px.tsv:</b> Prompts associated with imagenet dataset for the 64x64 res. model </li><li><b>prompts_WebImage-ALIGN-64px.tsv:</b> Prompts associated with WebImage-ALIGN dataset for the 64x64 res. model </li><li><b>t5.vocab:</b> t5-trained dictionary containing tokens and their associated vector representations </li><li><b>tokenizer_spm_32000_50m.vocab:</b> SPM-trained dictionary containing tokens and their associated vector representations </li></ul> |

### 3. /docs

| module | description |
| - | - |
| `docs` | <ul><li><b>web_demo.png:</b> Screenshot of the web demo of the model</li></ul> |

### 4. /ml_mdm

| module | description |
| - | - |
| `ml_mdm.models` | The core model implementations |
Expand All @@ -107,7 +132,11 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port
| `ml_mdm.clis` | All command line tools in the project, the most relevant being `train_parallel.py` |
| `tests/` | Unit tests and sample training files |

### 5. /tests

| module | description |
| - | - |
| `tests.test_files` | Sample files for testing |

# Concepts

Expand All @@ -125,6 +154,22 @@ In the `ml_mdm.models` submodule, we've open sourced our implementations of:
> In essence, `simple_parsing` will convert all passed cli arguments and yaml files into clean configuration classes like `ml_mdm.reader.ReaderConfig`, `ml_mdm.diffusion.DiffusionConfig`.


`ml_mdm.config` stores a global mapping of names to classes in `MODEL_REGISTRY`, `MODEL_CONFIG_REGISTRY`, `PIPELINE_REGISTRY`, and `PIPELINE_CONFIG_REGISTRY`.

`MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example:

> *_CONFIG_REGISTRY[architecture name]["model"] = model name

> *_CONFIG_REGISTRY[architecture name]["config"] = configuration class

MODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY store information as shown in the following example:
> *_CONFIG_REGISTRY[architecture name]["model"] = model name

> *_CONFIG_REGISTRY[architecture name]["config"] = configuration class


architecture name and model name are passed into ml_mdm.config through the function parameter *names. where *names points to "architecture name", "model name"



# Tutorials
Expand Down Expand Up @@ -263,11 +308,11 @@ reader_config:
Then you can use our dataset download helper:
```console
python -m ml_mdm.clis.download_tar_from_index \
--dataset-config-file configs/datasets/cc12m.yaml \
--dataset_config_file configs/datasets/cc12m.yaml \
--subset train --download_tar

python -m ml_mdm.clis.download_tar_from_index \
--dataset-config-file configs/datasets/cc12m.yaml \
--dataset_config_file configs/datasets/cc12m.yaml \
--subset eval --download_tar
```

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
nodes this data will be distributed over.
"""

import argparse
import simple_parsing
import csv
import logging
import os
Expand All @@ -33,7 +33,30 @@
import mlx.data

from ml_mdm import helpers, s3_helpers

from dataclasses import dataclass, field

@dataclass
class DownloadConfig:
dataset_config_file: str = field(default="",
metadata={"help": "yaml file with dataset names"})
worker_id: int = field(default=0,
metadata={"help": "current worker in [0, num-downloaders -1]"})
num_downloaders: int = field(default=1,
metadata={"help": "number of parallel downloaders"})
no_bandwidth: bool = field(default=False)
download_tar: bool = field(default=False,
metadata={"help": "whether or not to download tar files also"})
pretrained_text_embeddings: str = field(default=None)
endpoint_url: str = field(default="",
metadata={"help": "end point for the s3 bucket — uses environment variable AWS_ENDPOINT_URL otherwise"})
subset: str = field(default="train",
metadata={"choices": ["train", "eval"],
"help": "subset to download [train|eval]"})

def get_parser():
parser = simple_parsing.ArgumentParser(description="Download tar files referred to in index file from mlx")
parser.add_arguments(DownloadConfig, dest="options")
return parser

def read_tsv(filename):
# Open the TSV file for reading
Expand Down Expand Up @@ -331,44 +354,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download tar files referred to in index file from mlx"
)
parser.add_argument(
"--dataset-config-file",
type=str,
default="",
help="yaml file with dataset names",
)
parser.add_argument(
"--worker-id",
type=int,
default=0,
help="current worker in [0, num-downloaders -1]",
)
parser.add_argument(
"--num-downloaders", type=int, default=1, help="number of parallel downloaders"
)
parser.add_argument("--no_bandwidth", action="store_true")
parser.add_argument(
"--download_tar",
action="store_true",
help="whether or not to download tar files also",
)
parser.add_argument("--pretrained-text-embeddings", type=str, default=None)
parser.add_argument(
"--endpoint-url",
type=str,
default="",
help="end point for the s3 bucket — uses environment variable AWS_ENDPOINT_URL otherwise",
)
parser.add_argument(
"--subset",
type=str,
default="train",
choices=["train", "eval"],
help="subset to download [train|eval]",
)
parser = get_parser()
args = parser.parse_args()
logging.basicConfig(
level="INFO",
Expand All @@ -377,5 +363,5 @@ def main(args):
),
datefmt="%H:%M:%S",
)
helpers.print_args(args)
main(args)
helpers.print_args(args.options)
main(args.options)
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
import argparse
import logging
import os
import shlex
import time
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import gradio as gr
import simple_parsing
Expand All @@ -16,6 +17,8 @@
import torch
from torchvision.utils import make_grid

import ml_mdm.language_models.factory
import ml_mdm.language_models.tokenizer
from ml_mdm import helpers, reader
from ml_mdm.config import get_arguments, get_model, get_pipeline
from ml_mdm.language_models import factory
Expand All @@ -36,22 +39,28 @@
)


def dividable(n):
def dividable(n: int) -> Tuple[int, int]:
for i in range(int(np.sqrt(n)), 0, -1):
if n % i == 0:
break
return i, n // i


def generate_lm_outputs(device, sample, tokenizer, language_model, args):
def generate_lm_outputs(
device: torch.device,
sample: dict,
tokenizer: ml_mdm.language_models.tokenizer.Tokenizer,
language_model: ml_mdm.language_models.factory.LanguageModel,
args: argparse.Namespace,
) -> dict:
with torch.no_grad():
lm_outputs, lm_mask = language_model(sample, tokenizer)
sample["lm_outputs"] = lm_outputs
sample["lm_mask"] = lm_mask
return sample


def setup_models(args, device):
def setup_models(args: argparse.Namespace, device: torch.device):
input_channels = 3

# load the language model
Expand All @@ -68,7 +77,10 @@ def setup_models(args, device):
return tokenizer, language_model, diffusion_model


def plot_logsnr(logsnrs, total_steps):

def plot_logsnr(logsnrs: list, total_steps: int) -> np.ndarray:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

x = 1 - np.arange(len(logsnrs)) / (total_steps - 1)
Expand Down Expand Up @@ -103,39 +115,40 @@ class GLOBAL_DATA:
global_config = GLOBAL_DATA()


def stop_run():
def stop_run() -> gr.component:
return (
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
)


def get_model_type(config_file):

def get_model_type(config_file: str) -> str:
with open(config_file, "r") as f:
d = yaml.safe_load(f)
return d.get("model", d.get("vision_model", "unet"))


def generate(
config_file="cc12m_64x64.yaml",
ckpt_name="vis_model_64x64.pth",
prompt="a chair",
input_template="",
negative_prompt="",
negative_template="",
batch_size=20,
guidance_scale=7.5,
threshold_function="clip",
num_inference_steps=250,
eta=0,
save_diffusion_path=False,
show_diffusion_path=False,
show_xt=False,
reader_config="",
seed=10,
comment="",
override_args="",
output_inner=False,
config_file: str = "cc12m_64x64.yaml",
ckpt_name: str = "vis_model_64x64.pth",
prompt: str = "a chair",
input_template: str = "",
negative_prompt: str = "",
negative_template: str = "",
batch_size: int = 20,
guidance_scale: float = 7.5,
threshold_function: str = "clip",
num_inference_steps: int = 250,
eta: int = 0,
save_diffusion_path: bool = False,
show_diffusion_path: bool = False,
show_xt: bool = False,
reader_config: str = "",
seed: int = 10,
comment: str = "",
override_args: str = "",
output_inner: bool = False,
):
np.random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -292,7 +305,7 @@ def generate(
)


def main(args):
def main(args: argparse.Namespace):
# get the language model outputs
example_texts = open("data/prompts_demo.tsv").readlines()

Expand Down
Loading