Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions diffusionjax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def beta(t):
return beta_min + t * (beta_max - beta_min)

def mean_coeff(t):
"""..math: exp(-0.5 * \int_{0}^{t} \beta(s) ds)"""
"""..math: exp(-0.5 * \\int_{0}^{t} \\beta(s) ds)"""
return jnp.exp(-0.5 * t * beta_min - 0.25 * t**2 * (beta_max - beta_min))

return beta, mean_coeff
Expand All @@ -69,7 +69,7 @@ def get_cosine_beta_function(beta_max, offset=0.08):
offset: https://arxiv.org/abs/2102.09672 "Use a small offset to prevent
$\beta(t)$ from being too small near
$t = 0$, since we found that having tiny amounts of noise at the beginning
of the process made it hard for the network to predict $\epsilon$
of the process made it hard for the network to predict $\\epsilon$
accurately enough"
"""

Expand All @@ -78,7 +78,7 @@ def beta(t):
return jnp.clip(jnp.sin((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) / (jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) + 1e-5) * jnp.pi * (1.0 / (1.0 + offset)), a_max=beta_max)

def mean_coeff(t):
"""..math: -0.5 * \int_{0}^{t} \beta(s) ds"""
"""..math: -0.5 * \\int_{0}^{t} \\beta(s) ds"""
return jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi)
# return jnp.cos((t + offset) / (1.0 + offset) * 0.5 * jnp.pi) / jnp.cos(offset / (1.0 + offset) * 0.5 * jnp.pi)

Expand Down Expand Up @@ -128,7 +128,7 @@ def gamma(sigmas):


def get_times(num_steps=1000, dt=None, t0=None):
"""
r"""
Get linear, monotonically increasing time schedule.
Args:
num_steps: number of discretization time steps.
Expand Down
2 changes: 1 addition & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def main(argv):
)

def log_hat_pt(x, t):
"""Empirical distribution score.
r"""Empirical distribution score.

Args:
x: One location in $\mathbb{R}^2$
Expand Down
4 changes: 2 additions & 2 deletions examples/example1.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, x, t):

@partial(jit, static_argnums=[4])
def update_step(params, rng, batch, opt_state, loss):
"""
r"""
Takes the gradient of the loss function and updates the model weights (params) using it.
Args:
params: the current weights of the model
Expand Down Expand Up @@ -168,7 +168,7 @@ def nabla_log_pt(x, t):
Returns:
The true log density.
.. math::
\nabla_{x} \log p_{t}(x)
\nabla_{x} \\log p_{t}(x)
"""
x_shape = x.shape
v_t = sde.variance(t)
Expand Down
4 changes: 2 additions & 2 deletions examples/example2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
from flax import serialization
from functools import partial
from diffusionjax.plot import plot_samples, plot_heatmap, plot_samples_1D, plot_samples
from diffusionjax.plot import plot_heatmap, plot_samples_1D, plot_samples
from diffusionjax.utils import (
get_score,
get_loss,
Expand Down Expand Up @@ -63,7 +63,7 @@ def __call__(self, x, t):

@partial(jit, static_argnums=[4])
def update_step(params, rng, batch, opt_state, loss):
"""
r"""
Takes the gradient of the loss function and updates the model weights (params) using it.
Args:
params: the current weights of the model
Expand Down
56 changes: 56 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,61 @@ requires = [
"numpy>=1.16",
]

[tool.setuptools]
py-modules = []

[tool.setuptools_scm]
write_to = "diffusionjax/_version.py"

[project]
name = "diffusionjax"
description = "diffusionjax is a simple and accessible diffusion models package in JAX"
readme = 'README.md'
requires-python = '>=3.9, <3.13'
license = {file = 'LICENSE.rst'}
authors = [{name = 'Benjamin Boys'}, {name = 'Jakiw Pidstrigach'}]
maintainers = [{name = 'Benjamin Boys'}, {name = 'Jakiw Pidstrigach'}]
dependencies = [
"numpy",
"scipy",
"matplotlib",
"flax",
"ml_collections",
"tqdm",
"absl-py",
"wandb",
]
dynamic = ['version']

[project.optional-dependencies]
linting = [
"flake8",
"pylint",
"mypy",
"typing-extensions",
"pre-commit",
"ruff",
'jaxtyping',
]
testing = [
"optax",
"orbax-checkpoint",
"torch",
"pytest",
"pytest-xdist",
"pytest-cov",
"coveralls",
"jax>=0.4.1",
"jaxlib>=0.4.1",
"setuptools_scm[toml]",
"setuptools_scm_git_archive",
]
examples = [
"optax",
"orbax-checkpoint",
"torch",
"mlkernels",
]

[project.urls]
repository = "https://github.com/bb515/diffusionjax"
5 changes: 3 additions & 2 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
tab-size = 2
indent-width = 2

[lint]
select = [
"F",
"W6",
Expand Down Expand Up @@ -30,7 +31,7 @@ exclude = [
"wandb/",
]

[per-file-ignores]
[lint.per-file-ignores]
"test/*" = [
"F401",
"F403",
Expand Down
72 changes: 0 additions & 72 deletions setup.py

This file was deleted.