From 341d03b4057929d4e114d7ddc2b07fce3fb1157b Mon Sep 17 00:00:00 2001 From: "John D. Pope" Date: Wed, 27 Mar 2024 10:31:15 +1100 Subject: [PATCH] Update diffusion.py fixes problem with device mismatch. --- diffusion.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/diffusion.py b/diffusion.py index bf517e10..19afd56d 100644 --- a/diffusion.py +++ b/diffusion.py @@ -64,6 +64,8 @@ def get_p_params(self, xt, timesteps, nn_out): else: eps_pred, nu = nn_out.chunk(2, 1) nu = (nu + 1) / 2 + self.beta = self.beta.to(xt.device) + self.log_beta_tilde_clipped = self.log_beta_tilde_clipped.to(xt.device) p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps]) p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred) @@ -72,12 +74,16 @@ def get_p_params(self, xt, timesteps, nn_out): def get_q_params(self, xt, timesteps, eps_pred=None, x0=None): if x0 is None: # predict x0 from xt and eps_pred + self.coef1_x0=self.coef1_x0.to(xt.device) + self.coef2_x0=self.coef2_x0.to(xt.device) coef1_x0 = self.expand(self.coef1_x0[timesteps]) coef2_x0 = self.expand(self.coef2_x0[timesteps]) x0 = coef1_x0 * xt - coef2_x0 * eps_pred x0 = x0.clamp(-1, 1) # q(x_{t-1} | x_t, x_0) + self.coef1_q=self.coef1_q.to(xt.device) + self.coef2_q=self.coef2_q.to(xt.device) coef1_q = self.expand(self.coef1_q[timesteps]) coef2_q = self.expand(self.coef2_q[timesteps]) q_mean = coef1_q * x0 + coef2_q * xt @@ -142,4 +148,4 @@ def get_spaced_beta(self): def expand(self, arr, dim=4): while arr.dim() < dim: arr = arr[:, None] - return arr.to(self.device) \ No newline at end of file + return arr.to(self.device)