From 642faf311fdb6dae01a8f3ae5b0ec316d7676d29 Mon Sep 17 00:00:00 2001 From: Isabella Deanhardt <139717054+bdeanhardt@users.noreply.github.com> Date: Sun, 10 Nov 2024 19:04:24 -0500 Subject: [PATCH 01/64] test_configs.py docstrings --- tests/test_configs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_configs.py b/tests/test_configs.py index 2d3fe83..299d0b3 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -7,16 +7,20 @@ def test_unet_in_registry(): + """Check that 'nested_unet' and 'unet' models are correctly registered in the Model Registry.""" assert config.get_model("nested_unet") is not None assert config.get_model("unet") is not None def test_unet_in_pipeline(): + """Check that 'nested_unet' and 'unet' models have corresponding pipelines defined.""" assert config.get_pipeline("unet") is not None assert config.get_pipeline("nested_unet") is not None def test_config_cc12m_64x64(): + """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo).""" + f = "configs/models/cc12m_64x64.yaml" f = "configs/models/cc12m_64x64.yaml" args = config.get_arguments( mode="trainer", @@ -44,6 +48,7 @@ def test_config_cc12m_64x64(): def test_config_cc12m_256x256(): + """Check that the 'cc12m_256x256' configuration loads with 'nested_unet' as model in all modes (trainer, sampler, demo).""" f = "configs/models/cc12m_256x256.yaml" args = config.get_arguments( args=["--model=nested_unet"], @@ -75,6 +80,7 @@ def test_config_cc12m_256x256(): def test_config_cc12m_1024x1024(): + """Check that the 'cc12m_1024x1024' configuration loads with 'nested2_unet' model in all modes (trainer, sampler, demo).""" f = "configs/models/cc12m_1024x1024.yaml" args = config.get_arguments( args=["--model=nested2_unet"], From 1a82a4bf32876a465c38f378d20dc68ab6a553e9 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 12 Nov 2024 09:25:09 -0500 Subject: [PATCH 02/64] added docstrings to remaining test cases --- tests/test_generate_batch.py | 10 ++++++++++ tests/test_generate_sample.py | 4 ++++ tests/test_imports.py | 4 ++++ tests/test_models.py | 3 +++ tests/test_reader.py | 3 +++ tests/test_train.py | 1 + 6 files changed, 25 insertions(+) diff --git a/tests/test_generate_batch.py b/tests/test_generate_batch.py index 2334ac2..4edb759 100644 --- a/tests/test_generate_batch.py +++ b/tests/test_generate_batch.py @@ -8,6 +8,10 @@ def test_small_batch(): + """ + Test small batch generation with T5 model. + Check that basic data generation pipeline works with minimal settings. + """ args = Namespace( batch_size=10, test_file_list="tests/test_files/sample_training_0.tsv", @@ -33,6 +37,12 @@ def test_small_batch(): def test_generate_batch(): + """ + Test batch generation with default config settings. + + Note: This test currently only sets up the configuration but doesn't execute + the generation (ends with pass statement). + """ args = config.get_arguments(mode="sampler") args.batch_size = 10 args.test_file_list = "tests/test_files/sample_training_0.tsv" diff --git a/tests/test_generate_sample.py b/tests/test_generate_sample.py index d36c889..02c70e5 100644 --- a/tests/test_generate_sample.py +++ b/tests/test_generate_sample.py @@ -4,6 +4,10 @@ def test_load_flick_config(): + """ + Test loading of cc12m_64x64.yaml config file. + Checks image dimensions are correctly loaded in reader config. + """ args = config.get_arguments( "", mode="demo", diff --git a/tests/test_imports.py b/tests/test_imports.py index 19138c2..04474a8 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,6 +1,7 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. def test_top_level_imports_work(): + """Checks that all top-level ml_mdm module imports are accessible.""" from ml_mdm import ( config, diffusion, @@ -16,6 +17,7 @@ def test_top_level_imports_work(): def test_cli_imports_work(): + """Checks that all CLI module imports are accessible.""" from ml_mdm.clis import ( download_tar_from_index, generate_batch, @@ -25,8 +27,10 @@ def test_cli_imports_work(): def test_model_imports_work(): + """Checks that all model module imports are accessible.""" from ml_mdm.models import model_ema, nested_unet, unet def test_lm_imports_work(): + """Checks that all language model module imports are accessible.""" from ml_mdm.language_models import factory, tokenizer diff --git a/tests/test_models.py b/tests/test_models.py index b945fb5..2a96dad 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,6 +13,7 @@ def test_initialize_unet(): + """Test UNet model and EMA initialization with default configs.""" unet_config = models.unet.UNetConfig() diffusion_config = diffusion.DiffusionConfig( use_vdm_loss_weights=True, model_output_scale=0.1 @@ -30,6 +31,7 @@ def test_initialize_unet(): def test_all_registered_models(): + """Test instantiation of all models in the registry with default configs.""" for config_name, additional_info in config.MODEL_CONFIG_REGISTRY.items(): model_name = additional_info["model"] config_cls = additional_info["config"] @@ -44,6 +46,7 @@ def test_all_registered_models(): @pytest.mark.gpu def test_initialize_pretrained(): + """Test loading pretrained 64x64 model on GPU if available.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = config.get_arguments( diff --git a/tests/test_reader.py b/tests/test_reader.py index bcc3ef7..dd6d983 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -10,6 +10,7 @@ def test_get_dataset(): + """Test dataset loading and verify sample format and dimensions.""" tokenizer = factory.create_tokenizer("data/t5.vocab") dataset = reader.get_dataset( tokenizer=tokenizer, @@ -31,6 +32,7 @@ def test_get_dataset(): def test_get_dataset_partition(): + """Test dataset partitioning and iteration.""" tokenizer = factory.create_tokenizer("data/t5.vocab") train_loader = reader.get_dataset_partition( partition_num=0, @@ -46,6 +48,7 @@ def test_get_dataset_partition(): def test_process_text(): + """Test text tokenization with default reader config.""" line = "A bicycle on top of a boat." tokenizer = factory.create_tokenizer("data/t5.vocab") tokens = reader.process_text( diff --git a/tests/test_train.py b/tests/test_train.py index c1f2f52..073fcb4 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -23,6 +23,7 @@ reason="more effective to test this with torchrun, just here for documentation" ) def test_small(): + """Test minimal training run with single process setup.""" os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["LOCAL_RANK"] = "0" From 5cad65072b2680f2cc581cd75e6deebfa1f4e6d5 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 12 Nov 2024 09:34:51 -0500 Subject: [PATCH 03/64] removed duplicate line --- tests/test_configs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_configs.py b/tests/test_configs.py index 299d0b3..ed622ae 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -21,7 +21,6 @@ def test_unet_in_pipeline(): def test_config_cc12m_64x64(): """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo).""" f = "configs/models/cc12m_64x64.yaml" - f = "configs/models/cc12m_64x64.yaml" args = config.get_arguments( mode="trainer", additional_config_paths=[f], From e024eb0917990043303a6c35fa2d565c2cee1c6f Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 12 Nov 2024 21:37:00 -0500 Subject: [PATCH 04/64] Update distributed.py --- ml_mdm/distributed.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml_mdm/distributed.py b/ml_mdm/distributed.py index ebd2aa5..99997fc 100644 --- a/ml_mdm/distributed.py +++ b/ml_mdm/distributed.py @@ -32,7 +32,8 @@ def init_distributed_singlenode(timeout=0): rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) - if not "MASTER_ADDR" in os.environ: + + if not "MASTER_ADDR" in os.environ or world_size == 1: return local_rank, rank, world_size if timeout == 0: From 367131c18419edbf524dd01455c2ded46cf09a21 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 12 Nov 2024 21:39:47 -0500 Subject: [PATCH 05/64] Update trainer.py --- ml_mdm/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml_mdm/trainer.py b/ml_mdm/trainer.py index 1c2b93d..4135a82 100644 --- a/ml_mdm/trainer.py +++ b/ml_mdm/trainer.py @@ -50,7 +50,7 @@ def train_batch( grad_scaler.step(optimizer) grad_scaler.update() if ema_model is not None: - ema_model.update(model.model.module.vision_model) + ema_model.update(getattr(model.model, "module", model.model).vision_model) else: losses, times, x_t, means, targets, weights = model.get_loss(sample) if weights is None: @@ -74,7 +74,7 @@ def train_batch( ).item() optimizer.step() if ema_model is not None: - ema_model.update(model.model.module.vision_model) + ema_model.update(getattr(model.model, "module", model.model).vision_model) if logger is not None and not accumulate_gradient: logger.add_scalar("train/Loss", loss_val) From 571a6b951e17ff1375e46a5d0115d31a55bc6054 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 12 Nov 2024 21:43:57 -0500 Subject: [PATCH 06/64] Update train_parallel.py --- ml_mdm/clis/train_parallel.py | 44 ++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/ml_mdm/clis/train_parallel.py b/ml_mdm/clis/train_parallel.py index 6a4f0af..808a105 100644 --- a/ml_mdm/clis/train_parallel.py +++ b/ml_mdm/clis/train_parallel.py @@ -7,6 +7,7 @@ import logging import os import time +from contextlib import nullcontext import numpy as np import torch @@ -53,7 +54,11 @@ def main(args): local_rank, global_rank, world_size = init_distributed_singlenode(timeout=36000) input_channels = 3 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" tokenizer, language_model = factory.create_lm(args, device=device) language_model_dim = language_model.embed_dim @@ -71,7 +76,8 @@ def main(args): os.makedirs(args.output_dir) if "MASTER_ADDR" in os.environ: - dist.barrier() + if dist.is_available() and dist.is_initialized(): + dist.barrier() other_items = None if ( @@ -109,7 +115,8 @@ def main(args): else: grad_scaler = None - dist.barrier() + if dist.is_available() and dist.is_initialized(): + dist.barrier() max_lr = args.lr # Should eps be 1e-4 like for LMs in fp16 ? if args.use_adamw: @@ -137,13 +144,18 @@ def main(args): CLIP = 3 # intialize the model - model = nn.parallel.DistributedDataParallel( - diffusion_model.model, - device_ids=[local_rank], - ) + if int(os.environ.get("WORLD_SIZE", "1")) > 1: + model = nn.parallel.DistributedDataParallel( + diffusion_model.model, + device_ids=[local_rank], + ) + else: + model = diffusion_model.model diffusion_model.model = model - dist.barrier() - ema_model = ModelEma(diffusion_model.model.module.vision_model) + if dist.is_available() and dist.is_initialized(): + dist.barrier() + # Check if the model is wrapped in DistributedDataParallel + ema_model = ModelEma(getattr(diffusion_model.model, "module", diffusion_model.model).vision_model) # get the dataloader if args.multinode: @@ -187,7 +199,8 @@ def main(args): sample["images"] = images if accumulate_gradient: - with diffusion_model.model.no_sync(): + no_sync_context = diffusion_model.model.no_sync() if hasattr(diffusion_model.model, "no_sync") else nullcontext() + with no_sync_context: loss_val, losses, times, x_t, means, targets = trainer.train_batch( diffusion_model, sample, @@ -220,7 +233,6 @@ def main(args): num_time_counts += 1 if np.isnan(loss_val): continue - # accumulate loss if batch_num != 1: # E[(x-E[x])^2] = E[x^2] - E[x]^2 @@ -239,6 +251,8 @@ def main(args): exp_avg_loss = loss_val exp_avg_loss_var = loss_val**2 total_loss_val += loss_val + print(f"Allocated memory: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB", end='') + print(f"Val loss: {loss_val}") if (not accumulate_gradient) and (global_rank == 0): metrics = { @@ -274,12 +288,15 @@ def main(args): "args": args, } # save full config. ema_model.save(vision_model_file, other_items=other_items) - diffusion_model.model.module.vision_model.save( + getattr(diffusion_model.model, "module", diffusion_model.model).vision_model.save( vision_model_noema_file, other_items=other_items ) + torch.cuda.empty_cache() + torch.mps.empty_cache() if (batch_num % args.save_freq == 0) or (batch_num == args.num_training_steps): - dist.barrier() + if dist.is_available() and dist.is_initialized(): + dist.barrier() if batch_num == args.num_training_steps: break @@ -302,5 +319,6 @@ def main(args): np.random.seed(seed) torch.random.manual_seed(seed) torch.cuda.empty_cache() + torch.mps.empty_cache() helpers.print_args(args) main(args) From 3124abcb208f54cfa2f1e9a73393a62309cbf490 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 12 Nov 2024 21:59:24 -0500 Subject: [PATCH 07/64] Update train_parallel.py --- ml_mdm/clis/train_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml_mdm/clis/train_parallel.py b/ml_mdm/clis/train_parallel.py index 808a105..8e0a174 100644 --- a/ml_mdm/clis/train_parallel.py +++ b/ml_mdm/clis/train_parallel.py @@ -251,8 +251,8 @@ def main(args): exp_avg_loss = loss_val exp_avg_loss_var = loss_val**2 total_loss_val += loss_val - print(f"Allocated memory: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB", end='') - print(f"Val loss: {loss_val}") + # print(f"Allocated memory: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB", end='') + # print(f"Val loss: {loss_val}") if (not accumulate_gradient) and (global_rank == 0): metrics = { From 274ee5f7ae98acfa25244189ac535703fb452b82 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:09:28 -0500 Subject: [PATCH 08/64] Update train_parallel.py --- ml_mdm/clis/train_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml_mdm/clis/train_parallel.py b/ml_mdm/clis/train_parallel.py index 8e0a174..07bd4c7 100644 --- a/ml_mdm/clis/train_parallel.py +++ b/ml_mdm/clis/train_parallel.py @@ -54,11 +54,11 @@ def main(args): local_rank, global_rank, world_size = init_distributed_singlenode(timeout=36000) input_channels = 3 - device = "cpu" + device = torch.device("cpu") if torch.cuda.is_available(): - device = "cuda" + device = torch.device("cuda") elif torch.backends.mps.is_available(): - device = "mps" + device = torch.device("mps") tokenizer, language_model = factory.create_lm(args, device=device) language_model_dim = language_model.embed_dim From ac560df24772734c0bcaa61ba6713445280fb44c Mon Sep 17 00:00:00 2001 From: Emma Garrett Date: Wed, 13 Nov 2024 23:36:18 -0500 Subject: [PATCH 09/64] Type hints for train_batch --- ml_mdm/trainer.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ml_mdm/trainer.py b/ml_mdm/trainer.py index 1c2b93d..f018fd8 100644 --- a/ml_mdm/trainer.py +++ b/ml_mdm/trainer.py @@ -6,17 +6,17 @@ def train_batch( - model, - sample, - optimizer, - scheduler, - logger, - args, - grad_scaler=None, - accumulate_gradient=False, - num_grad_accumulations=1, - ema_model=None, - loss_factor=1, + model: torch.nn.Module, + sample: dict, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + logger: Optional[torch.utils.tensorboard.SummaryWriter], + args: Namespace, + grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, + accumulate_gradient: bool = False, + num_grad_accumulations: int =1, + ema_model: Optional[nn.Module] = None, + loss_factor: float = 1.0, ): model.train() lr = scheduler.get_last_lr()[0] From 7f083701b1f00e6baa5ecf960d8047cd54f7f5c4 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:24:45 -0500 Subject: [PATCH 10/64] Update torch requirement pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 43aaeeb..45a008a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "transformers", "sentencepiece", "boto3", - "torch==2.2.2", + "torch==2.5.1", "pytest", "pytest-cov", "pre-commit" From 1e1b28608e73792956d06e41852288ef60f6afc6 Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:37:56 -0500 Subject: [PATCH 11/64] Type hinted one function --- ml_mdm/samplers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index c72caac..c5ececa 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -112,9 +112,7 @@ class SamplerConfig: ) schedule_shifted_power: float = field( default=1, - metadata={ - "help": "noise shifted ratio, by default using 1." - }, + metadata={"help": "noise shifted ratio, by default using 1."}, ) @@ -146,12 +144,12 @@ def schedule_ddpm_defults( return gammas -def squaredcos_cap_v2(timesteps: int): +def squaredcos_cap_v2(timesteps: int) -> np.ndarray: """ https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L147 """ - def alpha_bar(time_step): + def alpha_bar(time_step: float) -> float: return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [0] @@ -249,7 +247,7 @@ def get_image_rescaled(self, images, scale_factor=None): def get_schedule_shifted(self, gammas, scale_factor=None): if (scale_factor is not None) and (scale_factor > 1): # rescale noise schecule p = self._config.schedule_shifted_power - scale_factor = scale_factor ** p + scale_factor = scale_factor**p snr = gammas / (1 - gammas) scaled_snr = snr / scale_factor gammas = 1 / (1 + 1 / scaled_snr) From 23e008ef2bbf66acf23ab652792d501b4b620cba Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 14 Nov 2024 14:52:56 -0500 Subject: [PATCH 12/64] Start of tokenizer tests. --- tests/test_tokenizer.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/test_tokenizer.py diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000..a4c44a1 --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,32 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. +from ml_mdm.language_models import tokenizer + +# Vocab files are in /ml-mdm/data/ +# Difference between .vocab file to a token file? + +# Tokenizer in factory.py create_tokenizer(vocab_file): you pass a vocab file, +# function returns something (dictionary? a tokenizer?) +# Tokenizer.py: each function is a different mode of tokenizing a token_file (?) +# (is this the same as a vocab file or is it a dictionary of tokens?) and they +# return a prefix tree structure. + +# CASE 1: We compare factory.py's create_tokenizer() to tokenizer.py's output. +# Ex.: assert create_tokenizer() == read_dictionary_bert() + +# CASE 2: We run the output of factory, and then pass its output as the token +# file to tokenizer.py's functions. Then, we need something else to assert +# against. +# Ex.: read_dictionary_bert(create_tokenizer(.vocab)) == ? + +def test_tokenizer_bert(): + pass + +def test_tokenizer_t5(): + pass + +# Is this a general vocab file that's not either bert or t5? +def test_tokenizer(): + pass + + From 7147d66785cf557ff61e7204d5f3cdbc304cb73e Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 14 Nov 2024 15:59:47 -0500 Subject: [PATCH 13/64] update test_tokenizer and fixed import in factory.py --- tests/test_tokenizer.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index a4c44a1..02cd51b 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,32 +1,34 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. -from ml_mdm.language_models import tokenizer + +import logging + +# Tokenizer class from tokenizer.py +from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py # Vocab files are in /ml-mdm/data/ -# Difference between .vocab file to a token file? -# Tokenizer in factory.py create_tokenizer(vocab_file): you pass a vocab file, -# function returns something (dictionary? a tokenizer?) # Tokenizer.py: each function is a different mode of tokenizing a token_file (?) # (is this the same as a vocab file or is it a dictionary of tokens?) and they # return a prefix tree structure. -# CASE 1: We compare factory.py's create_tokenizer() to tokenizer.py's output. -# Ex.: assert create_tokenizer() == read_dictionary_bert() - -# CASE 2: We run the output of factory, and then pass its output as the token -# file to tokenizer.py's functions. Then, we need something else to assert -# against. -# Ex.: read_dictionary_bert(create_tokenizer(.vocab)) == ? def test_tokenizer_bert(): - pass + f = "../data/bert.vocab" + assert Tokenizer(f, mode="bert") + #Q: should we assert the contents of tokenizer? def test_tokenizer_t5(): - pass + f = "../data/t5.vocab" + assert Tokenizer(f, mode="tf") + #Q: should we assert the contents of tokenizer? -# Is this a general vocab file that's not either bert or t5? +# any vocab file that's not either bert or t5 def test_tokenizer(): - pass - + f = "../data/imagenet.vocab" + assert Tokenizer(f) + #Q: should we assert the contents of tokenizer? +test_tokenizer_bert() +test_tokenizer_t5() +test_tokenizer() \ No newline at end of file From 76f5502f4d7be27efb2da13a2d8f7c51ac79170b Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 14 Nov 2024 16:10:50 -0500 Subject: [PATCH 14/64] Finished test_tokenizer.py - need to check if we should assert the contents of tokenizer. --- tests/test_configs.py | 2 +- tests/test_reader.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_configs.py b/tests/test_configs.py index 2d3fe83..84ecfb4 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -102,4 +102,4 @@ def test_config_cc12m_1024x1024(): mode="demo", additional_config_paths=[f], ) - assert args + assert args \ No newline at end of file diff --git a/tests/test_reader.py b/tests/test_reader.py index bcc3ef7..baea019 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -53,3 +53,6 @@ def test_process_text(): ) assert len(tokens) > 0 assert len(tokens[0]) > 0 + + +test_get_dataset() \ No newline at end of file From 9802f2a21bf67b658be1fa134d7c61592a00f950 Mon Sep 17 00:00:00 2001 From: Pedro Borges Date: Fri, 15 Nov 2024 19:57:05 +0100 Subject: [PATCH 15/64] Update test_tokenizer.py --- tests/test_tokenizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 02cd51b..f2b0bb7 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -4,6 +4,7 @@ import logging # Tokenizer class from tokenizer.py +from pathlib import Path from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py # Vocab files are in /ml-mdm/data/ @@ -14,7 +15,8 @@ def test_tokenizer_bert(): - f = "../data/bert.vocab" + # f = "../data/bert.vocab" + f = Path(__file__).parent/"data/bert.vocab" # To solve from relative to absolute import assert Tokenizer(f, mode="bert") #Q: should we assert the contents of tokenizer? @@ -31,4 +33,4 @@ def test_tokenizer(): test_tokenizer_bert() test_tokenizer_t5() -test_tokenizer() \ No newline at end of file +test_tokenizer() From 195fae4c5eb9f293cc6d7066b7c0271d92c5d36b Mon Sep 17 00:00:00 2001 From: Pedro Borges Date: Sat, 16 Nov 2024 03:58:24 +0100 Subject: [PATCH 16/64] Update test_tokenizer.py --- tests/test_tokenizer.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index f2b0bb7..49ea485 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -3,33 +3,20 @@ import logging -# Tokenizer class from tokenizer.py from pathlib import Path from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py -# Vocab files are in /ml-mdm/data/ - -# Tokenizer.py: each function is a different mode of tokenizing a token_file (?) -# (is this the same as a vocab file or is it a dictionary of tokens?) and they -# return a prefix tree structure. - - def test_tokenizer_bert(): - # f = "../data/bert.vocab" f = Path(__file__).parent/"data/bert.vocab" # To solve from relative to absolute import assert Tokenizer(f, mode="bert") - #Q: should we assert the contents of tokenizer? def test_tokenizer_t5(): - f = "../data/t5.vocab" + f = Path(__file__).parent/"data/t5.vocab" assert Tokenizer(f, mode="tf") - #Q: should we assert the contents of tokenizer? - -# any vocab file that's not either bert or t5 + def test_tokenizer(): - f = "../data/imagenet.vocab" + f = Path(__file__).parent/"data/imagenet.vocab" assert Tokenizer(f) - #Q: should we assert the contents of tokenizer? test_tokenizer_bert() test_tokenizer_t5() From e62c9de56ae3606317d21712c954852cdbc1deba Mon Sep 17 00:00:00 2001 From: Aaliyah Bullen Date: Sun, 17 Nov 2024 15:10:37 -0500 Subject: [PATCH 17/64] changed token_file to vocab_file --- ml_mdm/language_models/tokenizer.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ml_mdm/language_models/tokenizer.py b/ml_mdm/language_models/tokenizer.py index 0fb08dd..b3af8a8 100644 --- a/ml_mdm/language_models/tokenizer.py +++ b/ml_mdm/language_models/tokenizer.py @@ -5,11 +5,11 @@ from mlx.data.core import CharTrie -def read_dictionary_bert(token_file): +def read_dictionary_bert(vocab_file): trie_key_scores = [] trie = CharTrie() - f = open(token_file, "rb") + f = open(vocab_file, "rb") sep = "\u2581".encode() max_score = 0 @@ -42,11 +42,11 @@ def read_dictionary_bert(token_file): return trie, trie_key_scores, eos, bos, pad -def read_dictionary_t5(token_file): +def read_dictionary_t5(vocab_file): trie_key_scores = [] trie = CharTrie() - f = open(token_file, "rb") + f = open(vocab_file, "rb") sep = "\u2581".encode() max_score = 0 @@ -75,7 +75,7 @@ def read_dictionary_t5(token_file): return trie, trie_key_scores, eos, bos, pad -def read_dictionary(token_file): +def read_dictionary(vocab_file): trie_key_scores = [] trie = CharTrie() @@ -85,7 +85,7 @@ def read_dictionary(token_file): trie.insert(token) trie_key_scores.append(0.0) - f = open(token_file, "rb") + f = open(vocab_file, "rb") sep = "\u2581".encode() max_score = 0 @@ -130,7 +130,7 @@ def read_dictionary(token_file): class Tokenizer: - def __init__(self, token_file, mode=None): + def __init__(self, vocab_file, mode=None): if mode == "t5": ( self._trie, @@ -138,7 +138,7 @@ def __init__(self, token_file, mode=None): self.eos, self.bos, self.pad, - ) = read_dictionary_t5(token_file) + ) = read_dictionary_t5(vocab_file) elif mode == "bert": ( self._trie, @@ -146,7 +146,7 @@ def __init__(self, token_file, mode=None): self.eos, self.bos, self.pad, - ) = read_dictionary_bert(token_file) + ) = read_dictionary_bert(vocab_file) else: ( self._trie, @@ -154,7 +154,7 @@ def __init__(self, token_file, mode=None): self.eos, self.bos, self.pad, - ) = read_dictionary(token_file) + ) = read_dictionary(vocab_file) self.vocab_size = self._trie.num_keys() @property From 3b1eca28f936dd4b8f4483645b2745bd2a0e41de Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:08:45 -0500 Subject: [PATCH 18/64] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 43aaeeb..bf00bbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "transformers", "sentencepiece", "boto3", - "torch==2.2.2", + "torch>=2.5.1", "pytest", "pytest-cov", "pre-commit" From fde2865a0480adf8f0a2cee88cfdff0d5a5e4818 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Tue, 19 Nov 2024 12:12:39 -0500 Subject: [PATCH 19/64] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45a008a..849f31d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,11 @@ dependencies = [ "tensorboard==2.16.2", "torchinfo", "torchmetrics[image]", - "torchvision", + "torchvision>=0.20.1", "transformers", "sentencepiece", "boto3", - "torch==2.5.1", + "torch>=2.5.1", "pytest", "pytest-cov", "pre-commit" From 7619a12c734e2877e1700d554a2196b48809e6aa Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Wed, 20 Nov 2024 18:13:57 -0500 Subject: [PATCH 20/64] Update generate_sample.py (matplotlib use Agg backend) --- ml_mdm/clis/generate_sample.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml_mdm/clis/generate_sample.py b/ml_mdm/clis/generate_sample.py index 975ab31..9708b5c 100644 --- a/ml_mdm/clis/generate_sample.py +++ b/ml_mdm/clis/generate_sample.py @@ -69,6 +69,8 @@ def setup_models(args, device): def plot_logsnr(logsnrs, total_steps): + import matplotlib + matplotlib.use('Agg') import matplotlib.pyplot as plt x = 1 - np.arange(len(logsnrs)) / (total_steps - 1) From bf2fad3f1383a832e52f985ef15cd7abf7c5a85d Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 21 Nov 2024 14:59:35 -0500 Subject: [PATCH 21/64] small edit. --- README.md | 1 + ml_mdm/language_models/factory.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1841881..9d6f2ce 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port ## Codebase +1. ml_mdm | module | description | | - | - | | `ml_mdm.models` | The core model implementations | diff --git a/ml_mdm/language_models/factory.py b/ml_mdm/language_models/factory.py index 180d406..df8f838 100644 --- a/ml_mdm/language_models/factory.py +++ b/ml_mdm/language_models/factory.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F -from .tokenizer import Tokenizer +from ml_mdm.language_models.tokenizer import Tokenizer class T5Encoder(T5ForConditionalGeneration): From 81fca8ccab2aa3bfc104f52622ce5beed8f65292 Mon Sep 17 00:00:00 2001 From: Pedro Borges Date: Thu, 21 Nov 2024 15:01:53 -0500 Subject: [PATCH 22/64] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9d6f2ce..ebf87fe 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,8 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port ## Codebase -1. ml_mdm +1. ml_mdm + | module | description | | - | - | | `ml_mdm.models` | The core model implementations | From 875837a34039fa6a5350e8a6caf271de1f991743 Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 21 Nov 2024 15:17:15 -0500 Subject: [PATCH 23/64] README changes --- README.md | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9d6f2ce..fc9f475 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,28 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port ## Codebase -1. ml_mdm + +1. configs + +| module | description | +| - | - | +| `configs.dataset_creation` | Configuration file for dataset splitting into train-eval-val pipeline | +| `configs.datasets` | Datasets for training and evaluation phases of the model | +| `configs.models` | Configuration files for different resolution models | + + +2. data + +| module | description | +| - | - | +| | | +| `data` | - bert.vocab: BERT-sourced dictionary of tokens containing text data and their associated vector representations +- jdhjwjsdjiejej +- idjdjkdkkd | + +3. docs +4. ml_mdm + | module | description | | - | - | | `ml_mdm.models` | The core model implementations | @@ -108,6 +129,8 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | `ml_mdm.clis` | All command line tools in the project, the most relevant being `train_parallel.py` | | `tests/` | Unit tests and sample training files | +5. tests + # Concepts From 13a4168de41174fcf594b0089083446124ed2810 Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 21 Nov 2024 15:21:36 -0500 Subject: [PATCH 24/64] List test. --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7e435ae..3dac191 100644 --- a/README.md +++ b/README.md @@ -113,10 +113,8 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | module | description | | - | - | -| | | -| `data` | - bert.vocab: BERT-sourced dictionary of tokens containing text data and their associated vector representations -- jdhjwjsdjiejej -- idjdjkdkkd | +| `data` |
  • bert.vocab: BERT-sourced dictionary of tokens containing text data and their associated vector representations
  • bert.vocab: BERT-sourced dictionary of tokens containing text data and their associated vector representations
| + 3. docs From c198434a82c5a1c0271276f28c0e7146fb5fab3f Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 21 Nov 2024 15:50:56 -0500 Subject: [PATCH 25/64] Updated README with file descriptions. --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3dac191..21d267e 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,14 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | module | description | | - | - | -| `data` |
  • bert.vocab: BERT-sourced dictionary of tokens containing text data and their associated vector representations
  • bert.vocab: BERT-sourced dictionary of tokens containing text data and their associated vector representations
| - +| `data` |
  • bert.vocab: BERT-trained dictionary containing tokens and their associated vector representations
  • c4_wpm.vocab: C4-trained dictionary containing tokens and their associated vector representations
  • cifar10.vocab: CIFAR10-trained dictionary containing tokens and their associated vector representations
  • imagenet.vocab:
  • prompts_cc12m-64x64.tsv: Prompts for the 64x64 model
  • prompts_cc12m-64x64.tsv: Prompts for the 64x64 model
  • prompts_cc12m-256x256.tsv:
  • prompts_cifar10-32x32.tsv
  • prompts_cifar10-64x64.tsv
  • prompts_demo.tsv
  • prompts_imagenet-64px.tsv
  • prompts_WebImage-ALIGN-64px.tsv
  • t5.vocab
  • tokenizer_spm_32000_50m.vocab
| 3. docs +| module | description | +| - | - | +| `docs` |
  • web_demo.png: Screenshot of the web demo of the model
| + 4. ml_mdm | module | description | @@ -130,6 +133,10 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port 5. tests +| module | description | +| - | - | +| `tests.test_files` | Sample files for testing | + # Concepts From d2fd65f91b925a7eccb715c0aff66e369df71489 Mon Sep 17 00:00:00 2001 From: Pedro Borges Date: Thu, 21 Nov 2024 15:53:28 -0500 Subject: [PATCH 26/64] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 21d267e..64aeae9 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port ## Codebase -1. configs +### 1. /configs | module | description | | - | - | @@ -109,19 +109,19 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | `configs.models` | Configuration files for different resolution models | -2. data +### 2. /data | module | description | | - | - | | `data` |
  • bert.vocab: BERT-trained dictionary containing tokens and their associated vector representations
  • c4_wpm.vocab: C4-trained dictionary containing tokens and their associated vector representations
  • cifar10.vocab: CIFAR10-trained dictionary containing tokens and their associated vector representations
  • imagenet.vocab:
  • prompts_cc12m-64x64.tsv: Prompts for the 64x64 model
  • prompts_cc12m-64x64.tsv: Prompts for the 64x64 model
  • prompts_cc12m-256x256.tsv:
  • prompts_cifar10-32x32.tsv
  • prompts_cifar10-64x64.tsv
  • prompts_demo.tsv
  • prompts_imagenet-64px.tsv
  • prompts_WebImage-ALIGN-64px.tsv
  • t5.vocab
  • tokenizer_spm_32000_50m.vocab
| -3. docs +### 3. /docs | module | description | | - | - | | `docs` |
  • web_demo.png: Screenshot of the web demo of the model
| -4. ml_mdm +### 4. /ml_mdm | module | description | | - | - | @@ -131,7 +131,7 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | `ml_mdm.clis` | All command line tools in the project, the most relevant being `train_parallel.py` | | `tests/` | Unit tests and sample training files | -5. tests +### 5. /tests | module | description | | - | - | From fd89a1abea2a34d28081056667f5afddb9498ddf Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Thu, 21 Nov 2024 16:01:23 -0500 Subject: [PATCH 27/64] Update to README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 21d267e..562f8f1 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | module | description | | - | - | -| `data` |
  • bert.vocab: BERT-trained dictionary containing tokens and their associated vector representations
  • c4_wpm.vocab: C4-trained dictionary containing tokens and their associated vector representations
  • cifar10.vocab: CIFAR10-trained dictionary containing tokens and their associated vector representations
  • imagenet.vocab:
  • prompts_cc12m-64x64.tsv: Prompts for the 64x64 model
  • prompts_cc12m-64x64.tsv: Prompts for the 64x64 model
  • prompts_cc12m-256x256.tsv:
  • prompts_cifar10-32x32.tsv
  • prompts_cifar10-64x64.tsv
  • prompts_demo.tsv
  • prompts_imagenet-64px.tsv
  • prompts_WebImage-ALIGN-64px.tsv
  • t5.vocab
  • tokenizer_spm_32000_50m.vocab
| +| `data` |
  • bert.vocab: BERT-trained dictionary containing tokens and their associated vector representations
  • c4_wpm.vocab: C4-trained dictionary containing tokens and their associated vector representations
  • cifar10.vocab: CIFAR10-trained dictionary containing tokens and their associated vector representations
  • imagenet.vocab: Prompts associated with Imagenet dataset
  • prompts_cc12m-64x64.tsv: Prompts associated with cc12m dataset for the 64x64 res. model
  • prompts_cc12m-256x256.tsv: Prompts associated with cc12m dataset for the 256x256 res. model
  • prompts_cifar10-32x32.tsv: Prompts associated with cifar10 dataset for the 32x32 res. model
  • prompts_cifar10-64x64.tsv: Prompts associated with cifar10 dataset for the 64x64 res. model
  • prompts_demo.tsv: Extra demo prompts
  • prompts_imagenet-64px.tsv: Prompts associated with imagenet dataset for the 64x64 res. model
  • prompts_WebImage-ALIGN-64px.tsv: Prompts associated with WebImage-ALIGN dataset for the 64x64 res. model
  • t5.vocab: t5-trained dictionary containing tokens and their associated vector representations
  • tokenizer_spm_32000_50m.vocab: SPM-trained dictionary containing tokens and their associated vector representations
| 3. docs From a6f45dff269899c1802ff566d175168aa3cbd047 Mon Sep 17 00:00:00 2001 From: Pedro Borges Date: Thu, 21 Nov 2024 16:04:08 -0500 Subject: [PATCH 28/64] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b082880..b0fe761 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port | module | description | | - | - | -| `docs` |
  • web_demo.png: Screenshot of the web demo of the model
| +| `docs` |
  • web_demo.png: Screenshot of the web demo of the model
| ### 4. /ml_mdm From aa40a702172177d78c6954d514079b0c2f6d36f7 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Sat, 23 Nov 2024 13:16:56 -0500 Subject: [PATCH 29/64] switch the default pytest call to run just CPU test cases --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 43aaeeb..ca85b12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ known_numeric = ["torch", "torchvision", "numpy", "jax", "flax", "mlx"] [tool.pytest.ini_options] -addopts = "--cov=ml_mdm" +addopts = "--cov=ml_mdm -m 'not gpu'" markers = [ "gpu" # tests that require a gpu ] From 1a433f5c79507947cb120a9fb65769968781cd78 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Sat, 23 Nov 2024 13:20:29 -0500 Subject: [PATCH 30/64] updated readme to reflect new pytest functionalities --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1841881..a81261b 100644 --- a/README.md +++ b/README.md @@ -64,8 +64,9 @@ Developers should set up `pre-commit` as well with `pre-commit install`. ### Running Test Cases ``` -> pytest # will run all test cases - including ones that require a gpu -> pytest -m "not gpu" # run test cases that can work with just cpu +> pytest # run test cases that can work with just cpu +> pytest -m '' # will run all test cases - including ones that require a gpu +> pytest -m gpu # run only gpu test cases ``` From 4d29bbf4a5a31939732167d9dcb220881f324c4d Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Sat, 23 Nov 2024 14:33:21 -0500 Subject: [PATCH 31/64] Added some typehinting to diffusion.py and generate_sample.py --- ml_mdm/clis/generate_sample.py | 19 ++++++++++---- ml_mdm/diffusion.py | 46 ++++++++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/ml_mdm/clis/generate_sample.py b/ml_mdm/clis/generate_sample.py index 975ab31..ca3cae0 100644 --- a/ml_mdm/clis/generate_sample.py +++ b/ml_mdm/clis/generate_sample.py @@ -1,5 +1,6 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. +import argparse import logging import os import shlex @@ -16,6 +17,8 @@ import torch from torchvision.utils import make_grid +import ml_mdm.language_models.factory +import ml_mdm.language_models.tokenizer from ml_mdm import helpers, reader from ml_mdm.config import get_arguments, get_model, get_pipeline from ml_mdm.language_models import factory @@ -43,7 +46,13 @@ def dividable(n): return i, n // i -def generate_lm_outputs(device, sample, tokenizer, language_model, args): +def generate_lm_outputs( + device: torch.device, + sample: dict, + tokenizer: ml_mdm.language_models.tokenizer.Tokenizer, + language_model: ml_mdm.language_models.factory.LanguageModel, + args: argparse.Namespace, +) -> dict: with torch.no_grad(): lm_outputs, lm_mask = language_model(sample, tokenizer) sample["lm_outputs"] = lm_outputs @@ -68,7 +77,7 @@ def setup_models(args, device): return tokenizer, language_model, diffusion_model -def plot_logsnr(logsnrs, total_steps): +def plot_logsnr(logsnrs: list, total_steps: int) -> np.ndarray: import matplotlib.pyplot as plt x = 1 - np.arange(len(logsnrs)) / (total_steps - 1) @@ -103,14 +112,14 @@ class GLOBAL_DATA: global_config = GLOBAL_DATA() -def stop_run(): +def stop_run() -> gr.component: return ( gr.update(value="Run", variant="primary", visible=True), gr.update(visible=False), ) -def get_model_type(config_file): +def get_model_type(config_file: str) -> str: with open(config_file, "r") as f: d = yaml.safe_load(f) return d.get("model", d.get("vision_model", "unet")) @@ -292,7 +301,7 @@ def generate( ) -def main(args): +def main(args: argparse.Namespace): # get the language model outputs example_texts = open("data/prompts_demo.tsv").readlines() diff --git a/ml_mdm/diffusion.py b/ml_mdm/diffusion.py index c11b034..a2d05b5 100644 --- a/ml_mdm/diffusion.py +++ b/ml_mdm/diffusion.py @@ -13,10 +13,12 @@ import torch.nn.functional as F from torchvision.utils import save_image +import ml_mdm.samplers from ml_mdm import config, samplers def sv(x, f): + breakpoint() # 1 save_image(x, f, value_range=(-1, 1), normalize=True) @@ -29,7 +31,8 @@ def sv(x, f): @dataclass class DiffusionConfig: sampler_config: samplers.SamplerConfig = field( - default_factory=samplers.SamplerConfig, metadata={"help": "Sampler configuration"} + default_factory=samplers.SamplerConfig, + metadata={"help": "Sampler configuration"}, ) model_output_scale: float = field( default=0, @@ -58,13 +61,14 @@ def __init__( self.vision_model = vision_model self.sampler = None - def set_sampler(self, sampler): + def set_sampler(self, sampler: ml_mdm.samplers.Sampler): self.sampler = sampler - def load(self, vision_file): + def load(self, vision_file: str) -> dict: return self.vision_model.load(vision_file) def save(self, vision_file, other_items=None): + breakpoint() # 4 self.vision_model.save(vision_file, other_items=other_items) @property @@ -102,7 +106,7 @@ def get_model(self): return self.model.module return self.model - def to(self, device): + def to(self, device: torch.device): self.model = self.model.to(device) self.sampler = self.sampler.to(device) return self @@ -116,13 +120,16 @@ def eval(self): def get_xt_minus_1(self, t, x_t, lm_outputs, lm_mask): self.eval() + breakpoint() # 8 return self.sampler.get_xt_minus_1(t, x_t, lm_outputs, lm_mask) def get_pred_for_training(self, x_t, pred, g): + breakpoint() # 9 if ( self._config.sampler_config.loss_target_type == self._config.sampler_config.prediction_type ): + breakpoint() # 10 return pred else: x0, _ = self.sampler.get_x0_eps_from_pred( @@ -131,15 +138,17 @@ def get_pred_for_training(self, x_t, pred, g): pred = self.sampler.get_pred_from_x0_xt( x_t, x0, g, self._config.sampler_config.loss_target_type ) + breakpoint() # 11 return pred - def get_micro_conditioning(self, sample): + def get_micro_conditioning(self, sample: dict) -> dict: micros, conditions = {}, self.get_model().vision_model.conditions if conditions is not None: micros = {key: sample[key] for key in conditions if key in sample} return micros - def get_loss(self, sample): + def get_loss(self, sample: dict): + breakpoint() # 13 images, lm_outputs, lm_mask = ( sample["images"], sample["lm_outputs"], @@ -163,14 +172,28 @@ def get_loss(self, sample): ) pred = self.get_pred_for_training(x_t, means, g) loss = self.loss_fn(pred, tgt).mean(axis=(1, 2, 3)) + breakpoint() # 14 return loss, time, x_t, means, tgt, weights - def get_noise(self, num_examples, input_channels, image_side, device): + def get_noise( + self, + num_examples: int, + input_channels: int, + image_side: int, + device: torch.device, + ) -> torch.Tensor: return torch.randn(num_examples, input_channels, image_side, image_side).to( device ) - def sample(self, num_examples, sample, image_side, device, **kwargs): + def sample( + self, + num_examples: int, + sample: dict, + image_side: int, + device: torch.device, + **kwargs: dict, + ): self.eval() noise = self.get_noise( num_examples, self.get_model().input_channels, image_side, device @@ -184,8 +207,10 @@ def sample(self, num_examples, sample, image_side, device, **kwargs): def partial_diffusion( self, images, t, lm_outputs, lm_mask, device, return_sequence=False ): + breakpoint() # 17 self.eval() (_, x_t, _, _) = self.sampler.get_noisy_samples_for_training(images, t) + breakpoint() # 18 return self.sampler.sample( x_t, lm_outputs, lm_mask, return_sequence=return_sequence, t=t ) @@ -297,7 +322,8 @@ def __init__(self, denoising_model, diffusion_config: DiffusionConfig): ) self.mixed_ratio = self.mixed_ratio / self.mixed_ratio[-1] - def get_loss(self, sample): + def get_loss(self, sample: dict): + breakpoint() # 19 images, lm_outputs, lm_mask = ( sample["images"], sample["lm_outputs"], @@ -369,5 +395,5 @@ def get_loss(self, sample): loss_ = pred[i].mean() * 0.0 loss_ = loss_ * w[i] loss = loss + loss_ - + breakpoint() # 20 return loss, time, x_t[0], pred[0], tgt[0], weights From 0b82947593c1098314fbd9f1b8687e9e6bdfcf0a Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Mon, 25 Nov 2024 15:44:13 -0500 Subject: [PATCH 32/64] Separate python dependencies into optional groups. --- pyproject.toml | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 43aaeeb..de7fc47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,33 +17,39 @@ description = "A python package to simplify the creation of text conditioned ima dependencies = [ "dataclass-wizard", "einops", - "fastapi>=0.109.1", # Required due to CVE-2024-24762 - "gradio>=4.14", # Required due to CVE-2023-6572 "httpx==0.24.1", - "imageio[ffmpeg]", - "matplotlib", "mlx-data", "numpy<2", - "pytorch-model-summary", - "rotary-embedding-torch", "simple-parsing==0.1.5", - "tensorboardX==2.6.2.2", - "tensorboard==2.16.2", - "torchinfo", - "torchmetrics[image]", "torchvision", "transformers", "sentencepiece", - "boto3", "torch==2.2.2", - "pytest", - "pytest-cov", - "pre-commit" ] [project.optional-dependencies] data_prep = [ - "img2dataset" + "img2dataset", + "boto3", +] +web_demo = [ + "fastapi>=0.109.1", # Required due to CVE-2024-24762 + "gradio>=4.14", # Required due to CVE-2023-6572 + "matplotlib", + "imageio[ffmpeg]", +] +training = [ + "tensorboard==2.16.2", + "tensorboardX==2.6.2.2", + "torchmetrics[image]", + "rotary-embedding-torch", + "pytorch-model-summary", + "torchinfo", +] +dev = [ + "pytest", + "pytest-cov", + "pre-commit", ] [tool.isort] From ce43e0549e4fdadfc1487c0ca1ba05f38d764c5f Mon Sep 17 00:00:00 2001 From: Pedro Cruz Date: Mon, 25 Nov 2024 15:49:12 -0500 Subject: [PATCH 33/64] Separate python dependencies into optional groups --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8303d7a..645d6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ profile = "black" sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "FIRSTPARTY", "LOCALFOLDER"] known_numeric = ["torch", "torchvision", "numpy", "jax", "flax", "mlx"] - [tool.pytest.ini_options] addopts = "--cov=ml_mdm -m 'not gpu'" markers = [ From 50c4e73be43e32a76fd4a8fe91eff1c8f25402aa Mon Sep 17 00:00:00 2001 From: Bregwin Jogi Date: Wed, 27 Nov 2024 22:57:49 -0500 Subject: [PATCH 34/64] Update files to use simple parsing --- ml_mdm/clis/download_tar_from_index.py | 70 +++++++++------------ ml_mdm/clis/run_torchmetrics.py | 86 ++++++++++++-------------- 2 files changed, 66 insertions(+), 90 deletions(-) diff --git a/ml_mdm/clis/download_tar_from_index.py b/ml_mdm/clis/download_tar_from_index.py index faba594..08bb793 100644 --- a/ml_mdm/clis/download_tar_from_index.py +++ b/ml_mdm/clis/download_tar_from_index.py @@ -17,7 +17,7 @@ nodes this data will be distributed over. """ -import argparse +import simple_parsing import csv import logging import os @@ -33,7 +33,30 @@ import mlx.data from ml_mdm import helpers, s3_helpers - +from dataclasses import dataclass, field + +@dataclass +class DownloadConfig: + dataset_config_file: str = field(default="", + metadata={"help": "yaml file with dataset names"}) + worker_id: int = field(default=0, + metadata={"help": "current worker in [0, num-downloaders -1]"}) + num_downloaders: int = field(default=1, + metadata={"help": "number of parallel downloaders"}) + no_bandwidth: bool = field(default=False) + download_tar: bool = field(default=False, + metadata={"help": "whether or not to download tar files also"}) + pretrained_text_embeddings: str = field(default=None) + endpoint_url: str = field(default="", + metadata={"help": "end point for the s3 bucket — uses environment variable AWS_ENDPOINT_URL otherwise"}) + subset: str = field(default="train", + metadata={"choices": ["train", "eval"], + "help": "subset to download [train|eval]"}) + +def get_parser(): + parser = simple_parsing.ArgumentParser(description="Download tar files referred to in index file from mlx") + parser.add_arguments(DownloadConfig, dest="options") + return parser def read_tsv(filename): # Open the TSV file for reading @@ -331,44 +354,7 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Download tar files referred to in index file from mlx" - ) - parser.add_argument( - "--dataset-config-file", - type=str, - default="", - help="yaml file with dataset names", - ) - parser.add_argument( - "--worker-id", - type=int, - default=0, - help="current worker in [0, num-downloaders -1]", - ) - parser.add_argument( - "--num-downloaders", type=int, default=1, help="number of parallel downloaders" - ) - parser.add_argument("--no_bandwidth", action="store_true") - parser.add_argument( - "--download_tar", - action="store_true", - help="whether or not to download tar files also", - ) - parser.add_argument("--pretrained-text-embeddings", type=str, default=None) - parser.add_argument( - "--endpoint-url", - type=str, - default="", - help="end point for the s3 bucket — uses environment variable AWS_ENDPOINT_URL otherwise", - ) - parser.add_argument( - "--subset", - type=str, - default="train", - choices=["train", "eval"], - help="subset to download [train|eval]", - ) + parser = get_parser() args = parser.parse_args() logging.basicConfig( level="INFO", @@ -377,5 +363,5 @@ def main(args): ), datefmt="%H:%M:%S", ) - helpers.print_args(args) - main(args) + helpers.print_args(args.options) + main(args.options) \ No newline at end of file diff --git a/ml_mdm/clis/run_torchmetrics.py b/ml_mdm/clis/run_torchmetrics.py index ec5a502..9331370 100644 --- a/ml_mdm/clis/run_torchmetrics.py +++ b/ml_mdm/clis/run_torchmetrics.py @@ -1,6 +1,7 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. -import argparse + +import simple_parsing import json import logging import os @@ -14,7 +15,38 @@ import torch from ml_mdm import helpers - +from dataclasses import dataclass, field + +@dataclass +class MetricsConfig: + loglevel: str = field(default="INFO", + metadata={"help": "Logging level"}) + sample_dir: str = field(default="", + metadata={"help": "directory with samples"}) + metrics: str = field(default="clip,fid", + metadata={"help": "Metrics to compute(comma separated)"}) + reference_dir: str = field(default="", + metadata={"help": "directory with reference images"}) + num_samplers: int = field(default=1, + metadata={"help": "Number of jobs generating samples"}) + num_training_steps: int = field(default=850000, + metadata={"help": "# of training steps to train for"}) + max_caption_length: int = field(default=77, + metadata={"help": "Maximum length of caption"}) + eval_freq: int = field(default=1000, + metadata={"help": "Minimum Evaluation interval"}) + clip_model: str = field(default="openai/clip-vit-base-patch16", + metadata={"help": "Model to use for clip scores"}) + inception_layer_fid: int = field(default=2048, + metadata={ + "choices": [64, 192, 768, 2048], + "help": "Which layer of inception to use for fid" + }) + +def get_parser(): + parser = simple_parsing.ArgumentParser(description="Compute metrics on samples from diffusion model") + parser.add_arguments(MetricsConfig, dest="options") + return parser def load_captions_and_images(dir_name, args, override_path=None): map_files = [] @@ -140,54 +172,12 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Compute metrics on samples from diffusion model" - ) - parser.add_argument("--loglevel", type=str, default="INFO", help="Logging level") - parser.add_argument( - "--sample-dir", type=str, default="", help="directory with samples" - ) - parser.add_argument( - "--metrics", - type=str, - default="clip,fid", - help="Metrics to compute(comma separated)", - ) - parser.add_argument( - "--reference-dir", type=str, default="", help="directory with reference images" - ) - parser.add_argument( - "--num-samplers", type=int, default=1, help="Number of jobs generating samples" - ) - parser.add_argument( - "--num-training-steps", - type=int, - default=850000, - help="# of training steps to train for", - ) - parser.add_argument( - "--max-caption-length", type=int, default=77, help="Maximum length of caption" - ) - parser.add_argument( - "--eval-freq", type=int, default=1000, help="Minimum Evaluation interval" - ) - parser.add_argument( - "--clip-model", - type=str, - default="openai/clip-vit-base-patch16", - help="Model to use for clip scores", - ) - parser.add_argument( - "--inception-layer-fid", - type=int, - default=2048, - choices=[64, 192, 768, 2048], - help="Which layer of inception to use for fid", - ) + parser = get_parser() args = parser.parse_args() logging.basicConfig( - level=getattr(logging, args.loglevel.upper(), None), + level=getattr(logging, args.options.loglevel.upper(), None), format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s", datefmt="%H:%M:%S", ) - main(args) + helpers.print_args(args.options) + main(args.options) From 19d616a6c82809584397a3014b2fcaea6f9b0721 Mon Sep 17 00:00:00 2001 From: Bregwin Jogi Date: Thu, 28 Nov 2024 12:21:49 -0500 Subject: [PATCH 35/64] Fix argument issues readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 37ff804..c9a98ac 100644 --- a/README.md +++ b/README.md @@ -292,11 +292,11 @@ reader_config: Then you can use our dataset download helper: ```console python -m ml_mdm.clis.download_tar_from_index \ - --dataset-config-file configs/datasets/cc12m.yaml \ + --dataset_config_file configs/datasets/cc12m.yaml \ --subset train --download_tar python -m ml_mdm.clis.download_tar_from_index \ - --dataset-config-file configs/datasets/cc12m.yaml \ + --dataset_config_file configs/datasets/cc12m.yaml \ --subset eval --download_tar ``` From 10d71672ab12f974c3e96309f94c3e2a6dedec22 Mon Sep 17 00:00:00 2001 From: Aaliyah Bullen Date: Mon, 2 Dec 2024 21:22:25 -0500 Subject: [PATCH 36/64] added config description --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index 37ff804..0a8fdaf 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,20 @@ In the `ml_mdm.models` submodule, we've open sourced our implementations of: > In essence, `simple_parsing` will convert all passed cli arguments and yaml files into clean configuration classes like `ml_mdm.reader.ReaderConfig`, `ml_mdm.diffusion.DiffusionConfig`. +`ml_mdm.config` stores a global mapping of names to classes in `MODEL_REGISTRY`, `MODEL_CONFIG_REGISTRY`, `PIPELINE_REGISTRY`, and `PIPELINE_CONFIG_REGISTRY`. + +`MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example + +> *_CONFIG_REGISTRY[architecture name]["model"] = model name +> *_CONFIG_REGISTRY[architecture name]["config"] = configuration class + +MODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY stores information as shown in the following example: +> *_CONFIG_REGISTRY[architecture name]["model"] = model name +> *_CONFIG_REGISTRY[architecture name]["config"] = configuration class + + +architecture name and model name are passed into ml_mdm.config through the function parameter *names. where *names points to "architecture name", "model name" + # Tutorials From c3e66cc7fb200b4681e1de90fbcbb765af5311b2 Mon Sep 17 00:00:00 2001 From: aaliyahnl <106929014+aaliyahnl@users.noreply.github.com> Date: Mon, 2 Dec 2024 21:36:36 -0500 Subject: [PATCH 37/64] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 0a8fdaf..0a4e395 100644 --- a/README.md +++ b/README.md @@ -159,10 +159,12 @@ In the `ml_mdm.models` submodule, we've open sourced our implementations of: `MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example > *_CONFIG_REGISTRY[architecture name]["model"] = model name + > *_CONFIG_REGISTRY[architecture name]["config"] = configuration class MODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY stores information as shown in the following example: > *_CONFIG_REGISTRY[architecture name]["model"] = model name + > *_CONFIG_REGISTRY[architecture name]["config"] = configuration class From 7509888f2613bc7d75156d1244eb55ede8731519 Mon Sep 17 00:00:00 2001 From: aaliyahnl <106929014+aaliyahnl@users.noreply.github.com> Date: Mon, 2 Dec 2024 21:38:40 -0500 Subject: [PATCH 38/64] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0a4e395..2c53d0e 100644 --- a/README.md +++ b/README.md @@ -156,13 +156,13 @@ In the `ml_mdm.models` submodule, we've open sourced our implementations of: `ml_mdm.config` stores a global mapping of names to classes in `MODEL_REGISTRY`, `MODEL_CONFIG_REGISTRY`, `PIPELINE_REGISTRY`, and `PIPELINE_CONFIG_REGISTRY`. -`MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example +`MODEL_REGISTRY` and `PIPELINE_REGISTRY` store information as shown in the following example: > *_CONFIG_REGISTRY[architecture name]["model"] = model name > *_CONFIG_REGISTRY[architecture name]["config"] = configuration class -MODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY stores information as shown in the following example: +MODEL_CONFIG_REGISTRY and PIPELINE_CONFIG_REGISTRY store information as shown in the following example: > *_CONFIG_REGISTRY[architecture name]["model"] = model name > *_CONFIG_REGISTRY[architecture name]["config"] = configuration class From 823834230b3fc19f1c923947a69e77ccf619e3e0 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Wed, 4 Dec 2024 10:26:24 -0500 Subject: [PATCH 39/64] typehinting for generate and dividable --- ml_mdm/clis/generate_sample.py | 47 +++++++++++++++++----------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/ml_mdm/clis/generate_sample.py b/ml_mdm/clis/generate_sample.py index 975ab31..acf5496 100644 --- a/ml_mdm/clis/generate_sample.py +++ b/ml_mdm/clis/generate_sample.py @@ -1,11 +1,12 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. +import argparse import logging import os import shlex import time from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple import gradio as gr import simple_parsing @@ -36,7 +37,7 @@ ) -def dividable(n): +def dividable(n: int) -> Tuple[int, int]: for i in range(int(np.sqrt(n)), 0, -1): if n % i == 0: break @@ -51,7 +52,7 @@ def generate_lm_outputs(device, sample, tokenizer, language_model, args): return sample -def setup_models(args, device): +def setup_models(args: argparse.Namespace, device: torch.device): input_channels = 3 # load the language model @@ -110,32 +111,32 @@ def stop_run(): ) -def get_model_type(config_file): +def get_model_type(config_file: str): with open(config_file, "r") as f: d = yaml.safe_load(f) return d.get("model", d.get("vision_model", "unet")) def generate( - config_file="cc12m_64x64.yaml", - ckpt_name="vis_model_64x64.pth", - prompt="a chair", - input_template="", - negative_prompt="", - negative_template="", - batch_size=20, - guidance_scale=7.5, - threshold_function="clip", - num_inference_steps=250, - eta=0, - save_diffusion_path=False, - show_diffusion_path=False, - show_xt=False, - reader_config="", - seed=10, - comment="", - override_args="", - output_inner=False, + config_file: str = "cc12m_64x64.yaml", + ckpt_name: str = "vis_model_64x64.pth", + prompt: str = "a chair", + input_template: str = "", + negative_prompt: str = "", + negative_template: str = "", + batch_size: int = 20, + guidance_scale: float = 7.5, + threshold_function: str = "clip", + num_inference_steps: int = 250, + eta: int = 0, + save_diffusion_path: bool = False, + show_diffusion_path: bool = False, + show_xt: bool = False, + reader_config: str = "", + seed: int = 10, + comment: str = "", + override_args: str = "", + output_inner: bool = False, ): np.random.seed(seed) torch.random.manual_seed(seed) From 96568032156ab21ab01c1aa24db1982a8b29da8e Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:36:48 -0500 Subject: [PATCH 40/64] Took out breakpoints --- ml_mdm/diffusion.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/ml_mdm/diffusion.py b/ml_mdm/diffusion.py index a2d05b5..f44ae71 100644 --- a/ml_mdm/diffusion.py +++ b/ml_mdm/diffusion.py @@ -18,7 +18,6 @@ def sv(x, f): - breakpoint() # 1 save_image(x, f, value_range=(-1, 1), normalize=True) @@ -68,7 +67,6 @@ def load(self, vision_file: str) -> dict: return self.vision_model.load(vision_file) def save(self, vision_file, other_items=None): - breakpoint() # 4 self.vision_model.save(vision_file, other_items=other_items) @property @@ -120,16 +118,13 @@ def eval(self): def get_xt_minus_1(self, t, x_t, lm_outputs, lm_mask): self.eval() - breakpoint() # 8 return self.sampler.get_xt_minus_1(t, x_t, lm_outputs, lm_mask) def get_pred_for_training(self, x_t, pred, g): - breakpoint() # 9 if ( self._config.sampler_config.loss_target_type == self._config.sampler_config.prediction_type ): - breakpoint() # 10 return pred else: x0, _ = self.sampler.get_x0_eps_from_pred( @@ -138,7 +133,6 @@ def get_pred_for_training(self, x_t, pred, g): pred = self.sampler.get_pred_from_x0_xt( x_t, x0, g, self._config.sampler_config.loss_target_type ) - breakpoint() # 11 return pred def get_micro_conditioning(self, sample: dict) -> dict: @@ -148,7 +142,6 @@ def get_micro_conditioning(self, sample: dict) -> dict: return micros def get_loss(self, sample: dict): - breakpoint() # 13 images, lm_outputs, lm_mask = ( sample["images"], sample["lm_outputs"], @@ -172,7 +165,6 @@ def get_loss(self, sample: dict): ) pred = self.get_pred_for_training(x_t, means, g) loss = self.loss_fn(pred, tgt).mean(axis=(1, 2, 3)) - breakpoint() # 14 return loss, time, x_t, means, tgt, weights def get_noise( @@ -207,10 +199,8 @@ def sample( def partial_diffusion( self, images, t, lm_outputs, lm_mask, device, return_sequence=False ): - breakpoint() # 17 self.eval() (_, x_t, _, _) = self.sampler.get_noisy_samples_for_training(images, t) - breakpoint() # 18 return self.sampler.sample( x_t, lm_outputs, lm_mask, return_sequence=return_sequence, t=t ) @@ -323,7 +313,6 @@ def __init__(self, denoising_model, diffusion_config: DiffusionConfig): self.mixed_ratio = self.mixed_ratio / self.mixed_ratio[-1] def get_loss(self, sample: dict): - breakpoint() # 19 images, lm_outputs, lm_mask = ( sample["images"], sample["lm_outputs"], @@ -395,5 +384,4 @@ def get_loss(self, sample: dict): loss_ = pred[i].mean() * 0.0 loss_ = loss_ * w[i] loss = loss + loss_ - breakpoint() # 20 return loss, time, x_t[0], pred[0], tgt[0], weights From 5bf935b15430ddc2d04d6bf7ae19ddf826971976 Mon Sep 17 00:00:00 2001 From: Emma Garrett Date: Thu, 5 Dec 2024 14:08:47 -0500 Subject: [PATCH 41/64] gpu and cpu dependencies --- pyproject.toml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 645d6c6..e304f49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,14 @@ dependencies = [ ] [project.optional-dependencies] +cpu = [ + "torch==2.2.2+cpu", + "tensorflow==2.5.0", +] +gpu = [ + "torch==2.2.2+cu111", + "tensorflow-gpu==2.5.0", +] data_prep = [ "img2dataset", "boto3", From 0e78b014450e965f4a495a08e4b71f5838de5a15 Mon Sep 17 00:00:00 2001 From: gabrielfnayres Date: Thu, 5 Dec 2024 22:48:48 -0300 Subject: [PATCH 42/64] Fixing missing modules in pyproject.toml --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index e304f49..fc6fcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,12 @@ dependencies = [ "transformers", "sentencepiece", "torch==2.2.2", + "matplotlib", + "gradio", + "boto3", + "torchmetrics", + "img2dataset", + "torchinfo" ] [project.optional-dependencies] From 5b793c0840c6f835a7f85cda3c31869138cd0aa5 Mon Sep 17 00:00:00 2001 From: gabrielfnayres Date: Thu, 5 Dec 2024 23:01:54 -0300 Subject: [PATCH 43/64] added sigmoid beta scheduler --- ml_mdm/samplers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index c72caac..fe4a14c 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -26,6 +26,7 @@ class ScheduleType(Type): COSINE = 0 DDPM = 2 DEEPFLOYD = 3 + SIGMOID = 4 @staticmethod def argparse(s): @@ -164,6 +165,10 @@ def alpha_bar(time_step): gammas = np.exp(np.cumsum(log_alphas)) return gammas +def schedule_sigmoid(timesteps: int, beta_start: float, beta_end: float) -> np.ndarray: + """https://arxiv.org/pdf/2301.10972""" + betas = np.linspace((-6,6), timesteps) + return (1/(np.exp(betas) + 1)) * (beta_end - beta_start) + beta_start ########################################################################################## # Sampler Class # From ab5851c688235f35d4ab3c4d717466f5cd6f9f50 Mon Sep 17 00:00:00 2001 From: gabrielfnayres Date: Fri, 6 Dec 2024 16:40:13 -0300 Subject: [PATCH 44/64] Added some typing in s3_helpers.py and trainer.py --- ml_mdm/s3_helpers.py | 22 +++++++++++----------- ml_mdm/trainer.py | 3 ++- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ml_mdm/s3_helpers.py b/ml_mdm/s3_helpers.py index e70479e..a57af49 100644 --- a/ml_mdm/s3_helpers.py +++ b/ml_mdm/s3_helpers.py @@ -12,10 +12,10 @@ def download_object( - bucket_name, - file_name, - download_path=None, - endpoint_url=ENDPOINT_URL, + bucket_name: str, + file_name: str, + download_path: str =None, + endpoint_url: str = ENDPOINT_URL, max_bandwidth=None, ): """Downloads an object from S3 to local.""" @@ -37,7 +37,7 @@ def download_object( return download_path -def download_object_from_full_path(path, download_path=None, endpoint_url=ENDPOINT_URL): +def download_object_from_full_path(path: str, download_path: str =None, endpoint_url: str = ENDPOINT_URL): bucket_name, parent_path, basename = _parse_path(path) file_name = os.path.join(parent_path, basename) return download_object( @@ -46,10 +46,10 @@ def download_object_from_full_path(path, download_path=None, endpoint_url=ENDPOI def upload_object( - bucket_name, - file_name, - upload_path, - endpoint_url=ENDPOINT_URL, + bucket_name: str, + file_name: str, + upload_path: str, + endpoint_url: str = ENDPOINT_URL, ): """Uload an object from S3 to local.""" @@ -70,7 +70,7 @@ def _parse_path(tsv_pattern): return bucket, "/".join(parts[3:-1]), pattern -def get_file_list(tsv_pattern, endpoint_url=ENDPOINT_URL): +def get_file_list(tsv_pattern: str, endpoint_url: str = ENDPOINT_URL): bucket_name, parent_path, pattern = _parse_path(tsv_pattern) resource = boto3.resource("s3", endpoint_url=endpoint_url) bucket = resource.Bucket(bucket_name) @@ -84,7 +84,7 @@ def get_file_list(tsv_pattern, endpoint_url=ENDPOINT_URL): return fnames -def download_parallel(files, endpoint_url=ENDPOINT_URL): +def download_parallel(files: str, endpoint_url: str=ENDPOINT_URL): logging.info("Doing parallel download") with ProcessPoolExecutor() as executor: logging.info(f"Submitting {files}") diff --git a/ml_mdm/trainer.py b/ml_mdm/trainer.py index fbafb6a..fe1e0d1 100644 --- a/ml_mdm/trainer.py +++ b/ml_mdm/trainer.py @@ -3,7 +3,8 @@ import numpy as np import torch import torch.nn as nn - +from typing import Optional, Tuple +from argparse import Namespace def train_batch( model: torch.nn.Module, From 2042cfaa5f20610303de8bb22e2d777865a4cb57 Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:42:28 -0500 Subject: [PATCH 45/64] imports for type hinting --- ml_mdm/trainer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ml_mdm/trainer.py b/ml_mdm/trainer.py index fbafb6a..ddc7efe 100644 --- a/ml_mdm/trainer.py +++ b/ml_mdm/trainer.py @@ -1,8 +1,12 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. +from argparse import Namespace +from typing import Optional + import numpy as np import torch import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter def train_batch( @@ -14,10 +18,11 @@ def train_batch( args: Namespace, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, accumulate_gradient: bool = False, - num_grad_accumulations: int =1, + num_grad_accumulations: int = 1, ema_model: Optional[nn.Module] = None, loss_factor: float = 1.0, ): + breakpoint() model.train() lr = scheduler.get_last_lr()[0] # Updates the scale for next iteration @@ -50,7 +55,9 @@ def train_batch( grad_scaler.step(optimizer) grad_scaler.update() if ema_model is not None: - ema_model.update(getattr(model.model, "module", model.model).vision_model) + ema_model.update( + getattr(model.model, "module", model.model).vision_model + ) else: losses, times, x_t, means, targets, weights = model.get_loss(sample) if weights is None: @@ -74,7 +81,9 @@ def train_batch( ).item() optimizer.step() if ema_model is not None: - ema_model.update(getattr(model.model, "module", model.model).vision_model) + ema_model.update( + getattr(model.model, "module", model.model).vision_model + ) if logger is not None and not accumulate_gradient: logger.add_scalar("train/Loss", loss_val) From 26af98a28ee5465ef0b0cbf9d79ae09c782889f1 Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Mon, 9 Dec 2024 21:53:32 -0500 Subject: [PATCH 46/64] Added type hinting to some functions in samplers.py --- ml_mdm/samplers.py | 91 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 63 insertions(+), 28 deletions(-) diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index c5ececa..1351b35 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -4,8 +4,10 @@ import math from dataclasses import dataclass, field from enum import Enum +from typing import Tuple from einops import repeat +from reader import method from tqdm import tqdm import numpy as np @@ -13,6 +15,8 @@ import torch.nn as nn import torch.nn.functional as F +import ml_mdm.diffusion + class Type(Enum): def __str__(self): @@ -187,12 +191,14 @@ def __init__(self, sampler_config: SamplerConfig): if self._config.loss_target_type is None: self._config.loss_target_type = self._config.prediction_type - def read_gamma(self, time, image): + def read_gamma(self, time: torch.Tensor, image: torch.Tensor) -> torch.Tensor: B, C, H, W = image.size() time = repeat(time, "b -> b c h w", c=C, h=H, w=W) return self.gammas[time] - def get_noise_schedule(self, schedule_type, n_steps, sampler_config): + def get_noise_schedule( + self, schedule_type: ScheduleType, n_steps: int, sampler_config: SamplerConfig + ): # pre-defined noise schedule functions if schedule_type == ScheduleType.COSINE: _gammas = schedule_cosine(n_steps) @@ -223,6 +229,7 @@ def get_noise_schedule(self, schedule_type, n_steps, sampler_config): self.register_buffer("vdm_loss_weights", weights) def get_eps_time(self, images, time=None): + breakpoint() # Breakpoint 1 batch_size = images.shape[0] if time is None: time = torch.randint(0, self.n_steps, (batch_size,), device=images.device) @@ -231,20 +238,27 @@ def get_eps_time(self, images, time=None): g, g_last = self.read_gamma(time + 1, images), self.read_gamma(time, images) weights = self.vdm_loss_weights[time + 1] eps = torch.randn_like(images) + breakpoint() # Breakpoint 2 return eps, g, g_last, weights, time def get_xt(self, images, eps, g): + breakpoint() # Breakpoint 3 x_t = g.sqrt() * images + (1 - g).sqrt() * eps + breakpoint() # Breakpoint 4 return x_t def get_image_rescaled(self, images, scale_factor=None): + breakpoint() # Breakpoint 5 if scale_factor is None: scale_factor = self._config.rescale_signal if scale_factor: # divide the signal images = images / scale_factor + breakpoint() # Breakpoint 6 return images - def get_schedule_shifted(self, gammas, scale_factor=None): + def get_schedule_shifted( + self, gammas: torch.Tensor, scale_factor: float = None + ) -> torch.Tensor: if (scale_factor is not None) and (scale_factor > 1): # rescale noise schecule p = self._config.schedule_shifted_power scale_factor = scale_factor**p @@ -254,6 +268,7 @@ def get_schedule_shifted(self, gammas, scale_factor=None): return gammas def get_prediction_targets(self, images, eps, g, g_last, prediction_type=None): + breakpoint() # Breakpoint 7 if prediction_type is None: prediction_type = self._config.loss_target_type @@ -266,21 +281,22 @@ def get_prediction_targets(self, images, eps, g, g_last, prediction_type=None): pred = g.sqrt() * eps - (1 - g).sqrt() * images else: raise Exception("Unsupported type") + breakpoint() # Breakpoint 8 return pred def get_prediction_xt_last( self, - x_t, - pred, - g, - g_last, - prediction_type=None, - clip_fn=None, - need_noise=False, - ddim_eta=None, + x_t: torch.Tensor, + pred: torch.Tensor, + g: torch.Tensor, + g_last: torch.Tensor, + prediction_type: PredictionType = None, + clip_fn: method = None, + need_noise: torch.Tensor = False, + ddim_eta: int = None, input_noise=None, image_scale=None, - ): + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ x_t: noisy image pred: model prediction (can be x0, eps, v, etc) @@ -291,7 +307,6 @@ def get_prediction_xt_last( need_noise: use noise or not ddim_eta: if None, then not using DDIM, otherwise, use DDIM implementation (1==DDPM) """ - if prediction_type is None: prediction_type = self._config.prediction_type @@ -336,8 +351,15 @@ def get_prediction_xt_last( return x0, x_t_last, eps def get_x0_eps_from_pred( - self, x_t, pred, g, prediction_type=None, clip_fn=None, return_eps=True + self, + x_t: torch.Tensor, + pred: torch.Tensor, + g: torch.Tensor, + prediction_typ: PredictionType = None, + clip_fn=None, + return_eps: bool = True, ): + breakpoint() # Breakpoint 11 batch_size = x_t.size(0) if prediction_type is None: prediction_type = self._config.prediction_type @@ -357,9 +379,11 @@ def get_x0_eps_from_pred( if not return_eps: return x0 eps = (x_t - x0 * g.sqrt()) / (1 - g).sqrt() + breakpoint() # Breakpoint 12 return x0, eps def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None): + breakpoint() # Breakpoint 13 batch_size = x_t.size(0) if prediction_type is None: prediction_type = self._config.prediction_type @@ -372,20 +396,21 @@ def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None): pred = (g.sqrt() * x_t - x0) / (1 - g).sqrt() else: raise Exception("prediction type not set to a correct value") + breakpoint() # Breakpoint 14 return pred def get_xt_minus_1( self, - model, - time_step, - x_t, - lm_outputs, - lm_mask, - micros={}, - time_step_last=None, - guidance_scale=1, - ddim_eta=None, - return_details=False, + model: ml_mdm.diffusion.Model, + time_step: torch.Tensor, + x_t: torch.Tensor, + lm_outputs: torch.Tensor, + lm_mask: torch.Tensor, + micros: dict = {}, + time_step_last: torch.Tensor = None, + guidance_scale: float = 1, + ddim_eta: int = None, + return_details: bool = False, ): batch_size = x_t.shape[0] ones = torch.ones(batch_size, dtype=torch.long, device=self.gammas.device) @@ -418,8 +443,15 @@ def get_xt_minus_1( return x_s def forward_model( - self, model, x_t, t, lm_outputs, lm_mask, micros={}, guidance_scale=1 - ): + self, + model: ml_mdm.diffusion.Model, + x_t: torch.Tensor, + t: torch.Tensor, + lm_outputs: torch.Tensor, + lm_mask: torch.Tensor, + micros: dict = {}, + guidance_scale: float = 1, + ) -> Tuple[torch.Tensor, torch.Tensor]: if guidance_scale != 1: assert x_t.shape[0] * 2 == lm_outputs.shape[0] pred, extras = model( @@ -437,8 +469,11 @@ def forward_model( return pred, extras def _threshold_sample( - self, sample, dynamic_thresholding_ratio=0.995, sample_max_value=100 - ): + self, + sample: torch.Tensor, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 100, + ) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by From 559f10e37e6a12dced1fb458c589db154d14bc47 Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:38:28 -0500 Subject: [PATCH 47/64] Added more type hinting in samplers.py --- ml_mdm/samplers.py | 82 ++++++++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index 1351b35..32438ca 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -7,7 +7,6 @@ from typing import Tuple from einops import repeat -from reader import method from tqdm import tqdm import numpy as np @@ -15,8 +14,6 @@ import torch.nn as nn import torch.nn.functional as F -import ml_mdm.diffusion - class Type(Enum): def __str__(self): @@ -291,7 +288,7 @@ def get_prediction_xt_last( g: torch.Tensor, g_last: torch.Tensor, prediction_type: PredictionType = None, - clip_fn: method = None, + clip_fn=None, # Said type class 'method' but could not find what that was need_noise: torch.Tensor = False, ddim_eta: int = None, input_noise=None, @@ -355,11 +352,10 @@ def get_x0_eps_from_pred( x_t: torch.Tensor, pred: torch.Tensor, g: torch.Tensor, - prediction_typ: PredictionType = None, + prediction_type: PredictionType = None, clip_fn=None, return_eps: bool = True, ): - breakpoint() # Breakpoint 11 batch_size = x_t.size(0) if prediction_type is None: prediction_type = self._config.prediction_type @@ -401,7 +397,7 @@ def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None): def get_xt_minus_1( self, - model: ml_mdm.diffusion.Model, + model, # Ethan-This is ml_mdm.diffusion.Model but to import diffusion it is a circular import time_step: torch.Tensor, x_t: torch.Tensor, lm_outputs: torch.Tensor, @@ -444,7 +440,7 @@ def get_xt_minus_1( def forward_model( self, - model: ml_mdm.diffusion.Model, + model, # Ethan-This is ml_mdm.diffusion.Model but to import diffusion it is a circular import x_t: torch.Tensor, t: torch.Tensor, lm_outputs: torch.Tensor, @@ -507,7 +503,7 @@ def _threshold_sample( return sample - def clip_sample(self, pred_x0, image_scale=1): + def clip_sample(self, pred_x0: torch.Tensor, image_scale: int = 1) -> torch.Tensor: s = image_scale if self._config.threshold_function == ThresholdType.CLIP: return (pred_x0 * s).clip(-1, 1) / s @@ -517,30 +513,34 @@ def clip_sample(self, pred_x0, image_scale=1): return self._threshold_sample(pred_x0 * s, 0.95, 1.5) / s return pred_x0 - def sample(self, *args, **kwargs): + def sample( + self, *args, **kwargs + ): # Ethan - Not sure how to print type of these and not sure about generator class + # breakpoint() if not kwargs.get("yield_output", False): output = self._sample(*args, **kwargs) return next(output) + # breakpoint() return self._sample(*args, **kwargs) def _sample( self, - model, - x_t, - lm_outputs, - lm_mask, - micros, - return_sequence=False, - use_beta_tilde=False, - t=-1, - num_inference_steps=2000, - ddim_eta=None, - guidance_scale=1, - resample_steps=False, - disable_bar=True, - yield_output=False, + model, # Ethan-This is ml_mdm.diffusion.Model but to import diffusion it is a circular import + x_t: torch.Tensor, + lm_outputs: torch.Tensor, + lm_mask: torch.Tensor, + micros: dict, + return_sequence: bool = False, + use_beta_tilde: bool = False, + t: int = -1, + num_inference_steps: int = 2000, + ddim_eta: int = None, + guidance_scale: float = 1, + resample_steps: bool = False, + disable_bar: bool = True, + yield_output: bool = False, **post_args, - ): + ): # Ethan - Generator and **post_args """ Starting with x_t, at time step t, perform diffusion to first step. """ @@ -583,18 +583,20 @@ def _sample( if return_sequence: seq[-1] = torch.clip(seq[-1], -1, 1) + breakpoint() yield seq else: + breakpoint() yield self._postprocess(x_t, x0, extra, clip=True, **post_args) def _postprocess( self, - x_t, - x0=None, - extra=None, - yield_full=False, - clip=False, - image_scale=None, + x_t: torch.Tensor, + x0: torch.Tensor = None, + extra: tuple = None, + yield_full: bool = False, + clip: bool = False, + image_scale=None, # Ethan-Ask about NoneTypes where None is default **unused, ): if image_scale is None: @@ -606,9 +608,10 @@ def _postprocess( x_t = torch.clip(x_t, -1, 1) if yield_full: return (x0, x_t, extra) + breakpoint() return x_t - def set_timesteps(self, num_inference_steps=250): + def set_timesteps(self, num_inference_steps: int = 250) -> np.ndarray: step_ratio = (self._config.num_diffusion_steps + 1) / (num_inference_steps + 1) timesteps = ( (np.arange(0, num_inference_steps + 1) * step_ratio) @@ -621,6 +624,7 @@ def set_timesteps(self, num_inference_steps=250): class NestedSampler(Sampler): def get_gammas(self, gamma, scales, images=None): + breakpoint() if not self._config.schedule_shifted: gammas = [gamma for _ in scales] else: @@ -630,9 +634,11 @@ def get_gammas(self, gamma, scales, images=None): F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(gammas, images) ] + breakpoint() return gammas def get_xt(self, x0, eps, g, scales): + breakpoint() x_t = [] for x, s, e, gi in zip(x0, scales, eps, g): x_t += [ @@ -644,9 +650,11 @@ def get_xt(self, x0, eps, g, scales): gi, ) ] + breakpoint() return x_t def get_prediction_targets(self, x0, eps, g, g_last, scales, prediction_type=None): + breakpoint() tgt = [] for x, s, e, gi, gil in zip(x0, scales, eps, g, g_last): tgt += [ @@ -660,6 +668,7 @@ def get_prediction_targets(self, x0, eps, g, g_last, scales, prediction_type=Non prediction_type, ) ] + breakpoint() return tgt def get_xt_minus_1( @@ -675,6 +684,7 @@ def get_xt_minus_1( ddim_eta=None, return_details=False, ): + breakpoint() scales = model.vision_model.nest_ratio + [1] if isinstance(x_t, torch.Tensor): out = [x_t] @@ -732,6 +742,7 @@ def _postprocess( output_inner=False, **unused, ): + breakpoint() scales = [ x_t[i].size(-1) / x_t[-1].size(-1) if not self._config.schedule_shifted @@ -752,6 +763,7 @@ def cat(x, size): # nx = x.new_ones(x.size(0), 3, size, size) # nx[..., :x.size(-2), :x.size(-1)] = x nx = F.interpolate(x, size, mode="bilinear") + breakpoint() return nx # output inner loop results. @@ -779,13 +791,18 @@ def cat(x, size): x0 = torch.cat([cat(xi, size=x0[0].size(-1)) for xi in x0[::-1]], -1) xt = torch.cat([cat(xi, size=xt[0].size(-1)) for xi in xt[::-1]], -1) out = (x0, xt, extra[-1]) + breakpoint() return out def forward_model( self, model, x_t, t, lm_outputs, lm_mask, micros={}, guidance_scale=1 ): + breakpoint() + def cfg(pred): + breakpoint() pred_uncond, pred = pred.chunk(2) + breakpoint() return pred_uncond + guidance_scale * (pred - pred_uncond) if guidance_scale != 1: @@ -800,4 +817,5 @@ def cfg(pred): p_t = [cfg(p) for p in p_t] else: p_t = model(x_t, t, lm_outputs, lm_mask, micros) + breakpoint() return p_t From e0c1c8c6027d93c7d762cc45bf22aeb145ef409a Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:22:49 -0500 Subject: [PATCH 48/64] Took out breakpoints and added more typehinting --- ml_mdm/samplers.py | 39 +++------------------------------------ 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index 32438ca..9946cd5 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -226,7 +226,6 @@ def get_noise_schedule( self.register_buffer("vdm_loss_weights", weights) def get_eps_time(self, images, time=None): - breakpoint() # Breakpoint 1 batch_size = images.shape[0] if time is None: time = torch.randint(0, self.n_steps, (batch_size,), device=images.device) @@ -235,22 +234,17 @@ def get_eps_time(self, images, time=None): g, g_last = self.read_gamma(time + 1, images), self.read_gamma(time, images) weights = self.vdm_loss_weights[time + 1] eps = torch.randn_like(images) - breakpoint() # Breakpoint 2 return eps, g, g_last, weights, time def get_xt(self, images, eps, g): - breakpoint() # Breakpoint 3 x_t = g.sqrt() * images + (1 - g).sqrt() * eps - breakpoint() # Breakpoint 4 return x_t def get_image_rescaled(self, images, scale_factor=None): - breakpoint() # Breakpoint 5 if scale_factor is None: scale_factor = self._config.rescale_signal if scale_factor: # divide the signal images = images / scale_factor - breakpoint() # Breakpoint 6 return images def get_schedule_shifted( @@ -265,7 +259,6 @@ def get_schedule_shifted( return gammas def get_prediction_targets(self, images, eps, g, g_last, prediction_type=None): - breakpoint() # Breakpoint 7 if prediction_type is None: prediction_type = self._config.loss_target_type @@ -278,7 +271,6 @@ def get_prediction_targets(self, images, eps, g, g_last, prediction_type=None): pred = g.sqrt() * eps - (1 - g).sqrt() * images else: raise Exception("Unsupported type") - breakpoint() # Breakpoint 8 return pred def get_prediction_xt_last( @@ -375,11 +367,9 @@ def get_x0_eps_from_pred( if not return_eps: return x0 eps = (x_t - x0 * g.sqrt()) / (1 - g).sqrt() - breakpoint() # Breakpoint 12 return x0, eps def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None): - breakpoint() # Breakpoint 13 batch_size = x_t.size(0) if prediction_type is None: prediction_type = self._config.prediction_type @@ -392,12 +382,11 @@ def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None): pred = (g.sqrt() * x_t - x0) / (1 - g).sqrt() else: raise Exception("prediction type not set to a correct value") - breakpoint() # Breakpoint 14 return pred def get_xt_minus_1( self, - model, # Ethan-This is ml_mdm.diffusion.Model but to import diffusion it is a circular import + model, # Ethan-This is ml_mdm.diffusion.Model but importing diffusion is a circular import time_step: torch.Tensor, x_t: torch.Tensor, lm_outputs: torch.Tensor, @@ -513,14 +502,10 @@ def clip_sample(self, pred_x0: torch.Tensor, image_scale: int = 1) -> torch.Tens return self._threshold_sample(pred_x0 * s, 0.95, 1.5) / s return pred_x0 - def sample( - self, *args, **kwargs - ): # Ethan - Not sure how to print type of these and not sure about generator class - # breakpoint() + def sample(self, *args, **kwargs): if not kwargs.get("yield_output", False): output = self._sample(*args, **kwargs) return next(output) - # breakpoint() return self._sample(*args, **kwargs) def _sample( @@ -583,10 +568,8 @@ def _sample( if return_sequence: seq[-1] = torch.clip(seq[-1], -1, 1) - breakpoint() yield seq else: - breakpoint() yield self._postprocess(x_t, x0, extra, clip=True, **post_args) def _postprocess( @@ -596,7 +579,7 @@ def _postprocess( extra: tuple = None, yield_full: bool = False, clip: bool = False, - image_scale=None, # Ethan-Ask about NoneTypes where None is default + image_scale=None, # Ethan-NoneType **unused, ): if image_scale is None: @@ -608,7 +591,6 @@ def _postprocess( x_t = torch.clip(x_t, -1, 1) if yield_full: return (x0, x_t, extra) - breakpoint() return x_t def set_timesteps(self, num_inference_steps: int = 250) -> np.ndarray: @@ -624,7 +606,6 @@ def set_timesteps(self, num_inference_steps: int = 250) -> np.ndarray: class NestedSampler(Sampler): def get_gammas(self, gamma, scales, images=None): - breakpoint() if not self._config.schedule_shifted: gammas = [gamma for _ in scales] else: @@ -634,11 +615,9 @@ def get_gammas(self, gamma, scales, images=None): F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(gammas, images) ] - breakpoint() return gammas def get_xt(self, x0, eps, g, scales): - breakpoint() x_t = [] for x, s, e, gi in zip(x0, scales, eps, g): x_t += [ @@ -650,11 +629,9 @@ def get_xt(self, x0, eps, g, scales): gi, ) ] - breakpoint() return x_t def get_prediction_targets(self, x0, eps, g, g_last, scales, prediction_type=None): - breakpoint() tgt = [] for x, s, e, gi, gil in zip(x0, scales, eps, g, g_last): tgt += [ @@ -668,7 +645,6 @@ def get_prediction_targets(self, x0, eps, g, g_last, scales, prediction_type=Non prediction_type, ) ] - breakpoint() return tgt def get_xt_minus_1( @@ -684,7 +660,6 @@ def get_xt_minus_1( ddim_eta=None, return_details=False, ): - breakpoint() scales = model.vision_model.nest_ratio + [1] if isinstance(x_t, torch.Tensor): out = [x_t] @@ -742,7 +717,6 @@ def _postprocess( output_inner=False, **unused, ): - breakpoint() scales = [ x_t[i].size(-1) / x_t[-1].size(-1) if not self._config.schedule_shifted @@ -763,7 +737,6 @@ def cat(x, size): # nx = x.new_ones(x.size(0), 3, size, size) # nx[..., :x.size(-2), :x.size(-1)] = x nx = F.interpolate(x, size, mode="bilinear") - breakpoint() return nx # output inner loop results. @@ -791,18 +764,13 @@ def cat(x, size): x0 = torch.cat([cat(xi, size=x0[0].size(-1)) for xi in x0[::-1]], -1) xt = torch.cat([cat(xi, size=xt[0].size(-1)) for xi in xt[::-1]], -1) out = (x0, xt, extra[-1]) - breakpoint() return out def forward_model( self, model, x_t, t, lm_outputs, lm_mask, micros={}, guidance_scale=1 ): - breakpoint() - def cfg(pred): - breakpoint() pred_uncond, pred = pred.chunk(2) - breakpoint() return pred_uncond + guidance_scale * (pred - pred_uncond) if guidance_scale != 1: @@ -817,5 +785,4 @@ def cfg(pred): p_t = [cfg(p) for p in p_t] else: p_t = model(x_t, t, lm_outputs, lm_mask, micros) - breakpoint() return p_t From 5d57c51ea08e0aab6bba46eb6bb879182b906fbc Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:30:06 -0500 Subject: [PATCH 49/64] Changes to trainer.py to match remote repository --- ml_mdm/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml_mdm/trainer.py b/ml_mdm/trainer.py index ddc7efe..de17d4a 100644 --- a/ml_mdm/trainer.py +++ b/ml_mdm/trainer.py @@ -1,12 +1,11 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. from argparse import Namespace -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn -from torch.utils.tensorboard import SummaryWriter def train_batch( @@ -22,7 +21,6 @@ def train_batch( ema_model: Optional[nn.Module] = None, loss_factor: float = 1.0, ): - breakpoint() model.train() lr = scheduler.get_last_lr()[0] # Updates the scale for next iteration From b9d7d32740d6137c14097a65665320bd46834f00 Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Tue, 10 Dec 2024 13:33:23 -0500 Subject: [PATCH 50/64] Added one import for type hinting --- ml_mdm/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ml_mdm/trainer.py b/ml_mdm/trainer.py index de17d4a..c454847 100644 --- a/ml_mdm/trainer.py +++ b/ml_mdm/trainer.py @@ -1,11 +1,14 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. + + from argparse import Namespace -from typing import Optional, Tuple +from typing import Optional import numpy as np import torch import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter def train_batch( From 4ea65548633aaf65779254fb13a3d65345927dcc Mon Sep 17 00:00:00 2001 From: ethanernst11 <146121019+ethanernst11@users.noreply.github.com> Date: Sun, 15 Dec 2024 16:28:42 -0500 Subject: [PATCH 51/64] Changed comments to TODO and added two type hints --- ml_mdm/samplers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index 9946cd5..034f632 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -4,7 +4,7 @@ import math from dataclasses import dataclass, field from enum import Enum -from typing import Tuple +from typing import Callable, Tuple from einops import repeat from tqdm import tqdm @@ -280,7 +280,7 @@ def get_prediction_xt_last( g: torch.Tensor, g_last: torch.Tensor, prediction_type: PredictionType = None, - clip_fn=None, # Said type class 'method' but could not find what that was + clip_fn: Callable = None, need_noise: torch.Tensor = False, ddim_eta: int = None, input_noise=None, @@ -386,7 +386,7 @@ def get_pred_from_x0_xt(self, x_t, x0, g, prediction_type=None): def get_xt_minus_1( self, - model, # Ethan-This is ml_mdm.diffusion.Model but importing diffusion is a circular import + model, # TODO - This is ml_mdm.diffusion.Model but importing diffusion is a circular import time_step: torch.Tensor, x_t: torch.Tensor, lm_outputs: torch.Tensor, @@ -429,7 +429,7 @@ def get_xt_minus_1( def forward_model( self, - model, # Ethan-This is ml_mdm.diffusion.Model but to import diffusion it is a circular import + model, # TODO - This is ml_mdm.diffusion.Model but to import diffusion it is a circular import x_t: torch.Tensor, t: torch.Tensor, lm_outputs: torch.Tensor, @@ -510,7 +510,7 @@ def sample(self, *args, **kwargs): def _sample( self, - model, # Ethan-This is ml_mdm.diffusion.Model but to import diffusion it is a circular import + model, # TODO - This is ml_mdm.diffusion.Model but to import diffusion it is a circular import x_t: torch.Tensor, lm_outputs: torch.Tensor, lm_mask: torch.Tensor, @@ -525,7 +525,7 @@ def _sample( disable_bar: bool = True, yield_output: bool = False, **post_args, - ): # Ethan - Generator and **post_args + ): """ Starting with x_t, at time step t, perform diffusion to first step. """ @@ -579,7 +579,7 @@ def _postprocess( extra: tuple = None, yield_full: bool = False, clip: bool = False, - image_scale=None, # Ethan-NoneType + image_scale: float = None, **unused, ): if image_scale is None: From ffbdfa7e9371c7dd44d58940d2665e3db02d094d Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Sun, 5 Jan 2025 15:41:18 -0300 Subject: [PATCH 52/64] typing NestedSampler, modelEMA, nestedUNET, UNET and Diffusion --- ml_mdm/diffusion.py | 6 +++--- ml_mdm/models/model_ema.py | 6 +++--- ml_mdm/models/nested_unet.py | 2 +- ml_mdm/models/unet.py | 8 ++++---- ml_mdm/samplers.py | 12 ++++++------ 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ml_mdm/diffusion.py b/ml_mdm/diffusion.py index f44ae71..7d2f906 100644 --- a/ml_mdm/diffusion.py +++ b/ml_mdm/diffusion.py @@ -3,7 +3,7 @@ """ Basic UNet-DDPM pipeline. """ import logging from dataclasses import dataclass, field -from typing import List +from typing import List, Tuple from einops import rearrange @@ -80,7 +80,7 @@ def forward( lm_outputs: torch.Tensor, lm_mask: torch.Tensor, micros: {}, - ) -> (torch.Tensor, torch.Tensor): + ) -> Tuple(torch.Tensor, torch.Tensor): outputs = self.vision_model(x_t, times, lm_outputs, lm_mask, micros) if self._output_scale != 0: outputs = torch.tanh(outputs / self._output_scale) * self._output_scale @@ -116,7 +116,7 @@ def eval(self): self.model.eval() self.sampler.eval() - def get_xt_minus_1(self, t, x_t, lm_outputs, lm_mask): + def get_xt_minus_1(self, t, x_t, lm_outputs: torch.Tensor, lm_mask: torch.Tensor): self.eval() return self.sampler.get_xt_minus_1(t, x_t, lm_outputs, lm_mask) diff --git a/ml_mdm/models/model_ema.py b/ml_mdm/models/model_ema.py index b2c0f83..c916633 100644 --- a/ml_mdm/models/model_ema.py +++ b/ml_mdm/models/model_ema.py @@ -10,7 +10,7 @@ class ModelEma(nn.Module): - def __init__(self, model, decay=0.9999, warmup_steps=0, device=None): + def __init__(self, model, decay: float=0.9999, warmup_steps: int = 0, device: torch.device =None): super(ModelEma, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) @@ -33,7 +33,7 @@ def update(self, model): model_v = model_v.to(device=self.device) ema_v.mul_(decay).add_(model_v, alpha=(1.0 - decay)) - def save(self, fname, other_items=None): + def save(self, fname: str, other_items=None): logging.info(f"Saving EMA model file: {fname}") checkpoint = {"state_dict": self.module.state_dict()} if other_items is not None: @@ -41,7 +41,7 @@ def save(self, fname, other_items=None): checkpoint[k] = v torch.save(checkpoint, fname) - def load(self, fname): + def load(self, fname: str): logging.info(f"Loading EMA model file: {fname}") fix_old_checkpoints.mimic_old_modules() checkpoint = torch.load(fname, map_location=lambda storage, loc: storage) diff --git a/ml_mdm/models/nested_unet.py b/ml_mdm/models/nested_unet.py index b87c20c..3b170fa 100644 --- a/ml_mdm/models/nested_unet.py +++ b/ml_mdm/models/nested_unet.py @@ -75,7 +75,7 @@ class Nested4UNetConfig(Nested3UNetConfig): ) -def download(vision_model_path): +def download(vision_model_path: str): import os from distributed import get_local_rank diff --git a/ml_mdm/models/unet.py b/ml_mdm/models/unet.py index 43a8506..2d5ffd1 100644 --- a/ml_mdm/models/unet.py +++ b/ml_mdm/models/unet.py @@ -578,7 +578,7 @@ def forward( @config.register_model("unet") class UNet(nn.Module): - def __init__(self, input_channels, output_channels, config: UNetConfig): + def __init__(self, input_channels: int, output_channels: int, config: UNetConfig): super().__init__() self.down_blocks = [] self.config = config @@ -776,7 +776,7 @@ def __init__(self, input_channels, output_channels, config: UNetConfig): def model_type(self): return "unet" - def print_size(self, target_image_size=64): + def print_size(self, target_image_size: int =64): summary( self, [ @@ -791,7 +791,7 @@ def print_size(self, target_image_size=64): depth=4, ) - def save(self, fname, other_items=None): + def save(self, fname: str, other_items=None): logging.info(f"Saving model file: {fname}") checkpoint = {"state_dict": self.state_dict()} if other_items is not None: @@ -799,7 +799,7 @@ def save(self, fname, other_items=None): checkpoint[k] = v torch.save(checkpoint, fname) - def load(self, fname): + def load(self, fname: str): logging.info(f"Loading model file: {fname}") fix_old_checkpoints.mimic_old_modules() # first load to cpu or we will run out of memory. diff --git a/ml_mdm/samplers.py b/ml_mdm/samplers.py index 47ac945..1870814 100644 --- a/ml_mdm/samplers.py +++ b/ml_mdm/samplers.py @@ -657,8 +657,8 @@ def get_xt_minus_1( model, time_step, x_t, - lm_outputs, - lm_mask, + lm_outputs: torch.Tensor, + lm_mask: torch.Tensor, micros={}, time_step_last=None, guidance_scale=1, @@ -717,9 +717,9 @@ def _postprocess( x_t, x0=None, extra=None, - yield_full=False, - clip=False, - output_inner=False, + yield_full: bool=False, + clip: bool=False, + output_inner: bool =False, **unused, ): scales = [ @@ -772,7 +772,7 @@ def cat(x, size): return out def forward_model( - self, model, x_t, t, lm_outputs, lm_mask, micros={}, guidance_scale=1 + self, model, x_t, t, lm_outputs: torch.Tensor, lm_mask: torch.Tensor, micros={}, guidance_scale=1 ): def cfg(pred): pred_uncond, pred = pred.chunk(2) From 0a0698919308bb6080b9c17e7faef9dba8ad38cd Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Sun, 5 Jan 2025 15:41:35 -0300 Subject: [PATCH 53/64] removing tuple import --- ml_mdm/diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml_mdm/diffusion.py b/ml_mdm/diffusion.py index 7d2f906..9a2766d 100644 --- a/ml_mdm/diffusion.py +++ b/ml_mdm/diffusion.py @@ -3,7 +3,7 @@ """ Basic UNet-DDPM pipeline. """ import logging from dataclasses import dataclass, field -from typing import List, Tuple +from typing import List from einops import rearrange @@ -80,7 +80,7 @@ def forward( lm_outputs: torch.Tensor, lm_mask: torch.Tensor, micros: {}, - ) -> Tuple(torch.Tensor, torch.Tensor): + ) -> (torch.Tensor, torch.Tensor): outputs = self.vision_model(x_t, times, lm_outputs, lm_mask, micros) if self._output_scale != 0: outputs = torch.tanh(outputs / self._output_scale) * self._output_scale From c6671831e2502fdc3dc12c8db5c4917f37ff9945 Mon Sep 17 00:00:00 2001 From: Luke Carlson Date: Tue, 11 Feb 2025 10:44:15 -0500 Subject: [PATCH 54/64] Feature/namespace package config (#5) * ml_mdm -> ml-mdm-matryoshka * created top level namespace package * update license file * turn off default coverage * rm extra README * cli builder * config logic --- .../dataset_creation/sample_cc12m.yaml | 0 .../configs}/datasets/cc12m.yaml | 0 .../configs}/models/cc12m_1024x1024.yaml | 0 .../configs}/models/cc12m_256x256.yaml | 0 .../configs}/models/cc12m_64x64.yaml | 0 {data => ml-mdm-matryoshka/data}/bert.vocab | 0 {data => ml-mdm-matryoshka/data}/c4_wpm.vocab | 0 .../data}/cifar10.vocab | 0 .../data}/imagenet.vocab | 0 .../data}/prompts_WebImage-ALIGN-64px.tsv | 0 .../data}/prompts_cc12m-256x256.tsv | 0 .../data}/prompts_cc12m-64x64.tsv | 0 .../data}/prompts_cifar10-32x32.tsv | 0 .../data}/prompts_cifar10-64x64.tsv | 0 .../data}/prompts_demo.tsv | 0 .../data}/prompts_imagenet-64px.tsv | 0 {data => ml-mdm-matryoshka/data}/t5.vocab | 0 .../data}/tokenizer_spm_32000_50m.vocab | 0 .../ml_mdm}/clis/__init__.py | 0 .../ml_mdm}/clis/download_tar_from_index.py | 0 .../ml_mdm}/clis/generate_batch.py | 0 .../ml_mdm}/clis/generate_sample.py | 0 .../ml_mdm}/clis/run_torchmetrics.py | 0 .../ml_mdm}/clis/scrape_cc12m.py | 0 .../ml_mdm}/clis/train_parallel.py | 0 .../ml_mdm}/config.py | 0 .../ml_mdm}/diffusion.py | 0 .../ml_mdm}/distributed.py | 0 .../ml_mdm}/generate_html.py | 0 .../ml_mdm}/helpers.py | 0 .../ml_mdm}/language_models/__init__.py | 0 .../ml_mdm}/language_models/factory.py | 0 .../ml_mdm}/language_models/self_attention.py | 0 .../ml_mdm}/language_models/tokenizer.py | 0 .../ml_mdm}/language_models/transformer.py | 0 .../ml_mdm}/lr_scaler.py | 0 .../ml_mdm}/models/__init__.py | 0 .../ml_mdm}/models/model_ema.py | 0 .../ml_mdm}/models/nested_unet.py | 0 .../ml_mdm}/models/unet.py | 0 .../ml_mdm}/reader.py | 0 .../ml_mdm}/s3_helpers.py | 0 .../ml_mdm}/samplers.py | 0 .../ml_mdm}/trainer.py | 0 .../ml_mdm}/utils/__init__.py | 0 .../ml_mdm}/utils/fix_old_checkpoints.py | 0 .../ml_mdm}/utils/simple_logger.py | 0 .../pyproject.toml | 5 ++- .../tests}/test_configs.py | 0 .../tests}/test_files/c12m_10samples.tsv | 0 .../tests}/test_files/images_00000.tar | Bin .../tests}/test_files/images_00000.tsv | 0 .../tests}/test_files/sample_training_0.tsv | 0 .../tests}/test_generate_batch.py | 0 .../tests}/test_generate_sample.py | 0 .../tests}/test_imports.py | 0 .../tests}/test_models.py | 0 .../tests}/test_reader.py | 0 .../tests}/test_tokenizer.py | 0 .../tests}/test_train.py | 0 .../__init__.py => ml-mdm/ml_mdm/__about__.py | 0 ml-mdm/ml_mdm/core.py | 35 ++++++++++++++++++ ml-mdm/pyproject.toml | 12 ++++++ ml-mdm/tests/__init__.py | 2 + 64 files changed, 52 insertions(+), 2 deletions(-) rename {configs => ml-mdm-matryoshka/configs}/dataset_creation/sample_cc12m.yaml (100%) rename {configs => ml-mdm-matryoshka/configs}/datasets/cc12m.yaml (100%) rename {configs => ml-mdm-matryoshka/configs}/models/cc12m_1024x1024.yaml (100%) rename {configs => ml-mdm-matryoshka/configs}/models/cc12m_256x256.yaml (100%) rename {configs => ml-mdm-matryoshka/configs}/models/cc12m_64x64.yaml (100%) rename {data => ml-mdm-matryoshka/data}/bert.vocab (100%) rename {data => ml-mdm-matryoshka/data}/c4_wpm.vocab (100%) rename {data => ml-mdm-matryoshka/data}/cifar10.vocab (100%) rename {data => ml-mdm-matryoshka/data}/imagenet.vocab (100%) rename {data => ml-mdm-matryoshka/data}/prompts_WebImage-ALIGN-64px.tsv (100%) rename {data => ml-mdm-matryoshka/data}/prompts_cc12m-256x256.tsv (100%) rename {data => ml-mdm-matryoshka/data}/prompts_cc12m-64x64.tsv (100%) rename {data => ml-mdm-matryoshka/data}/prompts_cifar10-32x32.tsv (100%) rename {data => ml-mdm-matryoshka/data}/prompts_cifar10-64x64.tsv (100%) rename {data => ml-mdm-matryoshka/data}/prompts_demo.tsv (100%) rename {data => ml-mdm-matryoshka/data}/prompts_imagenet-64px.tsv (100%) rename {data => ml-mdm-matryoshka/data}/t5.vocab (100%) rename {data => ml-mdm-matryoshka/data}/tokenizer_spm_32000_50m.vocab (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/__init__.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/download_tar_from_index.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/generate_batch.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/generate_sample.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/run_torchmetrics.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/scrape_cc12m.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/clis/train_parallel.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/config.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/diffusion.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/distributed.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/generate_html.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/helpers.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/language_models/__init__.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/language_models/factory.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/language_models/self_attention.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/language_models/tokenizer.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/language_models/transformer.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/lr_scaler.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/models/__init__.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/models/model_ema.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/models/nested_unet.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/models/unet.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/reader.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/s3_helpers.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/samplers.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/trainer.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/utils/__init__.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/utils/fix_old_checkpoints.py (100%) rename {ml_mdm => ml-mdm-matryoshka/ml_mdm}/utils/simple_logger.py (100%) rename pyproject.toml => ml-mdm-matryoshka/pyproject.toml (95%) rename {tests => ml-mdm-matryoshka/tests}/test_configs.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_files/c12m_10samples.tsv (100%) rename {tests => ml-mdm-matryoshka/tests}/test_files/images_00000.tar (100%) rename {tests => ml-mdm-matryoshka/tests}/test_files/images_00000.tsv (100%) rename {tests => ml-mdm-matryoshka/tests}/test_files/sample_training_0.tsv (100%) rename {tests => ml-mdm-matryoshka/tests}/test_generate_batch.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_generate_sample.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_imports.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_models.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_reader.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_tokenizer.py (100%) rename {tests => ml-mdm-matryoshka/tests}/test_train.py (100%) rename ml_mdm/__init__.py => ml-mdm/ml_mdm/__about__.py (100%) create mode 100644 ml-mdm/ml_mdm/core.py create mode 100644 ml-mdm/pyproject.toml create mode 100644 ml-mdm/tests/__init__.py diff --git a/configs/dataset_creation/sample_cc12m.yaml b/ml-mdm-matryoshka/configs/dataset_creation/sample_cc12m.yaml similarity index 100% rename from configs/dataset_creation/sample_cc12m.yaml rename to ml-mdm-matryoshka/configs/dataset_creation/sample_cc12m.yaml diff --git a/configs/datasets/cc12m.yaml b/ml-mdm-matryoshka/configs/datasets/cc12m.yaml similarity index 100% rename from configs/datasets/cc12m.yaml rename to ml-mdm-matryoshka/configs/datasets/cc12m.yaml diff --git a/configs/models/cc12m_1024x1024.yaml b/ml-mdm-matryoshka/configs/models/cc12m_1024x1024.yaml similarity index 100% rename from configs/models/cc12m_1024x1024.yaml rename to ml-mdm-matryoshka/configs/models/cc12m_1024x1024.yaml diff --git a/configs/models/cc12m_256x256.yaml b/ml-mdm-matryoshka/configs/models/cc12m_256x256.yaml similarity index 100% rename from configs/models/cc12m_256x256.yaml rename to ml-mdm-matryoshka/configs/models/cc12m_256x256.yaml diff --git a/configs/models/cc12m_64x64.yaml b/ml-mdm-matryoshka/configs/models/cc12m_64x64.yaml similarity index 100% rename from configs/models/cc12m_64x64.yaml rename to ml-mdm-matryoshka/configs/models/cc12m_64x64.yaml diff --git a/data/bert.vocab b/ml-mdm-matryoshka/data/bert.vocab similarity index 100% rename from data/bert.vocab rename to ml-mdm-matryoshka/data/bert.vocab diff --git a/data/c4_wpm.vocab b/ml-mdm-matryoshka/data/c4_wpm.vocab similarity index 100% rename from data/c4_wpm.vocab rename to ml-mdm-matryoshka/data/c4_wpm.vocab diff --git a/data/cifar10.vocab b/ml-mdm-matryoshka/data/cifar10.vocab similarity index 100% rename from data/cifar10.vocab rename to ml-mdm-matryoshka/data/cifar10.vocab diff --git a/data/imagenet.vocab b/ml-mdm-matryoshka/data/imagenet.vocab similarity index 100% rename from data/imagenet.vocab rename to ml-mdm-matryoshka/data/imagenet.vocab diff --git a/data/prompts_WebImage-ALIGN-64px.tsv b/ml-mdm-matryoshka/data/prompts_WebImage-ALIGN-64px.tsv similarity index 100% rename from data/prompts_WebImage-ALIGN-64px.tsv rename to ml-mdm-matryoshka/data/prompts_WebImage-ALIGN-64px.tsv diff --git a/data/prompts_cc12m-256x256.tsv b/ml-mdm-matryoshka/data/prompts_cc12m-256x256.tsv similarity index 100% rename from data/prompts_cc12m-256x256.tsv rename to ml-mdm-matryoshka/data/prompts_cc12m-256x256.tsv diff --git a/data/prompts_cc12m-64x64.tsv b/ml-mdm-matryoshka/data/prompts_cc12m-64x64.tsv similarity index 100% rename from data/prompts_cc12m-64x64.tsv rename to ml-mdm-matryoshka/data/prompts_cc12m-64x64.tsv diff --git a/data/prompts_cifar10-32x32.tsv b/ml-mdm-matryoshka/data/prompts_cifar10-32x32.tsv similarity index 100% rename from data/prompts_cifar10-32x32.tsv rename to ml-mdm-matryoshka/data/prompts_cifar10-32x32.tsv diff --git a/data/prompts_cifar10-64x64.tsv b/ml-mdm-matryoshka/data/prompts_cifar10-64x64.tsv similarity index 100% rename from data/prompts_cifar10-64x64.tsv rename to ml-mdm-matryoshka/data/prompts_cifar10-64x64.tsv diff --git a/data/prompts_demo.tsv b/ml-mdm-matryoshka/data/prompts_demo.tsv similarity index 100% rename from data/prompts_demo.tsv rename to ml-mdm-matryoshka/data/prompts_demo.tsv diff --git a/data/prompts_imagenet-64px.tsv b/ml-mdm-matryoshka/data/prompts_imagenet-64px.tsv similarity index 100% rename from data/prompts_imagenet-64px.tsv rename to ml-mdm-matryoshka/data/prompts_imagenet-64px.tsv diff --git a/data/t5.vocab b/ml-mdm-matryoshka/data/t5.vocab similarity index 100% rename from data/t5.vocab rename to ml-mdm-matryoshka/data/t5.vocab diff --git a/data/tokenizer_spm_32000_50m.vocab b/ml-mdm-matryoshka/data/tokenizer_spm_32000_50m.vocab similarity index 100% rename from data/tokenizer_spm_32000_50m.vocab rename to ml-mdm-matryoshka/data/tokenizer_spm_32000_50m.vocab diff --git a/ml_mdm/clis/__init__.py b/ml-mdm-matryoshka/ml_mdm/clis/__init__.py similarity index 100% rename from ml_mdm/clis/__init__.py rename to ml-mdm-matryoshka/ml_mdm/clis/__init__.py diff --git a/ml_mdm/clis/download_tar_from_index.py b/ml-mdm-matryoshka/ml_mdm/clis/download_tar_from_index.py similarity index 100% rename from ml_mdm/clis/download_tar_from_index.py rename to ml-mdm-matryoshka/ml_mdm/clis/download_tar_from_index.py diff --git a/ml_mdm/clis/generate_batch.py b/ml-mdm-matryoshka/ml_mdm/clis/generate_batch.py similarity index 100% rename from ml_mdm/clis/generate_batch.py rename to ml-mdm-matryoshka/ml_mdm/clis/generate_batch.py diff --git a/ml_mdm/clis/generate_sample.py b/ml-mdm-matryoshka/ml_mdm/clis/generate_sample.py similarity index 100% rename from ml_mdm/clis/generate_sample.py rename to ml-mdm-matryoshka/ml_mdm/clis/generate_sample.py diff --git a/ml_mdm/clis/run_torchmetrics.py b/ml-mdm-matryoshka/ml_mdm/clis/run_torchmetrics.py similarity index 100% rename from ml_mdm/clis/run_torchmetrics.py rename to ml-mdm-matryoshka/ml_mdm/clis/run_torchmetrics.py diff --git a/ml_mdm/clis/scrape_cc12m.py b/ml-mdm-matryoshka/ml_mdm/clis/scrape_cc12m.py similarity index 100% rename from ml_mdm/clis/scrape_cc12m.py rename to ml-mdm-matryoshka/ml_mdm/clis/scrape_cc12m.py diff --git a/ml_mdm/clis/train_parallel.py b/ml-mdm-matryoshka/ml_mdm/clis/train_parallel.py similarity index 100% rename from ml_mdm/clis/train_parallel.py rename to ml-mdm-matryoshka/ml_mdm/clis/train_parallel.py diff --git a/ml_mdm/config.py b/ml-mdm-matryoshka/ml_mdm/config.py similarity index 100% rename from ml_mdm/config.py rename to ml-mdm-matryoshka/ml_mdm/config.py diff --git a/ml_mdm/diffusion.py b/ml-mdm-matryoshka/ml_mdm/diffusion.py similarity index 100% rename from ml_mdm/diffusion.py rename to ml-mdm-matryoshka/ml_mdm/diffusion.py diff --git a/ml_mdm/distributed.py b/ml-mdm-matryoshka/ml_mdm/distributed.py similarity index 100% rename from ml_mdm/distributed.py rename to ml-mdm-matryoshka/ml_mdm/distributed.py diff --git a/ml_mdm/generate_html.py b/ml-mdm-matryoshka/ml_mdm/generate_html.py similarity index 100% rename from ml_mdm/generate_html.py rename to ml-mdm-matryoshka/ml_mdm/generate_html.py diff --git a/ml_mdm/helpers.py b/ml-mdm-matryoshka/ml_mdm/helpers.py similarity index 100% rename from ml_mdm/helpers.py rename to ml-mdm-matryoshka/ml_mdm/helpers.py diff --git a/ml_mdm/language_models/__init__.py b/ml-mdm-matryoshka/ml_mdm/language_models/__init__.py similarity index 100% rename from ml_mdm/language_models/__init__.py rename to ml-mdm-matryoshka/ml_mdm/language_models/__init__.py diff --git a/ml_mdm/language_models/factory.py b/ml-mdm-matryoshka/ml_mdm/language_models/factory.py similarity index 100% rename from ml_mdm/language_models/factory.py rename to ml-mdm-matryoshka/ml_mdm/language_models/factory.py diff --git a/ml_mdm/language_models/self_attention.py b/ml-mdm-matryoshka/ml_mdm/language_models/self_attention.py similarity index 100% rename from ml_mdm/language_models/self_attention.py rename to ml-mdm-matryoshka/ml_mdm/language_models/self_attention.py diff --git a/ml_mdm/language_models/tokenizer.py b/ml-mdm-matryoshka/ml_mdm/language_models/tokenizer.py similarity index 100% rename from ml_mdm/language_models/tokenizer.py rename to ml-mdm-matryoshka/ml_mdm/language_models/tokenizer.py diff --git a/ml_mdm/language_models/transformer.py b/ml-mdm-matryoshka/ml_mdm/language_models/transformer.py similarity index 100% rename from ml_mdm/language_models/transformer.py rename to ml-mdm-matryoshka/ml_mdm/language_models/transformer.py diff --git a/ml_mdm/lr_scaler.py b/ml-mdm-matryoshka/ml_mdm/lr_scaler.py similarity index 100% rename from ml_mdm/lr_scaler.py rename to ml-mdm-matryoshka/ml_mdm/lr_scaler.py diff --git a/ml_mdm/models/__init__.py b/ml-mdm-matryoshka/ml_mdm/models/__init__.py similarity index 100% rename from ml_mdm/models/__init__.py rename to ml-mdm-matryoshka/ml_mdm/models/__init__.py diff --git a/ml_mdm/models/model_ema.py b/ml-mdm-matryoshka/ml_mdm/models/model_ema.py similarity index 100% rename from ml_mdm/models/model_ema.py rename to ml-mdm-matryoshka/ml_mdm/models/model_ema.py diff --git a/ml_mdm/models/nested_unet.py b/ml-mdm-matryoshka/ml_mdm/models/nested_unet.py similarity index 100% rename from ml_mdm/models/nested_unet.py rename to ml-mdm-matryoshka/ml_mdm/models/nested_unet.py diff --git a/ml_mdm/models/unet.py b/ml-mdm-matryoshka/ml_mdm/models/unet.py similarity index 100% rename from ml_mdm/models/unet.py rename to ml-mdm-matryoshka/ml_mdm/models/unet.py diff --git a/ml_mdm/reader.py b/ml-mdm-matryoshka/ml_mdm/reader.py similarity index 100% rename from ml_mdm/reader.py rename to ml-mdm-matryoshka/ml_mdm/reader.py diff --git a/ml_mdm/s3_helpers.py b/ml-mdm-matryoshka/ml_mdm/s3_helpers.py similarity index 100% rename from ml_mdm/s3_helpers.py rename to ml-mdm-matryoshka/ml_mdm/s3_helpers.py diff --git a/ml_mdm/samplers.py b/ml-mdm-matryoshka/ml_mdm/samplers.py similarity index 100% rename from ml_mdm/samplers.py rename to ml-mdm-matryoshka/ml_mdm/samplers.py diff --git a/ml_mdm/trainer.py b/ml-mdm-matryoshka/ml_mdm/trainer.py similarity index 100% rename from ml_mdm/trainer.py rename to ml-mdm-matryoshka/ml_mdm/trainer.py diff --git a/ml_mdm/utils/__init__.py b/ml-mdm-matryoshka/ml_mdm/utils/__init__.py similarity index 100% rename from ml_mdm/utils/__init__.py rename to ml-mdm-matryoshka/ml_mdm/utils/__init__.py diff --git a/ml_mdm/utils/fix_old_checkpoints.py b/ml-mdm-matryoshka/ml_mdm/utils/fix_old_checkpoints.py similarity index 100% rename from ml_mdm/utils/fix_old_checkpoints.py rename to ml-mdm-matryoshka/ml_mdm/utils/fix_old_checkpoints.py diff --git a/ml_mdm/utils/simple_logger.py b/ml-mdm-matryoshka/ml_mdm/utils/simple_logger.py similarity index 100% rename from ml_mdm/utils/simple_logger.py rename to ml-mdm-matryoshka/ml_mdm/utils/simple_logger.py diff --git a/pyproject.toml b/ml-mdm-matryoshka/pyproject.toml similarity index 95% rename from pyproject.toml rename to ml-mdm-matryoshka/pyproject.toml index fc6fcf0..d3d16d2 100644 --- a/pyproject.toml +++ b/ml-mdm-matryoshka/pyproject.toml @@ -5,9 +5,10 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["."] exclude = ["tests*", "*clis*"] +namespaces = true [project] -name = "ml_mdm" +name = "ml-mdm-matryoshka" authors = [{name = "Apple"}] readme = "README.md" version = "1.0" @@ -72,7 +73,7 @@ sections = ["FUTURE", "STDLIB", "THIRDPARTY", "NUMERIC", "FIRSTPARTY", "LOCALFOL known_numeric = ["torch", "torchvision", "numpy", "jax", "flax", "mlx"] [tool.pytest.ini_options] -addopts = "--cov=ml_mdm -m 'not gpu'" +addopts = " -m 'not gpu'" markers = [ "gpu" # tests that require a gpu ] diff --git a/tests/test_configs.py b/ml-mdm-matryoshka/tests/test_configs.py similarity index 100% rename from tests/test_configs.py rename to ml-mdm-matryoshka/tests/test_configs.py diff --git a/tests/test_files/c12m_10samples.tsv b/ml-mdm-matryoshka/tests/test_files/c12m_10samples.tsv similarity index 100% rename from tests/test_files/c12m_10samples.tsv rename to ml-mdm-matryoshka/tests/test_files/c12m_10samples.tsv diff --git a/tests/test_files/images_00000.tar b/ml-mdm-matryoshka/tests/test_files/images_00000.tar similarity index 100% rename from tests/test_files/images_00000.tar rename to ml-mdm-matryoshka/tests/test_files/images_00000.tar diff --git a/tests/test_files/images_00000.tsv b/ml-mdm-matryoshka/tests/test_files/images_00000.tsv similarity index 100% rename from tests/test_files/images_00000.tsv rename to ml-mdm-matryoshka/tests/test_files/images_00000.tsv diff --git a/tests/test_files/sample_training_0.tsv b/ml-mdm-matryoshka/tests/test_files/sample_training_0.tsv similarity index 100% rename from tests/test_files/sample_training_0.tsv rename to ml-mdm-matryoshka/tests/test_files/sample_training_0.tsv diff --git a/tests/test_generate_batch.py b/ml-mdm-matryoshka/tests/test_generate_batch.py similarity index 100% rename from tests/test_generate_batch.py rename to ml-mdm-matryoshka/tests/test_generate_batch.py diff --git a/tests/test_generate_sample.py b/ml-mdm-matryoshka/tests/test_generate_sample.py similarity index 100% rename from tests/test_generate_sample.py rename to ml-mdm-matryoshka/tests/test_generate_sample.py diff --git a/tests/test_imports.py b/ml-mdm-matryoshka/tests/test_imports.py similarity index 100% rename from tests/test_imports.py rename to ml-mdm-matryoshka/tests/test_imports.py diff --git a/tests/test_models.py b/ml-mdm-matryoshka/tests/test_models.py similarity index 100% rename from tests/test_models.py rename to ml-mdm-matryoshka/tests/test_models.py diff --git a/tests/test_reader.py b/ml-mdm-matryoshka/tests/test_reader.py similarity index 100% rename from tests/test_reader.py rename to ml-mdm-matryoshka/tests/test_reader.py diff --git a/tests/test_tokenizer.py b/ml-mdm-matryoshka/tests/test_tokenizer.py similarity index 100% rename from tests/test_tokenizer.py rename to ml-mdm-matryoshka/tests/test_tokenizer.py diff --git a/tests/test_train.py b/ml-mdm-matryoshka/tests/test_train.py similarity index 100% rename from tests/test_train.py rename to ml-mdm-matryoshka/tests/test_train.py diff --git a/ml_mdm/__init__.py b/ml-mdm/ml_mdm/__about__.py similarity index 100% rename from ml_mdm/__init__.py rename to ml-mdm/ml_mdm/__about__.py diff --git a/ml-mdm/ml_mdm/core.py b/ml-mdm/ml_mdm/core.py new file mode 100644 index 0000000..2d6152d --- /dev/null +++ b/ml-mdm/ml_mdm/core.py @@ -0,0 +1,35 @@ + + +from dataclasses import dataclass, is_dataclass +from simple_parsing.helpers import Serializable +from simple_parsing import parse +from simple_parsing.utils import DataclassT + +from typing import TypeVar + +C = TypeVar('C') + +@dataclass +class MDMConfig(Serializable): + pass + +class ConfigPrinter: + def __init__(self, config : MDMConfig) -> None: + print(config) + +@dataclass +class CLIBuilder(): + class_to_call: type[C] = ConfigPrinter + config_class: type = MDMConfig + default_config : DataclassT = None + + def build_config(self, args: str = None) -> DataclassT: + assert is_dataclass(self.config_class) + cfg: DataclassT = parse( + config_class=self.config_class, add_config_path_arg="config-file", default=self.default_config, args=args + ) + return cfg + + def run(self)-> C: + cfg: DataclassT = self.build_config() + return self.class_to_call(cfg) diff --git a/ml-mdm/pyproject.toml b/ml-mdm/pyproject.toml new file mode 100644 index 0000000..81d18d2 --- /dev/null +++ b/ml-mdm/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = ["setuptools"] + +[project] +dependencies = [] +name = "ml-mdm" +version = "0.1.0" + +[tool.setuptools.packages.find] +namespaces = true +where = ["."] diff --git a/ml-mdm/tests/__init__.py b/ml-mdm/tests/__init__.py new file mode 100644 index 0000000..5c8f054 --- /dev/null +++ b/ml-mdm/tests/__init__.py @@ -0,0 +1,2 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. \ No newline at end of file From 256f576ccf8b4b21c06e9efcc34e39e18e3ff0a5 Mon Sep 17 00:00:00 2001 From: Luke Carlson Date: Wed, 12 Feb 2025 15:00:18 -0500 Subject: [PATCH 55/64] Include test_tokenizer fix --- ml-mdm-matryoshka/tests/test_tokenizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml-mdm-matryoshka/tests/test_tokenizer.py b/ml-mdm-matryoshka/tests/test_tokenizer.py index 49ea485..30b090c 100644 --- a/ml-mdm-matryoshka/tests/test_tokenizer.py +++ b/ml-mdm-matryoshka/tests/test_tokenizer.py @@ -7,15 +7,15 @@ from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py def test_tokenizer_bert(): - f = Path(__file__).parent/"data/bert.vocab" # To solve from relative to absolute import + f = Path(__file__).parent.parent/"data/bert.vocab" # To solve from relative to absolute import assert Tokenizer(f, mode="bert") def test_tokenizer_t5(): - f = Path(__file__).parent/"data/t5.vocab" + f = Path(__file__).parent.parent/"data/t5.vocab" assert Tokenizer(f, mode="tf") def test_tokenizer(): - f = Path(__file__).parent/"data/imagenet.vocab" + f = Path(__file__).parent.parent/"data/imagenet.vocab" assert Tokenizer(f) test_tokenizer_bert() From de88fbce5d18cf251db5dda3c48888ad13853096 Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Sat, 22 Feb 2025 06:16:06 -0300 Subject: [PATCH 56/64] fixed self attention implementation and separated the self attention test into cond. and no-cond.. Also, the coded tests passed --- ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py | 132 ++++++++++++++++++ ml-mdm-matryoshka/tests/test_unet_mlx.py | 145 ++++++++++++++++++++ 2 files changed, 277 insertions(+) create mode 100644 ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py create mode 100644 ml-mdm-matryoshka/tests/test_unet_mlx.py diff --git a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py new file mode 100644 index 0000000..32da433 --- /dev/null +++ b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py @@ -0,0 +1,132 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import math + +import einops.array_api + +import mlx.core as mx +import mlx.nn as nn + + +def zero_module_mlx(module): + """ + Zero out the parameters of an MLX module and return it. + """ + # Create a new parameter dictionary with all parameters replaced by zeros + zeroed_params = { + name: mx.zeros(param.shape, dtype=param.dtype) + for name, param in module.parameters().items() + } + # Update the module's parameters with the zeroed parameters + module.update(zeroed_params) + return module + + +class MLP_MLX(nn.Module): # mlx based nn.Module + def __init__(self, channels, multiplier=4): + super().__init__() + ### use mlx layers + self.main = nn.Sequential( + nn.LayerNorm(channels), + nn.Linear(channels, multiplier * channels), + nn.GELU(), + zero_module_mlx(nn.Linear(multiplier * channels, channels)), + ) + + def forward(self, x): + return x + self.main(x) + + +class SelfAttention_MLX(nn.Module): + def __init__( + self, + channels, + num_heads=8, + num_head_channels=-1, + cond_dim=None, + use_attention_ffn=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = nn.GroupNorm(32, channels, pytorch_compatible=True) + self.qkv = nn.Conv2d(channels, channels * 3, 1) + self.cond_dim = cond_dim + if cond_dim is not None and cond_dim > 0: + self.norm_cond = nn.LayerNorm(cond_dim) + self.kv_cond = nn.Linear(cond_dim, channels * 2) + self.proj_out = zero_module_mlx(nn.Conv2d(channels, channels, 1)) + if use_attention_ffn: + self.ffn = nn.Sequential( + nn.GroupNorm(32, channels, pytorch_compatible=True), + nn.Conv2d(channels, 4 * channels, 1), + nn.GELU(), + zero_module_mlx(nn.Conv2d(4 * channels, channels, 1)), + ) + else: + self.ffn = None + + def attention(self, q, k, v, mask=None): + bs, width, length = q.shape + ch = width // self.num_heads + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = mx.einsum( + "bct,bcs->bts", + (q * scale).reshape(bs * self.num_heads, ch, length), + (k * scale).reshape(bs * self.num_heads, ch, -1), + ) # More stable with f16 than dividing afterwards + if mask is not None: + # Reshape mask to match attention shape + # From [bs, seq_len] to [bs * num_heads, 1, seq_len] + expanded_mask = einops.array_api.repeat( + mask[:, None, :], # Add dimension for broadcasting + "b 1 s -> (b h) 1 s", + h=self.num_heads, + ) + # Apply mask + weight = mx.where(expanded_mask, weight, float("-inf")) + + weight = mx.softmax(weight, axis=-1) + + return mx.einsum( + "bts,bcs->bct", weight, v.reshape(bs * self.num_heads, ch, -1) + ).reshape(bs, width, length) + + def forward(self, x, cond=None, cond_mask=None): + + x = einops.array_api.rearrange(x, "b c h w -> b h w c") + b, h, w, c = x.shape + + qkv = self.qkv(self.norm(x)) + qkv = einops.array_api.rearrange(qkv, "b h w (three c) -> three b (h w) c", three=3) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn_output = self.attention(q, k, v) + + if self.cond_dim is not None and cond is not None: + kv_cond = self.kv_cond(self.norm_cond(cond)) + kv_cond = einops.array_api.rearrange(kv_cond, "b s (two c) -> two b s c", two=2) + k_cond, v_cond = kv_cond[0], kv_cond[1] + attn_cond = self.attention(q, k_cond, v_cond, cond_mask) + attn_output += attn_cond + + attn_output = einops.array_api.rearrange(attn_output, "b (h w) c -> b h w c", h=h, w=w) + h = self.proj_out(attn_output) + + x = einops.array_api.rearrange(x, "b h w c -> b c h w") + h = einops.array_api.rearrange(h, "b h w c -> b c h w") + x = x + h + + if self.ffn is not None: + x = einops.array_api.rearrange(x, "b c h w -> b h w c") + x = self.ffn(x) + x + x = einops.array_api.rearrange(x, "b h w c -> b c h w") + + return x diff --git a/ml-mdm-matryoshka/tests/test_unet_mlx.py b/ml-mdm-matryoshka/tests/test_unet_mlx.py new file mode 100644 index 0000000..5d3415e --- /dev/null +++ b/ml-mdm-matryoshka/tests/test_unet_mlx.py @@ -0,0 +1,145 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import mlx.core as mx +import numpy as np +import torch + +from ml_mdm.models.unet import MLP, SelfAttention +from ml_mdm.models.unet_mlx import MLP_MLX, SelfAttention_MLX + +def test_pytorch_mlp(): + """ + Simple test for our MLP implementations + """ + # Define parameters + channels = 8 # Number of channels + multiplier = 4 # Multiplier for hidden dimensions + + # Create a model instance + pytorch_mlp = MLP(channels=channels, multiplier=multiplier) + mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) + + ## Start by testing pytorch version + + # Set model to evaluation mode + pytorch_mlp.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 8) + input_tensor = torch.randn(2, channels) + + # Pass the input through the model + output = pytorch_mlp(input_tensor) + + # Assertions to validate the output shape and properties + assert output.shape == input_tensor.shape, "Output shape mismatch" + assert torch.allclose( + output, input_tensor, atol=1e-5 + ), "Output should be close to input as the final layer is zero-initialized" + + ## now test mlx version + + # Convert the same input to MLX tensor + mlx_tensor = mx.array(input_tensor.numpy()) + + mlx_mlp.eval() + + mlx_output = mlx_mlp.forward(mlx_tensor) + + assert isinstance(mlx_output, mx.array) + assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" + + # Validate numerical equivalence using numpy + assert np.allclose( + output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch MLP and MLX MLP should match" + + print("Test passed for both PyTorch and MLX MLP!") + + +def test_pytorch_mlx_self_attention(): + """ + Test for feature parity between PyTorch and MLX implementations of SelfAttention. + We'll test both the basic self-attention and conditional attention scenarios. + """ + # Define test parameters + channels = 64 + batch_size = 2 + spatial_size = 8 + cond_dim = 32 + num_heads = 8 + + # ===== 1. Test WITH CONDITIONAL INPUT ===== + # Create models WITH conditional support + pytorch_attn_with_cond = SelfAttention( + channels=channels, + num_heads=num_heads, + cond_dim=cond_dim, # Enable conditioning + use_attention_ffn=True, + ) + mlx_attn_with_cond = SelfAttention_MLX( + channels=channels, + num_heads=num_heads, + cond_dim=cond_dim, + use_attention_ffn=True, + ) + + # Create conditional inputs + cond_seq_len = 4 + pytorch_cond = torch.randn(batch_size, cond_seq_len, cond_dim) + pytorch_cond_mask = torch.ones(batch_size, cond_seq_len) + mlx_cond = mx.array(pytorch_cond.numpy()) + mlx_cond_mask = mx.array(pytorch_cond_mask.numpy()) + + # Run conditional tests + pytorch_input = torch.randn(batch_size, channels, spatial_size, spatial_size) + mlx_input = mx.array(pytorch_input.numpy()) + + # PyTorch conditional forward + pytorch_output_with_cond = pytorch_attn_with_cond( + pytorch_input, cond=pytorch_cond, cond_mask=pytorch_cond_mask + ) + # MLX conditional forward + mlx_output_with_cond = mlx_attn_with_cond.forward( + mlx_input, cond=mlx_cond, cond_mask=mlx_cond_mask + ) + + # ===== 2. Test WITHOUT CONDITIONAL INPUT ===== + # Create NEW models WITHOUT conditional support + pytorch_attn_no_cond = SelfAttention( + channels=channels, + num_heads=num_heads, + cond_dim=None, # Disable conditioning + use_attention_ffn=True, + ) + mlx_attn_no_cond = SelfAttention_MLX( + channels=channels, + num_heads=num_heads, + cond_dim=None, # Disable conditioning + use_attention_ffn=True, + ) + + # Run non-conditional tests + pytorch_output_no_cond = pytorch_attn_no_cond(pytorch_input) + mlx_output_no_cond = mlx_attn_no_cond.forward(mlx_input) + + # ===== Assertions ===== + # Check conditional outputs + assert pytorch_output_with_cond.shape == pytorch_input.shape + assert mlx_output_with_cond.shape == mlx_input.shape + assert np.allclose( + pytorch_output_with_cond.detach().numpy(), + np.array(mlx_output_with_cond), + atol=1e-5, rtol=1e-5 + ) + + # Check non-conditional outputs + assert pytorch_output_no_cond.shape == pytorch_input.shape + assert mlx_output_no_cond.shape == mlx_input.shape + assert np.allclose( + pytorch_output_no_cond.detach().numpy(), + np.array(mlx_output_no_cond), + atol=1e-5, rtol=1e-5 + ) + + print("All tests passed!") \ No newline at end of file From df6d53352b6b7f7b499894307f6c246afd7914b7 Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Sat, 22 Feb 2025 06:21:48 -0300 Subject: [PATCH 57/64] improved tests flag comments --- ml-mdm-matryoshka/tests/test_unet_mlx.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ml-mdm-matryoshka/tests/test_unet_mlx.py b/ml-mdm-matryoshka/tests/test_unet_mlx.py index 5d3415e..a429a95 100644 --- a/ml-mdm-matryoshka/tests/test_unet_mlx.py +++ b/ml-mdm-matryoshka/tests/test_unet_mlx.py @@ -109,13 +109,13 @@ def test_pytorch_mlx_self_attention(): pytorch_attn_no_cond = SelfAttention( channels=channels, num_heads=num_heads, - cond_dim=None, # Disable conditioning + cond_dim=None, use_attention_ffn=True, ) mlx_attn_no_cond = SelfAttention_MLX( channels=channels, num_heads=num_heads, - cond_dim=None, # Disable conditioning + cond_dim=None, use_attention_ffn=True, ) @@ -131,7 +131,7 @@ def test_pytorch_mlx_self_attention(): pytorch_output_with_cond.detach().numpy(), np.array(mlx_output_with_cond), atol=1e-5, rtol=1e-5 - ) + ), "Outputs of PyTorch and MLX attention should match" # Check non-conditional outputs assert pytorch_output_no_cond.shape == pytorch_input.shape @@ -140,6 +140,6 @@ def test_pytorch_mlx_self_attention(): pytorch_output_no_cond.detach().numpy(), np.array(mlx_output_no_cond), atol=1e-5, rtol=1e-5 - ) + ), "Outputs without conditioning should match" - print("All tests passed!") \ No newline at end of file + print("Self-attention test passed for both PyTorch and MLX!") \ No newline at end of file From 753feeff5e94ea245f14cded8db2f4729fd22992 Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Tue, 25 Feb 2025 18:46:18 -0300 Subject: [PATCH 58/64] debug? --- ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py | 3 +-- ml-mdm-matryoshka/tests/test_unet_mlx.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py index 32da433..521a305 100644 --- a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py +++ b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py @@ -8,7 +8,6 @@ import mlx.core as mx import mlx.nn as nn - def zero_module_mlx(module): """ Zero out the parameters of an MLX module and return it. @@ -129,4 +128,4 @@ def forward(self, x, cond=None, cond_mask=None): x = self.ffn(x) + x x = einops.array_api.rearrange(x, "b h w c -> b c h w") - return x + return x \ No newline at end of file diff --git a/ml-mdm-matryoshka/tests/test_unet_mlx.py b/ml-mdm-matryoshka/tests/test_unet_mlx.py index a429a95..7334e66 100644 --- a/ml-mdm-matryoshka/tests/test_unet_mlx.py +++ b/ml-mdm-matryoshka/tests/test_unet_mlx.py @@ -57,6 +57,7 @@ def test_pytorch_mlp(): print("Test passed for both PyTorch and MLX MLP!") + def test_pytorch_mlx_self_attention(): """ Test for feature parity between PyTorch and MLX implementations of SelfAttention. From a3eac3fd7cdfdb43219a87781e1d99c3296d5108 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Mon, 3 Mar 2025 10:45:02 -0500 Subject: [PATCH 59/64] selfAttention1D passes --- ml_mdm/models/unet_mlx.py | 77 +++++++++++++++++++++++++++++++++++++++ tests/test_mlx_unet.py | 40 +++++++++++++++++++- 2 files changed, 115 insertions(+), 2 deletions(-) diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py index 0084ccc..140b3e7 100644 --- a/ml_mdm/models/unet_mlx.py +++ b/ml_mdm/models/unet_mlx.py @@ -1,6 +1,10 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. +import math + +import einops.array_api + import mlx.core as mx import mlx.nn as nn @@ -19,6 +23,79 @@ def zero_module_mlx(module): return module +class SelfAttention1D_MLX(nn.Module): + def __init__( + self, + channels, + num_heads=8, + num_head_channels=-1, + use_attention_ffn=False, + pos_emb=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.norm = nn.LayerNorm(channels) + self.qkv = nn.Linear(channels, channels * 3) + self.proj_out = zero_module_mlx(nn.Linear(channels, channels)) + if use_attention_ffn: + self.ffn = nn.Sequential( + nn.LayerNorm(channels), + nn.Linear(channels, 4 * channels), + nn.GELU(), + zero_module_mlx(nn.Linear(4 * channels, channels)), + ) + else: + self.ffn = None + if pos_emb: + from mlx.nn import RoPE + + self.pos_emb = RoPE(dim=channels // self.num_heads) + else: + self.pos_emb = None + + def attention(self, q, k, v, mask=None): + bs, length, width = q.shape + ch = width // self.num_heads + scale = 1 / math.sqrt(math.sqrt(ch)) + q = q.reshape(bs, length, self.num_heads, ch) + k = k.reshape(bs, length, self.num_heads, ch) + if self.pos_emb is not None: + q = self.pos_emb.rotate_queries_or_keys(q.permute(0, 2, 1, 3)).permute( + 0, 2, 1, 3 + ) + k = self.pos_emb.rotate_queries_or_keys(k.permute(0, 2, 1, 3)).permute( + 0, 2, 1, 3 + ) + weight = mx.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + if mask is not None: + mask = mask.view(mask.size(0), 1, 1, mask.size(1)) + weight = weight.masked_fill(mask == 0, float("-inf")) + weight = mx.softmax(weight, axis=-1) + a = mx.einsum("bhts,bshc->bthc", weight, v.reshape(bs, -1, self.num_heads, ch)) + return a.reshape(bs, length, -1) + + def forward(self, x, mask): + # assert (self.cond_dim is not None) == (cond is not None) + qkv = self.qkv(self.norm(x)) + q, k, v = mx.split(qkv, 3, axis=-1) + h = self.attention(q, k, v, mask) + h = self.proj_out(h) + x = x + h + if self.ffn is not None: + x = x + self.ffn(x) + return x + + class MLP_MLX(nn.Module): # mlx based nn.Module def __init__(self, channels, multiplier=4): super().__init__() diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py index a58a60e..75fbf26 100644 --- a/tests/test_mlx_unet.py +++ b/tests/test_mlx_unet.py @@ -5,8 +5,8 @@ import numpy as np import torch -from ml_mdm.models.unet import MLP -from ml_mdm.models.unet_mlx import MLP_MLX +from ml_mdm.models.unet import MLP, SelfAttention1D +from ml_mdm.models.unet_mlx import MLP_MLX, SelfAttention1D_MLX def test_pytorch_mlp(): @@ -56,3 +56,39 @@ def test_pytorch_mlp(): ), "Outputs of PyTorch MLP and MLX MLP should match" print("Test passed for both PyTorch and MLX MLP!") + + +def test_self_attention_1d(): + # Define parameters + channels = 8 + num_heads = 2 + seq_length = 16 + batch_size = 2 + + # Create a model instance + pytorch_attn = SelfAttention1D(channels=channels, num_heads=num_heads) + mlx_attn = SelfAttention1D_MLX(channels=channels, num_heads=num_heads) + + # Set models to evaluation mode + pytorch_attn.eval() + mlx_attn.eval() + + # Create a dummy input tensor + input_tensor = torch.randn(batch_size, seq_length, channels) + + # Pass the input through the PyTorch model + pytorch_output = pytorch_attn(input_tensor, mask=None) + + # Convert the input to MLX format + mlx_input = mx.array(input_tensor.numpy()) + + # Pass the input through the MLX model + mlx_output = mlx_attn.forward(mlx_input, mask=None) + + # Assertions to validate the output shape and properties + assert pytorch_output.shape == mlx_output.shape, "Output shape mismatch" + assert np.allclose( + pytorch_output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch and MLX SelfAttention1D should match" + + print("Test passed for both PyTorch and MLX SelfAttention1D!") From 14d877cf79498bf849c52506e544b35e37ce687e Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 4 Mar 2025 13:11:42 -0500 Subject: [PATCH 60/64] working through parity issues --- ml_mdm/models/unet_mlx.py | 38 +++++++++++++++++++++++++ tests/test_mlx_unet.py | 60 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py index 140b3e7..578696e 100644 --- a/ml_mdm/models/unet_mlx.py +++ b/ml_mdm/models/unet_mlx.py @@ -96,6 +96,44 @@ def forward(self, x, mask): return x +class TemporalAttentionBlock_MLX(nn.Module): + def __init__( + self, channels, num_heads=8, num_head_channels=-1, down=False, pos_emb=False + ): + super().__init__() + self.attn = SelfAttention1D_MLX( + channels, num_heads, num_head_channels, pos_emb=pos_emb + ) + self.mlp = MLP_MLX(channels, multiplier=4) + self.down = down + if down: + self.down_conv = nn.Conv2d( + channels, channels, kernel_size=3, stride=2, padding=1, bias=True + ) + self.up_conv = nn.Conv2d( + channels, channels, kernel_size=3, stride=1, padding=1, bias=True + ) + + def forward(self, x, temb): + x_ = x + if self.down: + # transformation for mlx format + x = einops.array_api.rearrange(x, "b c h w -> b h w c") + x = self.down_conv(x) + + T, H, W = x.shape[0] // temb.shape[0], x.shape[2], x.shape[3] + x = einops.array_api.rearrange(x, "(b t) c h w -> (b h w) t c", t=T) + x = self.mlp.forward(self.attn.forward(x, None)) + x = einops.array_api.rearrange(x, "(b h w) t c -> (b t) c h w", h=H, w=W) + + if self.down: + x = self.up_conv(nn.Upsample(scale_factor=2, mode="nearest")(x)) + + x = einops.array_api.rearrange(x, "b h w c -> b c h w") + x = x + x_ + return x + + class MLP_MLX(nn.Module): # mlx based nn.Module def __init__(self, channels, multiplier=4): super().__init__() diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py index 75fbf26..1966c09 100644 --- a/tests/test_mlx_unet.py +++ b/tests/test_mlx_unet.py @@ -5,8 +5,12 @@ import numpy as np import torch -from ml_mdm.models.unet import MLP, SelfAttention1D -from ml_mdm.models.unet_mlx import MLP_MLX, SelfAttention1D_MLX +from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock +from ml_mdm.models.unet_mlx import ( + MLP_MLX, + SelfAttention1D_MLX, + TemporalAttentionBlock_MLX, +) def test_pytorch_mlp(): @@ -92,3 +96,55 @@ def test_self_attention_1d(): ), "Outputs of PyTorch and MLX SelfAttention1D should match" print("Test passed for both PyTorch and MLX SelfAttention1D!") + + +def test_pytorch_mlx_temporal_attention_block(): + """ + Test for verifying parity between PyTorch and MLX implementations of TemporalAttentionBlock + """ + # Define parameters + channels = 8 + num_heads = 2 + batch_size = 2 + time_steps = 4 + height = 16 + width = 16 + + # Create model instances + pytorch_block = TemporalAttentionBlock( + channels=channels, num_heads=num_heads, down=True + ) + + mlx_block = TemporalAttentionBlock_MLX( + channels=channels, num_heads=num_heads, down=True + ) + + # Set models to evaluation mode + pytorch_block.eval() + mlx_block.eval() + + # Create dummy input tensors + pytorch_input = torch.randn(batch_size * time_steps, channels, height, width) + pytorch_temb = torch.randn(batch_size, channels) + + # Pass inputs through PyTorch model + pytorch_output = pytorch_block(pytorch_input, pytorch_temb) + + # Convert to MLX format + mlx_input = mx.array(pytorch_input.numpy()) + mlx_temb = mx.array(pytorch_temb.numpy()) + + # Pass inputs through MLX model + mlx_output = mlx_block.forward(mlx_input, mlx_temb) + + # print output tensors for debug + print("pytorch_output tensor: ", pytorch_output) + print("mlx_output tensor: ", mlx_output) + + # Assertions to validate the output + assert pytorch_output.shape == tuple(mlx_output.shape), "Output shape mismatch" + assert np.allclose( + pytorch_output.detach().numpy(), np.array(mlx_output), rtol=1e-1, atol=1e-1 + ), "Outputs of PyTorch and MLX TemporalAttentionBlock should match" + + print("Test passed for both PyTorch and MLX TemporalAttentionBlock!") From 3a8052f56cb73dd3f3957bb788dfde6fda8beb5c Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Thu, 13 Mar 2025 08:41:39 -0300 Subject: [PATCH 61/64] debugging temp attention and self atten tests --- ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py | 150 ++++++++++++++++++++ ml-mdm-matryoshka/tests/test_mlx_unet.py | 150 ++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py create mode 100644 ml-mdm-matryoshka/tests/test_mlx_unet.py diff --git a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py new file mode 100644 index 0000000..639c583 --- /dev/null +++ b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py @@ -0,0 +1,150 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import math + +import einops.array_api + +import mlx.core as mx +import mlx.nn as nn + + +def zero_module_mlx(module): + """ + Zero out the parameters of an MLX module and return it. + """ + # Create a new parameter dictionary with all parameters replaced by zeros + zeroed_params = { + name: mx.zeros(param.shape, dtype=param.dtype) + for name, param in module.parameters().items() + } + # Update the module's parameters with the zeroed parameters + module.update(zeroed_params) + return module + + +class SelfAttention1D_MLX(nn.Module): + def __init__( + self, + channels, + num_heads=8, + num_head_channels=-1, + use_attention_ffn=False, + pos_emb=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + self.norm = nn.LayerNorm(channels) + self.qkv = nn.Linear(channels, channels * 3) + self.proj_out = zero_module_mlx(nn.Linear(channels, channels)) + if use_attention_ffn: + self.ffn = nn.Sequential( + nn.LayerNorm(channels), + nn.Linear(channels, 4 * channels), + nn.GELU(), + zero_module_mlx(nn.Linear(4 * channels, channels)), + ) + else: + self.ffn = None + if pos_emb: + from mlx.nn import RoPE + + self.pos_emb = RoPE(dim=channels // self.num_heads) + else: + self.pos_emb = None + + def attention(self, q, k, v, mask=None): + bs, length, width = q.shape + ch = width // self.num_heads + scale = 1 / math.sqrt(math.sqrt(ch)) + q = q.reshape(bs, length, self.num_heads, ch) + k = k.reshape(bs, length, self.num_heads, ch) + if self.pos_emb is not None: + q = self.pos_emb.rotate_queries_or_keys(q.permute(0, 2, 1, 3)).permute( + 0, 2, 1, 3 + ) + k = self.pos_emb.rotate_queries_or_keys(k.permute(0, 2, 1, 3)).permute( + 0, 2, 1, 3 + ) + weight = mx.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + if mask is not None: + mask = mask.view(mask.size(0), 1, 1, mask.size(1)) + weight = weight.masked_fill(mask == 0, float("-inf")) + weight = mx.softmax(weight, axis=-1) + a = mx.einsum("bhts,bshc->bthc", weight, v.reshape(bs, -1, self.num_heads, ch)) + return a.reshape(bs, length, -1) + + def forward(self, x, mask): + # assert (self.cond_dim is not None) == (cond is not None) + qkv = self.qkv(self.norm(x)) + q, k, v = mx.split(qkv, 3, axis=-1) + h = self.attention(q, k, v, mask) + h = self.proj_out(h) + x = x + h + if self.ffn is not None: + x = x + self.ffn(x) + return x + + +class TemporalAttentionBlock_MLX(nn.Module): + def __init__( + self, channels, num_heads=8, num_head_channels=-1, down=False, pos_emb=False + ): + super().__init__() + self.attn = SelfAttention1D_MLX( + channels, num_heads, num_head_channels, pos_emb=pos_emb + ) + self.mlp = MLP_MLX(channels, multiplier=4) + self.down = down + if down: + self.down_conv = nn.Conv2d( + channels, channels, kernel_size=3, stride=2, padding=1, bias=True + ) + self.up_conv = nn.Conv2d( + channels, channels, kernel_size=3, stride=1, padding=1, bias=True + ) + + def forward(self, x, temb): + x_ = x + if self.down: + # transformation for mlx format + x = einops.array_api.rearrange(x, "b c h w -> b h w c") + x = self.down_conv(x) + + T, H, W = x.shape[0] // temb.shape[0], x.shape[2], x.shape[3] + x = einops.array_api.rearrange(x, "(b t) c h w -> (b h w) t c", t=T) + x = self.mlp.forward(self.attn.forward(x, None)) + x = einops.array_api.rearrange(x, "(b h w) t c -> (b t) c h w", h=H, w=W) + + if self.down: + x = self.up_conv(nn.Upsample(scale_factor=2, mode="nearest")(x)) + + x = einops.array_api.rearrange(x, "b h w c -> b c h w") + x = x + x_ + return x + + +class MLP_MLX(nn.Module): # mlx based nn.Module + def __init__(self, channels, multiplier=4): + super().__init__() + ### use mlx layers + self.main = nn.Sequential( + nn.LayerNorm(channels), + nn.Linear(channels, multiplier * channels), + nn.GELU(), + zero_module_mlx(nn.Linear(multiplier * channels, channels)), + ) + + def forward(self, x): + return x + self.main(x) + diff --git a/ml-mdm-matryoshka/tests/test_mlx_unet.py b/ml-mdm-matryoshka/tests/test_mlx_unet.py new file mode 100644 index 0000000..1c96a1e --- /dev/null +++ b/ml-mdm-matryoshka/tests/test_mlx_unet.py @@ -0,0 +1,150 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import mlx.core as mx +import numpy as np +import torch + +from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock +from ml_mdm.models.unet_mlx import ( + MLP_MLX, + SelfAttention1D_MLX, + TemporalAttentionBlock_MLX, +) + + +def test_pytorch_mlp(): + """ + Simple test for our MLP implementations + """ + # Define parameters + channels = 8 # Number of channels + multiplier = 4 # Multiplier for hidden dimensions + + # Create a model instance + pytorch_mlp = MLP(channels=channels, multiplier=multiplier) + mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) + + ## Start by testing pytorch version + + # Set model to evaluation mode + pytorch_mlp.eval() + + # Create a dummy pytorch input tensor (batch size = 2, channels = 8) + input_tensor = torch.randn(2, channels) + + # Pass the input through the model + output = pytorch_mlp(input_tensor) + + # Assertions to validate the output shape and properties + assert output.shape == input_tensor.shape, "Output shape mismatch" + assert torch.allclose( + output, input_tensor, atol=1e-5 + ), "Output should be close to input as the final layer is zero-initialized" + + ## now test mlx version + + # Convert the same input to MLX tensor + mlx_tensor = mx.array(input_tensor.numpy()) + + mlx_mlp.eval() + + mlx_output = mlx_mlp.forward(mlx_tensor) + + assert isinstance(mlx_output, mx.array) + assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" + + # Validate numerical equivalence using numpy + assert np.allclose( + output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch MLP and MLX MLP should match" + + print("Test passed for both PyTorch and MLX MLP!") + + +def test_self_attention_1d(): + # Define parameters + channels = 8 + num_heads = 2 + seq_length = 16 + batch_size = 2 + + # Create a model instance + pytorch_attn = SelfAttention1D(channels=channels, num_heads=num_heads) + mlx_attn = SelfAttention1D_MLX(channels=channels, num_heads=num_heads) + + # Set models to evaluation mode + pytorch_attn.eval() + mlx_attn.eval() + + # Create a dummy input tensor + input_tensor = torch.randn(batch_size, seq_length, channels) + + # Pass the input through the PyTorch model + pytorch_output = pytorch_attn(input_tensor, mask=None) + + # Convert the input to MLX format + mlx_input = mx.array(input_tensor.numpy()) + + # Pass the input through the MLX model + mlx_output = mlx_attn.forward(mlx_input, mask=None) + + # Assertions to validate the output shape and properties + assert pytorch_output.shape == mlx_output.shape, "Output shape mismatch" + assert np.allclose( + pytorch_output.detach().numpy(), np.array(mlx_output), atol=1e-5 + ), "Outputs of PyTorch and MLX SelfAttention1D should match" + + print("Test passed for both PyTorch and MLX SelfAttention1D!") + + +def test_pytorch_mlx_temporal_attention_block(): + """ + Test for verifying parity between PyTorch and MLX implementations of TemporalAttentionBlock + """ + # Define parameters + channels = 8 + num_heads = 2 + batch_size = 2 + time_steps = 4 + height = 16 + width = 16 + + # Create model instances + pytorch_block = TemporalAttentionBlock( + channels=channels, num_heads=num_heads, down=True + ) + + mlx_block = TemporalAttentionBlock_MLX( + channels=channels, num_heads=num_heads, down=True + ) + + # Set models to evaluation mode + pytorch_block.eval() + mlx_block.eval() + + # Create dummy input tensors + pytorch_input = torch.randn(batch_size * time_steps, channels, height, width) + pytorch_temb = torch.randn(batch_size, channels) + + # Pass inputs through PyTorch model + pytorch_output = pytorch_block(pytorch_input, pytorch_temb) + + # Convert to MLX format + mlx_input = mx.array(pytorch_input.numpy()) + mlx_temb = mx.array(pytorch_temb.numpy()) + + # Pass inputs through MLX model + mlx_output = mlx_block.forward(mlx_input, mlx_temb) + + # print output tensors for debug + print("pytorch_output tensor: ", pytorch_output) + print("mlx_output tensor: ", mlx_output) + + # Assertions to validate the output + assert pytorch_output.shape == tuple(mlx_output.shape), "Output shape mismatch" + assert np.allclose( + pytorch_output.detach().numpy(), np.array(mlx_output), rtol=1e-1, atol=1e-1 + ), "Outputs of PyTorch and MLX TemporalAttentionBlock should match" + + print("Test passed for both PyTorch and MLX TemporalAttentionBlock!") \ No newline at end of file From 63d28847d18de5a0c545ecf37b7656a336d2b7bc Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Thu, 20 Mar 2025 17:31:47 -0300 Subject: [PATCH 62/64] temporal attention with psnr comparison: 20dB --- ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py | 21 +++++----- ml-mdm-matryoshka/tests/test_mlx_unet.py | 45 ++++++++++++--------- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py index 639c583..315ffba 100644 --- a/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py +++ b/ml-mdm-matryoshka/ml_mdm/models/unet_mlx.py @@ -117,19 +117,23 @@ def __init__( def forward(self, x, temb): x_ = x if self.down: - # transformation for mlx format x = einops.array_api.rearrange(x, "b c h w -> b h w c") x = self.down_conv(x) + x = einops.array_api.rearrange(x, "b h w c -> b c h w") + x = einops.array_api.rearrange(x, "b c h w -> b h w c") T, H, W = x.shape[0] // temb.shape[0], x.shape[2], x.shape[3] - x = einops.array_api.rearrange(x, "(b t) c h w -> (b h w) t c", t=T) - x = self.mlp.forward(self.attn.forward(x, None)) - x = einops.array_api.rearrange(x, "(b h w) t c -> (b t) c h w", h=H, w=W) + x = einops.array_api.rearrange(x, "(b t) h w c -> (b h w) t c", t=T) + x = self.attn.forward(x, None) + x = self.mlp.forward(x) + x = einops.array_api.rearrange(x, "(b h w) t c -> (b t) h w c", h=H, w=W) + x = einops.array_api.rearrange(x, "b h w c -> b c h w") if self.down: - x = self.up_conv(nn.Upsample(scale_factor=2, mode="nearest")(x)) - - x = einops.array_api.rearrange(x, "b h w c -> b c h w") + x = einops.array_api.rearrange(x, "b c h w -> b h w c") + x = nn.Upsample(scale_factor=2, mode="nearest")(x) + x = self.up_conv(x) + x = einops.array_api.rearrange(x, "b h w c -> b c h w") x = x + x_ return x @@ -146,5 +150,4 @@ def __init__(self, channels, multiplier=4): ) def forward(self, x): - return x + self.main(x) - + return x + self.main(x) \ No newline at end of file diff --git a/ml-mdm-matryoshka/tests/test_mlx_unet.py b/ml-mdm-matryoshka/tests/test_mlx_unet.py index 1c96a1e..e6ff7d8 100644 --- a/ml-mdm-matryoshka/tests/test_mlx_unet.py +++ b/ml-mdm-matryoshka/tests/test_mlx_unet.py @@ -56,7 +56,7 @@ def test_pytorch_mlp(): # Validate numerical equivalence using numpy assert np.allclose( - output.detach().numpy(), np.array(mlx_output), atol=1e-5 + output.detach().numpy(), np.array(mx.stop_gradient(mlx_output)), atol=1e-5 ), "Outputs of PyTorch MLP and MLX MLP should match" print("Test passed for both PyTorch and MLX MLP!") @@ -92,7 +92,7 @@ def test_self_attention_1d(): # Assertions to validate the output shape and properties assert pytorch_output.shape == mlx_output.shape, "Output shape mismatch" assert np.allclose( - pytorch_output.detach().numpy(), np.array(mlx_output), atol=1e-5 + pytorch_output.detach().numpy(), np.array(mx.stop_gradient(mlx_output)), atol=1e-5 ), "Outputs of PyTorch and MLX SelfAttention1D should match" print("Test passed for both PyTorch and MLX SelfAttention1D!") @@ -100,7 +100,7 @@ def test_self_attention_1d(): def test_pytorch_mlx_temporal_attention_block(): """ - Test for verifying parity between PyTorch and MLX implementations of TemporalAttentionBlock + Test for verifying parity between PyTorch and MLX implementations of TemporalAttentionBlock. """ # Define parameters channels = 8 @@ -123,28 +123,37 @@ def test_pytorch_mlx_temporal_attention_block(): pytorch_block.eval() mlx_block.eval() + # Create random arrays with correct shape and dtype + arr_input = np.random.normal(0, 1, (batch_size * time_steps, channels, height, width)).astype(np.float32) + arr_temb = np.random.normal(0, 1, (batch_size, channels)).astype(np.float32) + # Create dummy input tensors - pytorch_input = torch.randn(batch_size * time_steps, channels, height, width) - pytorch_temb = torch.randn(batch_size, channels) + pytorch_input = torch.from_numpy(arr_input) + pytorch_temb = torch.from_numpy(arr_temb) - # Pass inputs through PyTorch model - pytorch_output = pytorch_block(pytorch_input, pytorch_temb) + mlx_input = mx.array(arr_input) + mlx_temb = mx.array(arr_temb) - # Convert to MLX format - mlx_input = mx.array(pytorch_input.numpy()) - mlx_temb = mx.array(pytorch_temb.numpy()) + pytorch_output = pytorch_block(pytorch_input, pytorch_temb) - # Pass inputs through MLX model mlx_output = mlx_block.forward(mlx_input, mlx_temb) - # print output tensors for debug - print("pytorch_output tensor: ", pytorch_output) - print("mlx_output tensor: ", mlx_output) - - # Assertions to validate the output - assert pytorch_output.shape == tuple(mlx_output.shape), "Output shape mismatch" + # Print output tensors for debugging + print("pytorch_output tensor shape: ", pytorch_output.shape) + print("mlx_output tensor shape: ", mlx_output.shape) + print("torch: ", pytorch_output) + print("mlx : ", mlx_output) + print("mean difference: ", np.mean(np.abs(pytorch_output.detach().numpy() - np.array(mx.stop_gradient(mlx_output))))) #0.35 + print("psnr: ", 10 * np.log10(np.max(pytorch_output.detach().numpy())**2 / np.mean((pytorch_output.detach().numpy() - np.array(mx.stop_gradient(mlx_output)))**2))) # 19.2 dB + + assert pytorch_output.shape == tuple(mlx_output.shape), f"Output shape mismatch: {pytorch_output.shape} vs {mlx_output.shape}" + + # Increase tolerance to allow for small discrepancies in floating-point operations assert np.allclose( - pytorch_output.detach().numpy(), np.array(mlx_output), rtol=1e-1, atol=1e-1 + pytorch_output.detach().numpy(), + np.array(mx.stop_gradient(mlx_output)), + rtol=1e-1, # Significantly increased tolerance + atol=1e-1, # Significantly increased tolerance ), "Outputs of PyTorch and MLX TemporalAttentionBlock should match" print("Test passed for both PyTorch and MLX TemporalAttentionBlock!") \ No newline at end of file From 517fd1ff7fec203e5b9543996fb7c5a8d8e3c066 Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Sat, 22 Mar 2025 21:16:00 -0300 Subject: [PATCH 63/64] fixing file name --- ml-mdm-matryoshka/tests/{test_mlx_unet.py => test_unet_mlx.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ml-mdm-matryoshka/tests/{test_mlx_unet.py => test_unet_mlx.py} (100%) diff --git a/ml-mdm-matryoshka/tests/test_mlx_unet.py b/ml-mdm-matryoshka/tests/test_unet_mlx.py similarity index 100% rename from ml-mdm-matryoshka/tests/test_mlx_unet.py rename to ml-mdm-matryoshka/tests/test_unet_mlx.py From fab0d14076e007bd6a6302d5e074eb434ecc0dd2 Mon Sep 17 00:00:00 2001 From: Gabriel Ayres Date: Thu, 27 Mar 2025 18:58:28 -0300 Subject: [PATCH 64/64] fixing dirs structure --- ml_mdm/models/unet_mlx.py | 149 ------------------------------------- tests/test_mlx_unet.py | 150 -------------------------------------- 2 files changed, 299 deletions(-) delete mode 100644 ml_mdm/models/unet_mlx.py delete mode 100644 tests/test_mlx_unet.py diff --git a/ml_mdm/models/unet_mlx.py b/ml_mdm/models/unet_mlx.py deleted file mode 100644 index 578696e..0000000 --- a/ml_mdm/models/unet_mlx.py +++ /dev/null @@ -1,149 +0,0 @@ -# For licensing see accompanying LICENSE file. -# Copyright (C) 2024 Apple Inc. All rights reserved. - -import math - -import einops.array_api - -import mlx.core as mx -import mlx.nn as nn - - -def zero_module_mlx(module): - """ - Zero out the parameters of an MLX module and return it. - """ - # Create a new parameter dictionary with all parameters replaced by zeros - zeroed_params = { - name: mx.zeros(param.shape, dtype=param.dtype) - for name, param in module.parameters().items() - } - # Update the module's parameters with the zeroed parameters - module.update(zeroed_params) - return module - - -class SelfAttention1D_MLX(nn.Module): - def __init__( - self, - channels, - num_heads=8, - num_head_channels=-1, - use_attention_ffn=False, - pos_emb=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - - self.norm = nn.LayerNorm(channels) - self.qkv = nn.Linear(channels, channels * 3) - self.proj_out = zero_module_mlx(nn.Linear(channels, channels)) - if use_attention_ffn: - self.ffn = nn.Sequential( - nn.LayerNorm(channels), - nn.Linear(channels, 4 * channels), - nn.GELU(), - zero_module_mlx(nn.Linear(4 * channels, channels)), - ) - else: - self.ffn = None - if pos_emb: - from mlx.nn import RoPE - - self.pos_emb = RoPE(dim=channels // self.num_heads) - else: - self.pos_emb = None - - def attention(self, q, k, v, mask=None): - bs, length, width = q.shape - ch = width // self.num_heads - scale = 1 / math.sqrt(math.sqrt(ch)) - q = q.reshape(bs, length, self.num_heads, ch) - k = k.reshape(bs, length, self.num_heads, ch) - if self.pos_emb is not None: - q = self.pos_emb.rotate_queries_or_keys(q.permute(0, 2, 1, 3)).permute( - 0, 2, 1, 3 - ) - k = self.pos_emb.rotate_queries_or_keys(k.permute(0, 2, 1, 3)).permute( - 0, 2, 1, 3 - ) - weight = mx.einsum( - "bthc,bshc->bhts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - if mask is not None: - mask = mask.view(mask.size(0), 1, 1, mask.size(1)) - weight = weight.masked_fill(mask == 0, float("-inf")) - weight = mx.softmax(weight, axis=-1) - a = mx.einsum("bhts,bshc->bthc", weight, v.reshape(bs, -1, self.num_heads, ch)) - return a.reshape(bs, length, -1) - - def forward(self, x, mask): - # assert (self.cond_dim is not None) == (cond is not None) - qkv = self.qkv(self.norm(x)) - q, k, v = mx.split(qkv, 3, axis=-1) - h = self.attention(q, k, v, mask) - h = self.proj_out(h) - x = x + h - if self.ffn is not None: - x = x + self.ffn(x) - return x - - -class TemporalAttentionBlock_MLX(nn.Module): - def __init__( - self, channels, num_heads=8, num_head_channels=-1, down=False, pos_emb=False - ): - super().__init__() - self.attn = SelfAttention1D_MLX( - channels, num_heads, num_head_channels, pos_emb=pos_emb - ) - self.mlp = MLP_MLX(channels, multiplier=4) - self.down = down - if down: - self.down_conv = nn.Conv2d( - channels, channels, kernel_size=3, stride=2, padding=1, bias=True - ) - self.up_conv = nn.Conv2d( - channels, channels, kernel_size=3, stride=1, padding=1, bias=True - ) - - def forward(self, x, temb): - x_ = x - if self.down: - # transformation for mlx format - x = einops.array_api.rearrange(x, "b c h w -> b h w c") - x = self.down_conv(x) - - T, H, W = x.shape[0] // temb.shape[0], x.shape[2], x.shape[3] - x = einops.array_api.rearrange(x, "(b t) c h w -> (b h w) t c", t=T) - x = self.mlp.forward(self.attn.forward(x, None)) - x = einops.array_api.rearrange(x, "(b h w) t c -> (b t) c h w", h=H, w=W) - - if self.down: - x = self.up_conv(nn.Upsample(scale_factor=2, mode="nearest")(x)) - - x = einops.array_api.rearrange(x, "b h w c -> b c h w") - x = x + x_ - return x - - -class MLP_MLX(nn.Module): # mlx based nn.Module - def __init__(self, channels, multiplier=4): - super().__init__() - ### use mlx layers - self.main = nn.Sequential( - nn.LayerNorm(channels), - nn.Linear(channels, multiplier * channels), - nn.GELU(), - zero_module_mlx(nn.Linear(multiplier * channels, channels)), - ) - - def forward(self, x): - return x + self.main(x) diff --git a/tests/test_mlx_unet.py b/tests/test_mlx_unet.py deleted file mode 100644 index 1966c09..0000000 --- a/tests/test_mlx_unet.py +++ /dev/null @@ -1,150 +0,0 @@ -# For licensing see accompanying LICENSE file. -# Copyright (C) 2024 Apple Inc. All rights reserved. - -import mlx.core as mx -import numpy as np -import torch - -from ml_mdm.models.unet import MLP, SelfAttention1D, TemporalAttentionBlock -from ml_mdm.models.unet_mlx import ( - MLP_MLX, - SelfAttention1D_MLX, - TemporalAttentionBlock_MLX, -) - - -def test_pytorch_mlp(): - """ - Simple test for our MLP implementations - """ - # Define parameters - channels = 8 # Number of channels - multiplier = 4 # Multiplier for hidden dimensions - - # Create a model instance - pytorch_mlp = MLP(channels=channels, multiplier=multiplier) - mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier) - - ## Start by testing pytorch version - - # Set model to evaluation mode - pytorch_mlp.eval() - - # Create a dummy pytorch input tensor (batch size = 2, channels = 8) - input_tensor = torch.randn(2, channels) - - # Pass the input through the model - output = pytorch_mlp(input_tensor) - - # Assertions to validate the output shape and properties - assert output.shape == input_tensor.shape, "Output shape mismatch" - assert torch.allclose( - output, input_tensor, atol=1e-5 - ), "Output should be close to input as the final layer is zero-initialized" - - ## now test mlx version - - # Convert the same input to MLX tensor - mlx_tensor = mx.array(input_tensor.numpy()) - - mlx_mlp.eval() - - mlx_output = mlx_mlp.forward(mlx_tensor) - - assert isinstance(mlx_output, mx.array) - assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch" - - # Validate numerical equivalence using numpy - assert np.allclose( - output.detach().numpy(), np.array(mlx_output), atol=1e-5 - ), "Outputs of PyTorch MLP and MLX MLP should match" - - print("Test passed for both PyTorch and MLX MLP!") - - -def test_self_attention_1d(): - # Define parameters - channels = 8 - num_heads = 2 - seq_length = 16 - batch_size = 2 - - # Create a model instance - pytorch_attn = SelfAttention1D(channels=channels, num_heads=num_heads) - mlx_attn = SelfAttention1D_MLX(channels=channels, num_heads=num_heads) - - # Set models to evaluation mode - pytorch_attn.eval() - mlx_attn.eval() - - # Create a dummy input tensor - input_tensor = torch.randn(batch_size, seq_length, channels) - - # Pass the input through the PyTorch model - pytorch_output = pytorch_attn(input_tensor, mask=None) - - # Convert the input to MLX format - mlx_input = mx.array(input_tensor.numpy()) - - # Pass the input through the MLX model - mlx_output = mlx_attn.forward(mlx_input, mask=None) - - # Assertions to validate the output shape and properties - assert pytorch_output.shape == mlx_output.shape, "Output shape mismatch" - assert np.allclose( - pytorch_output.detach().numpy(), np.array(mlx_output), atol=1e-5 - ), "Outputs of PyTorch and MLX SelfAttention1D should match" - - print("Test passed for both PyTorch and MLX SelfAttention1D!") - - -def test_pytorch_mlx_temporal_attention_block(): - """ - Test for verifying parity between PyTorch and MLX implementations of TemporalAttentionBlock - """ - # Define parameters - channels = 8 - num_heads = 2 - batch_size = 2 - time_steps = 4 - height = 16 - width = 16 - - # Create model instances - pytorch_block = TemporalAttentionBlock( - channels=channels, num_heads=num_heads, down=True - ) - - mlx_block = TemporalAttentionBlock_MLX( - channels=channels, num_heads=num_heads, down=True - ) - - # Set models to evaluation mode - pytorch_block.eval() - mlx_block.eval() - - # Create dummy input tensors - pytorch_input = torch.randn(batch_size * time_steps, channels, height, width) - pytorch_temb = torch.randn(batch_size, channels) - - # Pass inputs through PyTorch model - pytorch_output = pytorch_block(pytorch_input, pytorch_temb) - - # Convert to MLX format - mlx_input = mx.array(pytorch_input.numpy()) - mlx_temb = mx.array(pytorch_temb.numpy()) - - # Pass inputs through MLX model - mlx_output = mlx_block.forward(mlx_input, mlx_temb) - - # print output tensors for debug - print("pytorch_output tensor: ", pytorch_output) - print("mlx_output tensor: ", mlx_output) - - # Assertions to validate the output - assert pytorch_output.shape == tuple(mlx_output.shape), "Output shape mismatch" - assert np.allclose( - pytorch_output.detach().numpy(), np.array(mlx_output), rtol=1e-1, atol=1e-1 - ), "Outputs of PyTorch and MLX TemporalAttentionBlock should match" - - print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")