From d21aa5d946f05843acf6c5596e9d0158295b9444 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Thu, 19 Sep 2024 18:33:27 -0400 Subject: [PATCH] Replace setup.py with pyproject.toml This allows creating lock files for isolated environments. For example with: uv venv && uv lock -U && uv sync. Also, fixed the Ruff configuration and autofixed the errors with the latest Ruff. --- diffusionjax/utils.py | 8 ++--- examples/example.py | 2 +- examples/example1.py | 4 +-- examples/example2.py | 4 +-- pyproject.toml | 56 +++++++++++++++++++++++++++++++++ ruff.toml | 5 +-- setup.py | 72 ------------------------------------------- 7 files changed, 68 insertions(+), 83 deletions(-) delete mode 100644 setup.py diff --git a/diffusionjax/utils.py b/diffusionjax/utils.py index 7c9c169..c86e33d 100644 --- a/diffusionjax/utils.py +++ b/diffusionjax/utils.py @@ -53,7 +53,7 @@ def beta(t): return beta_min + t * (beta_max - beta_min) def mean_coeff(t): - """..math: exp(-0.5 * \int_{0}^{t} \beta(s) ds)""" + """..math: exp(-0.5 * \\int_{0}^{t} \\beta(s) ds)""" return jnp.exp(-0.5 * t * beta_min - 0.25 * t**2 * (beta_max - beta_min)) return beta, mean_coeff @@ -69,7 +69,7 @@ def get_cosine_beta_function(beta_max, offset=0.08): offset: https://arxiv.org/abs/2102.09672 "Use a small offset to prevent $\beta(t)$ from being too small near $t = 0$, since we found that having tiny amounts of noise at the beginning - of the process made it hard for the network to predict $\epsilon$ + of the process made it hard for the network to predict $\\epsilon$ accurately enough" """ @@ -78,7 +78,7 @@ def beta(t): return jnp.clip(jnp.sin((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) / (jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) + 1e-5) * jnp.pi * (1.0 / (1.0 + offset)), a_max=beta_max) def mean_coeff(t): - """..math: -0.5 * \int_{0}^{t} \beta(s) ds""" + """..math: -0.5 * \\int_{0}^{t} \\beta(s) ds""" return jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) # return jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) / jnp.cos(offset / (1.0 + offset) * 0.5 * jnp.pi) @@ -128,7 +128,7 @@ def gamma(sigmas): def get_times(num_steps=1000, dt=None, t0=None): - """ + r""" Get linear, monotonically increasing time schedule. Args: num_steps: number of discretization time steps. diff --git a/examples/example.py b/examples/example.py index fb48ffa..c099878 100644 --- a/examples/example.py +++ b/examples/example.py @@ -144,7 +144,7 @@ def main(argv): ) def log_hat_pt(x, t): - """Empirical distribution score. + r"""Empirical distribution score. Args: x: One location in $\mathbb{R}^2$ diff --git a/examples/example1.py b/examples/example1.py index a848e90..75cffe9 100644 --- a/examples/example1.py +++ b/examples/example1.py @@ -58,7 +58,7 @@ def __call__(self, x, t): @partial(jit, static_argnums=[4]) def update_step(params, rng, batch, opt_state, loss): - """ + r""" Takes the gradient of the loss function and updates the model weights (params) using it. Args: params: the current weights of the model @@ -168,7 +168,7 @@ def nabla_log_pt(x, t): Returns: The true log density. .. math:: - \nabla_{x} \log p_{t}(x) + \nabla_{x} \\log p_{t}(x) """ x_shape = x.shape v_t = sde.variance(t) diff --git a/examples/example2.py b/examples/example2.py index 9fa2d0f..774a14d 100644 --- a/examples/example2.py +++ b/examples/example2.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from flax import serialization from functools import partial -from diffusionjax.plot import plot_samples, plot_heatmap, plot_samples_1D, plot_samples +from diffusionjax.plot import plot_heatmap, plot_samples_1D, plot_samples from diffusionjax.utils import ( get_score, get_loss, @@ -63,7 +63,7 @@ def __call__(self, x, t): @partial(jit, static_argnums=[4]) def update_step(params, rng, batch, opt_state, loss): - """ + r""" Takes the gradient of the loss function and updates the model weights (params) using it. Args: params: the current weights of the model diff --git a/pyproject.toml b/pyproject.toml index fce4ade..9afeacf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,5 +7,61 @@ requires = [ "numpy>=1.16", ] +[tool.setuptools] +py-modules = [] + [tool.setuptools_scm] write_to = "diffusionjax/_version.py" + +[project] +name = "diffusionjax" +description = "diffusionjax is a simple and accessible diffusion models package in JAX" +readme = 'README.md' +requires-python = '>=3.9, <3.13' +license = {file = 'LICENSE.rst'} +authors = [{name = 'Benjamin Boys'}, {name = 'Jakiw Pidstrigach'}] +maintainers = [{name = 'Benjamin Boys'}, {name = 'Jakiw Pidstrigach'}] +dependencies = [ + "numpy", + "scipy", + "matplotlib", + "flax", + "ml_collections", + "tqdm", + "absl-py", + "wandb", +] +dynamic = ['version'] + +[project.optional-dependencies] +linting = [ + "flake8", + "pylint", + "mypy", + "typing-extensions", + "pre-commit", + "ruff", + 'jaxtyping', +] +testing = [ + "optax", + "orbax-checkpoint", + "torch", + "pytest", + "pytest-xdist", + "pytest-cov", + "coveralls", + "jax>=0.4.1", + "jaxlib>=0.4.1", + "setuptools_scm[toml]", + "setuptools_scm_git_archive", +] +examples = [ + "optax", + "orbax-checkpoint", + "torch", + "mlkernels", +] + +[project.urls] +repository = "https://github.com/bb515/diffusionjax" diff --git a/ruff.toml b/ruff.toml index ec54cf2..dec9c27 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,5 +1,6 @@ -tab-size = 2 +indent-width = 2 +[lint] select = [ "F", "W6", @@ -30,7 +31,7 @@ exclude = [ "wandb/", ] -[per-file-ignores] +[lint.per-file-ignores] "test/*" = [ "F401", "F403", diff --git a/setup.py b/setup.py deleted file mode 100644 index 296c502..0000000 --- a/setup.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Setup script for diffusionjax. - -This setup is required or else - >> ModuleNotFoundError: No module named 'diffusionjax' -will occur. -""" -from setuptools import setup, find_packages -import pathlib - - -# The directory containing this file -HERE = pathlib.Path(__file__).parent - -# The text of the README file -README = (HERE / "README.md").read_text() - -# The text of the LICENSE file -LICENSE = (HERE / "LICENSE.rst").read_text() - -setup( - name="diffusionjax", - # python_requires=">=3.8", - description="diffusionjax is a simple and accessible diffusion models package in JAX", - long_description=README, - long_description_content_type="text/markdown", - url="https://github.com/bb515/diffusionjax", - author="Benjamin Boys and Jakiw Pidstrigach", - license="MIT", - license_file=LICENSE, - packages=find_packages(exclude=["*.test"]), - install_requires=[ - "numpy", - "scipy", - "matplotlib", - "flax", - "ml_collections", - "tqdm", - "absl-py", - "wandb", - ], - extras_require={ - 'linting': [ - "flake8", - "pylint", - "mypy", - "typing-extensions", - "pre-commit", - "ruff", - 'jaxtyping', - ], - 'testing': [ - "optax", - "orbax-checkpoint", - "torch", - "pytest", - "pytest-xdist", - "pytest-cov", - "coveralls", - "jax>=0.4.1", - "jaxlib>=0.4.1", - "setuptools_scm[toml]", - "setuptools_scm_git_archive", - ], - 'examples': [ - "optax", - "orbax-checkpoint", - "torch", - "mlkernels", - ], - }, - include_package_data=True)