From 1e4a8a27d5fa3c7043514015df521540ee976f48 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 27 Jun 2023 20:49:06 +0000 Subject: [PATCH 01/14] Add factorized log probability functionality --- gensn/distributions.py | 107 +++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 46 deletions(-) diff --git a/gensn/distributions.py b/gensn/distributions.py index 8841602..288dfae 100644 --- a/gensn/distributions.py +++ b/gensn/distributions.py @@ -24,6 +24,12 @@ def log_prob(self, *obs, cond=None): x, y = obs[: self.prior.n_rvs], obs[self.prior.n_rvs :] return self.prior(*x, cond=cond) + self.conditional(*y, cond=x) + def factorized_log_prob(self, *obs, cond=None): + x, y = obs[: self.prior.n_rvs], obs[self.prior.n_rvs :] + return self.prior.factorized_log_prob( + *x, cond=cond + ) + self.conditional.factorized_log_prob(*y, cond=x) + def forward(self, *obs, cond=None): return self.log_prob(*obs, cond=cond) @@ -38,23 +44,6 @@ def rsample(self, sample_shape=torch.Size([]), cond=None): return turn_to_tuple(x_samples) + turn_to_tuple(y_samples) -# class DeltaDistribution(nn.Module): -# def __init__(self, value): -# self.value = value - -# def log_prob(self, obs, cond=None): -# # TODO: write to deal with more than one rvs -# return torch.log(self.prob(obs, cond=cond)) - -# def prob(self, obs, cond=None): -# # TODO: write to deal with more than one rvs -# return torch.where( -# torch.equal(obs, parse_attr(self.value, cond=cond)), 0, 1 -# ).to(obs.device) - -# def sample(self, sample_shape=torch.size([]), cond=None): - - class TrainableDistribution(nn.Module, ABC): """ Here we are providing the proper abstract base class for the @@ -110,26 +99,6 @@ def __init__(self, distribution_class, *dist_args, _parameters=None, **dist_kwar if _parameters is not None: self.parameter_generator = _parameters - # overwrite extra_repr to include the distribution class - # TODO: consider adding the parameters as well - def extra_repr(self): - repr = f"distribution_class={self.distribution_class!r}" - - if self.param_counts > 0: - repr += ", " + ", ".join( - f"{getattr(self, f'_arg{pos}')!r}" for pos in range(self.param_counts) - ) - - if len(self.param_keys) > 0: - repr += ", " + ", ".join( - f"{k}={getattr(self, k)!r}" for k in self.param_keys - ) - - if hasattr(self, "parameter_genrator"): - repr += ", " + f"_parameters={self.parameter_generator!r}" - - return repr - def distribution(self, cond=None): cond = turn_to_tuple(cond) @@ -162,18 +131,25 @@ def sample(self, sample_shape=torch.Size([]), cond=None): def rsample(self, sample_shape=torch.Size([]), cond=None): return self.distribution(cond=cond).rsample(sample_shape=sample_shape) + # overwrite extra_repr to include the distribution class + # TODO: consider adding the parameters as well + def extra_repr(self): + repr = f"distribution_class={self.distribution_class!r}" -# def wrap_with_indep(distribution_class, event_dims=1): -# """ -# Wrap the construction of the target distribution `distr_class` with -# D.Independent. The returned function can be used as if it is -# a constructor for an indepdenent version of the distribution -# """ + if self.param_counts > 0: + repr += ", " + ", ".join( + f"{getattr(self, f'_arg{pos}')!r}" for pos in range(self.param_counts) + ) -# def indep_distr(*args, **kwargs): -# return D.Independent(distribution_class(*args, **kwargs), event_dims) + if len(self.param_keys) > 0: + repr += ", " + ", ".join( + f"{k}={getattr(self, k)!r}" for k in self.param_keys + ) -# return indep_distr + if hasattr(self, "parameter_genrator"): + repr += ", " + f"_parameters={self.parameter_generator!r}" + + return repr class IndependentTrainableDistributionAdapter(TrainableDistributionAdapter): @@ -193,6 +169,12 @@ def __init__( def distribution(self, cond=None): return D.Independent(super().distribution(cond=cond), self.event_dims) + def factorized_distribution(self, cond=None): + return super().distribution(cond=cond) + + def factorized_log_prob(self, *obs, cond=None): + return self.factorized_distribution(cond=cond).log_prob(*obs) + def extra_repr(self): return super().extra_repr() + f", event_dims={self.event_dims}" @@ -212,6 +194,9 @@ def forward(self, *obs, cond=None): def log_prob(self, *obs, cond=None): return self.trainable_distribution.log_prob(*obs, cond=cond) + def factorized_log_prob(self, *obs, cond=None): + return self.trainable_distribution.factorized_log_prob(*obs, cond=cond) + def sample(self, sample_shape=torch.Size([]), cond=None): return self.trainable_distribution.sample(sample_shape=sample_shape, cond=cond) @@ -421,3 +406,33 @@ def __init__(self, loc=None, scale=None, _parameters=None, event_dims=1): **kwargs, _parameters=_parameters, ) + + +# class DeltaDistribution(nn.Module): +# def __init__(self, value): +# self.value = value + +# def log_prob(self, obs, cond=None): +# # TODO: write to deal with more than one rvs +# return torch.log(self.prob(obs, cond=cond)) + +# def prob(self, obs, cond=None): +# # TODO: write to deal with more than one rvs +# return torch.where( +# torch.equal(obs, parse_attr(self.value, cond=cond)), 0, 1 +# ).to(obs.device) + +# def sample(self, sample_shape=torch.size([]), cond=None): + + +# def wrap_with_indep(distribution_class, event_dims=1): +# """ +# Wrap the construction of the target distribution `distr_class` with +# D.Independent. The returned function can be used as if it is +# a constructor for an indepdenent version of the distribution +# """ + +# def indep_distr(*args, **kwargs): +# return D.Independent(distribution_class(*args, **kwargs), event_dims) + +# return indep_distr From be5da398171342b2703987b7183d45d2467e243f Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 27 Jun 2023 20:50:11 +0000 Subject: [PATCH 02/14] Change marginal -> factorized and add factorized log det functionality --- gensn/transforms/invertible.py | 151 ++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 68 deletions(-) diff --git a/gensn/transforms/invertible.py b/gensn/transforms/invertible.py index 71755ba..fad269e 100644 --- a/gensn/transforms/invertible.py +++ b/gensn/transforms/invertible.py @@ -30,6 +30,20 @@ def inverse(self, y, cond=None): logL += logDet return y, logL + def factorized_forward(self, x, cond=None): + logL = 0 + for t in self.transforms: + x, logDet = t.factorized_forward(x, cond=cond) + logL += logDet + return x, logL + + def factorized_inverse(self, y, cond=None): + logL = 0 + for t in self.transforms[::-1]: + y, logDet = t.factorized_inverse(y, cond=cond) + logL += logDet + return y, logL + class InverseTransform(nn.Module): def __init__(self, transform): @@ -42,6 +56,12 @@ def forward(self, x, cond=None): def inverse(self, y, cond=None): return self.transform.forward(y, cond=cond) + def factorized_forward(self, x, cond=None): + return self.transform.factorized_inverse(x, cond=cond) + + def factorized_inverse(self, y, cond=None): + return self.transform.factorized_forward(y, cond=cond) + class ConditionalShift(nn.Module): def __init__(self, conditional_shift): @@ -66,11 +86,11 @@ def inverse(self, x, cond=None): # return x, logL - log_det_f_prime -class MarginalTransform(nn.Module): - """Defines marginal transform template""" +class FactorizedTransform(nn.Module): + """Defines factorized transform template""" def __init__(self, dim=-1): - """Initialize the marginal transform. Specify the dimension to be + """Initialize the factorized transform. Specify the dimension to be collapsed over when computing the log determinant. Args: @@ -79,174 +99,169 @@ def __init__(self, dim=-1): super().__init__() self.dim = dim - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): pass - def marginal_inverse(self, y): + def factorized_inverse_transform(self, y, cond=None): pass - def get_log_det(self, x): + def get_log_det(self, x, cond=None): pass + def factorized_forward(self, x, cond=None): + return self.factorized_transform(x, cond=cond), self.get_log_det(x, cond=cond) + def forward(self, x, cond=None): - # TODO: deal with passing the cond through - return self.marginal_forward(x), self.get_log_det(x).sum(dim=self.dim) + y, log_det = self.factorized_forward(x, cond=cond) + return y, log_det.sum(dim=self.dim) + + def factorized_inverse(self, y, cond=None): + x = self.factorized_inverse_transform(y, cond=cond) + return x, -self.get_log_det(x, cond=cond) def inverse(self, y, cond=None): - x = self.marginal_inverse(y) - return x, -self.get_log_det(x).sum(dim=self.dim) + x, log_det = self.factorized_inverse(y, cond=cond) + return x, log_det.sum(dim=self.dim) -class IndependentAffine(MarginalTransform): +class IndependentAffine(FactorizedTransform): def __init__(self, input_dim=1, dim=-1): super().__init__(dim=dim) self.input_dim = input_dim self.weight = nn.Parameter(torch.empty(input_dim)) self.bias = nn.Parameter(torch.empty(input_dim)) - def get_log_det(self, x): + def get_log_det(self, x, cond=None): return torch.log( abs(self.weight) + torch.finfo(self.weight.dtype).tiny ) * torch.ones_like(x) - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): return x * self.weight + self.bias - def marginal_inverse(self, y): - return (y - self.bias) / self.weight + def factorized_inverse_transform(self, y, cond=None): + return (y - self.bias) / (self.weight + torch.finfo(self.weight.dtype).tiny) -class ELU(MarginalTransform): - def __init__(self, alpha=1.0, dim=-1): +class ELU(FactorizedTransform): + def __init__(self, alpha=1.0, offset=0.0, dim=-1): super().__init__(dim=dim) self.alpha = alpha + self.offset = offset - def get_log_det(self, x): + def get_log_det(self, x, cond=None): return torch.where(x > 0, torch.zeros(1).to(x.device), x + math.log(self.alpha)) - def marginal_forward(self, x, cond=None): - return F.elu(x, self.alpha) + def factorized_transform(self, x, cond=None): + return F.elu(x, self.alpha) + self.offset - def marginal_inverse(self, y, cond=None): + def factorized_inverse_transform(self, y, cond=None): # TODO: check if use of finfo is meaningful finfo = torch.finfo(y.dtype) + y = y - self.offset return torch.where( y > 0, y, torch.log((y / self.alpha + 1).clamp(min=finfo.tiny)) ) -class OffsetELU(nn.Module): - def __init__(self, alpha=1.0, offset=0.0, dim=-1): - super().__init__() - self.elu = ELU(alpha, dim=dim) - self.offset = offset - - def forward(self, x, cond=None): - y, logL = self.elu(x, cond=cond) - return y + self.offset, logL - - def inverse(self, y, cond=None): - return self.elu.inverse(y - self.offset, cond=cond) - - -class ELUplus1(OffsetELU): +class ELUplus1(ELU): def __init__(self, alpha=1.0, dim=-1): super().__init__(alpha=alpha, offset=1.0, dim=dim) -class Softplus(MarginalTransform): - def get_log_det(self, x): +class Softplus(FactorizedTransform): + def get_log_det(self, x, cond=None): # TODO: get the implementation for softplus return -F.softplus(-x) - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): return F.softplus(x) - def marginal_inverse(self, y): + def factorized_inverse_transform(self, y, cond=None): # TODO: consider providing inverse_softplus return (-y).expm1().neg().clamp(min=torch.finfo(y.dtype).tiny).log() + y -class Exp(MarginalTransform): - def get_log_det(self, x): +class Exp(FactorizedTransform): + def get_log_det(self, x, cond=None): return x - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): return x.exp() - def marginal_inverse(self, y): + def factorized_inverse_transform(self, y, cond=None): return y.clamp(min=torch.finfo(y.dtype).tiny).log() -class Tanh(MarginalTransform): - def get_log_det(self, x): +class Tanh(FactorizedTransform): + def get_log_det(self, x, cond=None): # using numerically stable formula from TF implementation: # https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/bijectors/tanh.py#L69-L80 return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x)) - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): return torch.tanh(x) # switchable safeguarding? # publicized? # this is silently failing for atanh out of bounds inputs - def marginal_inverse(self, y): + def factorized_inverse_transform(self, y, cond=None): eps = torch.finfo(y.dtype).eps return y.clamp(min=-1 + eps, max=1 - eps).atanh() -class Sigmoid(MarginalTransform): +class Sigmoid(FactorizedTransform): # TODO: implement scaling and also top & bottom offset - def get_log_det(self, x): + def get_log_det(self, x, cond=None): return -F.softplus(-x) - F.softplus(x) - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): return torch.sigmoid(x) - def marginal_inverse(self, y): + def factorized_inverse_transform(self, y, cond=None): finfo = torch.finfo(y.dtype) y = torch.clamp(y, min=finfo.tiny, max=1.0 - finfo.eps) return y.log() - (-y).log1p() -class Log(MarginalTransform): - def get_log_det(self, x): +class Log(FactorizedTransform): + def get_log_det(self, x, cond=None): return -torch.log(abs(x).clamp(min=torch.finfo(x.dtype).tiny)) - def marginal_forward(self, x, cond=None): + def factorized_forward(self, x, cond=None): return x.clamp(min=torch.finfo(x.dtype).tiny).log() - def marginal_inverse(self, y, cond=None): + def factorized_inverse_transform(self, y, cond=None): return y.exp() -class Pow(MarginalTransform): - def get_log_det(self, x): +class Pow(FactorizedTransform): + def __init__(self, exponent, dim=-1): + super().__init__(dim=dim) + self.exponent = exponent + + def get_log_det(self, x, cond=None): # TODO: deal with number/tensor conversion better here # currently using torch.zeros to ensure sum is a tensor return torch.log( torch.zeros([]) + abs(self.exponent) + torch.finfo(x.dtype).tiny ) + (self.exponent - 1) * torch.log(abs(x) + torch.finfo(x.dtype).tiny) - def __init__(self, exponent, dim=-1): - super().__init__(dim=dim) - self.exponent = exponent - - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): return x.pow(self.exponent) - def marginal_inverse(self, y): + def factorized_inverse_transform(self, y, cond=None): return y.pow(1 / (self.exponent + torch.finfo(y.dtype).tiny)) -class Sqrt(MarginalTransform): - def get_log_det(self, x): +class Sqrt(FactorizedTransform): + def get_log_det(self, x, cond=None): # TODO: replace with torch.log return -math.log(2.0) - 0.5 * torch.log(abs(x) + torch.finfo(x.dtype).tiny) - def marginal_forward(self, x): + def factorized_transform(self, x, cond=None): # TODO: Evaluate if this clamping is a good idea return x.clamp(min=0).sqrt() - def marginal_inverse(self, z): + def factorized_inverse_transform(self, z, cond=None): return z.pow(2) From 3393c67484d14583dc0eecd1b9eebe5575541efd Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 27 Jun 2023 20:50:45 +0000 Subject: [PATCH 03/14] Add factorized log probability functionality --- gensn/flow.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/gensn/flow.py b/gensn/flow.py index 6d5f5cf..bd4ee81 100644 --- a/gensn/flow.py +++ b/gensn/flow.py @@ -21,6 +21,13 @@ def log_prob(self, *obs, cond=None): x, logL = self.transform(*obs, cond=cond) return self.base_distribution.log_prob(*turn_to_tuple(x), cond=cond) + logL + def factorized_log_prob(self, *obs, cond=None): + x, logL = self.transform.factorized_forward(*obs, cond=cond) + return ( + self.base_distribution.factorized_log_prob(*turn_to_tuple(x), cond=cond) + + logL + ) + def sample(self, sample_shape=torch.Size([]), cond=None): samples = self.base_distribution.sample(sample_shape=sample_shape, cond=cond) y, _ = self.transform.inverse(samples, cond=cond) From ee2d3b73a2cde2c0a4f6fad70e8699a1181dc68f Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 27 Jun 2023 20:51:30 +0000 Subject: [PATCH 04/14] Add factorized elbo for vardequant dist and iw bound computation method --- gensn/variational.py | 62 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/gensn/variational.py b/gensn/variational.py index 2d9369b..be8f07b 100644 --- a/gensn/variational.py +++ b/gensn/variational.py @@ -1,3 +1,4 @@ +import numpy as np import torch from torch import nn @@ -35,12 +36,19 @@ def __init__(self, joint, posterior, n_samples=1): def n_rvs(self): return self.joint.n_rvs - self.posterior.n_rvs - def forward(self, *obs, cond=None): - return self.elbo(*obs, cond=cond) + def forward(self, *obs, cond=None, n_samples=None): + return self.elbo(*obs, cond=cond, n_samples=n_samples) - def elbo(self, *obs, cond=None): + def elbo(self, *obs, cond=None, n_samples=None): # TODO: deal with conditioning correctly - return ELBO_joint(self.joint, self.posterior, *obs, n_samples=self.n_samples) + if n_samples is None: + n_samples = self.n_samples + return ELBO_joint( + self.joint, + self.posterior, + *obs, + n_samples=n_samples, + ) def log_prob(self, *obs): # TODO: let this be implemented as an "approximation" with ELBO @@ -77,21 +85,59 @@ def __init__(self, prior, dequantizer, quantizer=None, n_samples=1): def n_rvs(self): return self.prior.n_rvs - def forward(self, *obs, cond=None): - return self.elbo(*obs, cond=cond) + def forward(self, *obs, cond=None, n_samples=None): + return self.elbo(*obs, cond=cond, n_samples=n_samples) - def elbo(self, *obs, cond=None): - z_samples = self.dequantizer.rsample((self.n_samples,), cond=obs) + def elbo(self, *obs, cond=None, n_samples=None): + if n_samples is None: + n_samples = self.n_samples + z_samples = self.dequantizer.rsample((n_samples,), cond=obs) elbo = -self.dequantizer(*turn_to_tuple(z_samples), cond=obs) # TODO: rewrite this so that quantizer can be used as is for joint & elbo elbo += self.prior(*turn_to_tuple(z_samples)) return elbo.mean(dim=0) # average over samples + def factorized_elbo(self, *obs, cond=None, n_samples=None): + if n_samples is None: + n_samples = self.n_samples + z_samples = self.dequantizer.rsample((n_samples,), cond=obs) + log_prob_posterior = self.dequantizer.factorized_log_prob( + *turn_to_tuple(z_samples), cond=obs + ) + log_prob_prior = self.prior.factorized_log_prob(*turn_to_tuple(z_samples)) + elbo = log_prob_prior - log_prob_posterior + return elbo.mean(dim=0) # average over samples + + def iw_bound(self, *obs, cond=None, n_samples=None): + if n_samples is None: + n_samples = self.n_samples + z_samples = self.dequantizer.rsample((n_samples,), cond=obs) + log_prob_posterior = self.dequantizer(*turn_to_tuple(z_samples), cond=obs) + log_prob_prior = self.prior(*turn_to_tuple(z_samples)) + return torch.logsumexp(log_prob_prior - log_prob_posterior, dim=0) - np.log( + n_samples + ) # average over samples + + def factorized_iw_bound(self, *obs, cond=None, n_samples=None): + if n_samples is None: + n_samples = self.n_samples + z_samples = self.dequantizer.rsample((n_samples,), cond=obs) + log_prob_posterior = self.dequantizer.factorized_log_prob( + *turn_to_tuple(z_samples), cond=obs + ) + log_prob_prior = self.prior.factorized_log_prob(*turn_to_tuple(z_samples)) + return torch.logsumexp(log_prob_prior - log_prob_posterior, dim=0) - np.log( + n_samples + ) # average over samples + def log_prob(self, *obs): # TODO: let this be implemented as an "approximation" with ELBO # but with ample warnings pass + def factorized_log_prob(self, *obs): + pass + def sample(self, sample_shape=torch.Size([]), cond=None): samples = self.prior.sample(sample_shape=sample_shape, cond=cond) return self.quantizer(samples) From 3acf902ef19c4fbc1163a9267a8fb0ea9f5c4953 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 27 Jun 2023 20:52:03 +0000 Subject: [PATCH 05/14] Demonstrate factorized log prob functionality --- notebooks/factorized_log_prob.ipynb | 635 ++++++++++++++++++++++++++++ 1 file changed, 635 insertions(+) create mode 100644 notebooks/factorized_log_prob.ipynb diff --git a/notebooks/factorized_log_prob.ipynb b/notebooks/factorized_log_prob.ipynb new file mode 100644 index 0000000..0bf015b --- /dev/null +++ b/notebooks/factorized_log_prob.ipynb @@ -0,0 +1,635 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## This notebook demonstrates the use of the factorized log prob method in gensn" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.distributions as D\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "from gensn.distributions import TrainableDistributionAdapter, Joint\n", + "from gensn.parameters import TransformedParameter\n", + "from gensn.variational import ELBOMarginal, VariationalDequantizedDistribution\n", + "from gensn.flow import FlowDistribution\n", + "import gensn.transforms.invertible as T\n", + "import gensn.distributions as G\n", + "import torch.distributions as D\n", + "\n", + "from gensn.utils import squeeze_tuple, turn_to_tuple\n", + "\n", + "\n", + "seed = 100\n", + "torch.manual_seed(seed);" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Factorized log prob for independent distributions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a 2d independent distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "n_dims = 2\n", + "loc = torch.zeros(n_dims)\n", + "scale = torch.ones(n_dims)\n", + "dist = G.IndependentNormal(loc, scale)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Call log prob and factorized log prob" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.8379) torch.Size([])\n" + ] + } + ], + "source": [ + "x = torch.zeros(n_dims)\n", + "lp = dist.log_prob(x)\n", + "print(lp, lp.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-0.9189, -0.9189]) torch.Size([2])\n" + ] + } + ], + "source": [ + "flp = dist.factorized_log_prob(x)\n", + "print(flp, flp.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.isclose(flp.sum(), lp)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Factorized log prob for flow distributions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### First create factorized transformations" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "affine = T.IndependentAffine(n_dims)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare an initializer\n", + "init_std = 0.1\n", + "def init_module(module):\n", + " if isinstance(module, T.IndependentAffine):\n", + " module.weight.data.normal_(mean=1.0, std=init_std)\n", + " module.bias.data.normal_(std=init_std * 0.1)\n", + "\n", + "affine.apply(init_module);" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Call forward and factorized forward" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-0.0039, 0.0024], grad_fn=) torch.Size([2])\n", + "tensor(0.0064, grad_fn=) torch.Size([])\n" + ] + } + ], + "source": [ + "y, log_det = affine(x)\n", + "print(y, y.shape)\n", + "print(log_det, log_det.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-0.0039, 0.0024], grad_fn=) torch.Size([2])\n", + "tensor([ 0.0354, -0.0290], grad_fn=) torch.Size([2])\n" + ] + } + ], + "source": [ + "fy, flog_det = affine.factorized_forward(x)\n", + "print(fy, fy.shape)\n", + "print(flog_det, flog_det.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.isclose(flog_det.sum(), log_det)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a inverse transform" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "inv_affine = T.InverseTransform(T.IndependentAffine(n_dims))\n", + "\n", + "inv_affine.apply(init_module);" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0.0037, 0.0113], grad_fn=) torch.Size([2])\n", + "tensor(0.4120, grad_fn=) torch.Size([])\n" + ] + } + ], + "source": [ + "y, log_det = inv_affine(x)\n", + "print(y, y.shape)\n", + "print(log_det, log_det.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0.0037, 0.0113], grad_fn=) torch.Size([2])\n", + "tensor([0.1489, 0.2631], grad_fn=) torch.Size([2])\n" + ] + } + ], + "source": [ + "fy, flog_det = inv_affine.factorized_forward(x)\n", + "print(fy, fy.shape)\n", + "print(flog_det, flog_det.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.isclose(flog_det.sum(), log_det)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a sequential transform" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "transform_sequence = [\n", + " T.InverseTransform(T.Softplus()),\n", + " T.IndependentAffine(n_dims),\n", + " T.ELU(),\n", + " T.IndependentAffine(n_dims),\n", + " T.ELU(),\n", + " T.IndependentAffine(n_dims),\n", + "]\n", + "\n", + "sequential = T.SequentialTransform(*transform_sequence)\n", + "\n", + "sequential.apply(init_module);" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-0.5234, -0.6372], grad_fn=) torch.Size([2])\n", + "tensor(-14.7422, grad_fn=) torch.Size([])\n" + ] + } + ], + "source": [ + "y, log_det = sequential(x)\n", + "print(y, y.shape)\n", + "print(log_det, log_det.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-0.5234, -0.6372], grad_fn=) torch.Size([2])\n", + "tensor([-16.1689, 1.4267], grad_fn=) torch.Size([2])\n" + ] + } + ], + "source": [ + "fy, flog_det = sequential.factorized_forward(x)\n", + "print(fy, fy.shape)\n", + "print(flog_det, flog_det.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.isclose(flog_det.sum(), log_det)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create a flow distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "flow_base_dist = G.IndependentNormal(loc, scale)\n", + "flow_dist = FlowDistribution(flow_base_dist, sequential)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-16.9201, grad_fn=) torch.Size([])\n" + ] + } + ], + "source": [ + "lp = flow_dist.log_prob(x)\n", + "print(lp, lp.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-17.2248, 0.3048], grad_fn=) torch.Size([2])\n" + ] + } + ], + "source": [ + "flp = flow_dist.factorized_log_prob(x)\n", + "print(flp, flp.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "assert torch.isclose(flp.sum(), lp)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Factorized elbo for variational dequantized distributions with normalizing flows" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "prior_base_dist = G.IndependentNormal(\n", + " loc=torch.zeros(n_dims), scale=torch.ones(n_dims)\n", + ")\n", + "prior_dist = FlowDistribution(prior_base_dist, sequential)\n", + "\n", + "dequant_base_dist = G.IndependentLaplace(\n", + " loc=torch.zeros(n_dims),\n", + " scale=torch.ones(n_dims),\n", + ")\n", + "dequant_dist = FlowDistribution(dequant_base_dist, sequential)\n", + "\n", + "vdd = VariationalDequantizedDistribution(\n", + " prior=prior_dist,\n", + " dequantizer=dequant_dist,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-0.1279, grad_fn=) torch.Size([])\n" + ] + } + ], + "source": [ + "elbo = vdd(x, n_samples=10_000_000)\n", + "print(elbo, elbo.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-17.2248, 0.3048], grad_fn=) torch.Size([2])\n" + ] + } + ], + "source": [ + "felbo = vdd.factorized_elbo(x, n_samples=10_000_000)\n", + "print(flp, flp.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-0.0010, grad_fn=)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "felbo.sum() - elbo" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[27], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39massert\u001b[39;00m torch\u001b[39m.\u001b[39misclose(felbo\u001b[39m.\u001b[39msum(), elbo)\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "assert torch.isclose(felbo.sum(), elbo)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0.1083, grad_fn=) torch.Size([])\n" + ] + } + ], + "source": [ + "iw_bound = vdd.iw_bound(x, n_samples=10_000_000)\n", + "print(iw_bound, iw_bound.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0.0466, 0.0617], grad_fn=) torch.Size([2])\n" + ] + } + ], + "source": [ + "fiw_bound = vdd.factorized_iw_bound(x, n_samples=10_000_000)\n", + "print(fiw_bound, fiw_bound.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-2.6703e-05, grad_fn=)" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fiw_bound.sum() - iw_bound" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[31], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39massert\u001b[39;00m torch\u001b[39m.\u001b[39misclose(fiw_bound\u001b[39m.\u001b[39msum(), iw_bound)\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "assert torch.isclose(fiw_bound.sum(), iw_bound)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From b92ecf45eba902ad691dee7e1bc996a9417c0a1f Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Wed, 28 Jun 2023 10:32:45 +0000 Subject: [PATCH 06/14] Rollback ELU implementation and add factorized methods for ConditionalShift --- gensn/transforms/invertible.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/gensn/transforms/invertible.py b/gensn/transforms/invertible.py index fad269e..c70a5c3 100644 --- a/gensn/transforms/invertible.py +++ b/gensn/transforms/invertible.py @@ -72,10 +72,16 @@ def forward(self, x, cond=None): x = x + invoke_with_cond(self.conditional_shift, cond=cond) return x, 0 + def factorized_forward(self, x, cond=None): + return self(x, cond=cond) + def inverse(self, x, cond=None): x = x - invoke_with_cond(self.conditional_shift, cond=cond) return x, 0 + def factorized_inverse(self, x, cond=None): + return self.inverse(x, cond=cond) + # # conceptual template # class InvertibleTransform(nn.Module): @@ -144,27 +150,39 @@ def factorized_inverse_transform(self, y, cond=None): class ELU(FactorizedTransform): - def __init__(self, alpha=1.0, offset=0.0, dim=-1): + def __init__(self, alpha=1.0, dim=-1): super().__init__(dim=dim) self.alpha = alpha - self.offset = offset def get_log_det(self, x, cond=None): return torch.where(x > 0, torch.zeros(1).to(x.device), x + math.log(self.alpha)) def factorized_transform(self, x, cond=None): - return F.elu(x, self.alpha) + self.offset + return F.elu(x, self.alpha) def factorized_inverse_transform(self, y, cond=None): # TODO: check if use of finfo is meaningful finfo = torch.finfo(y.dtype) - y = y - self.offset return torch.where( y > 0, y, torch.log((y / self.alpha + 1).clamp(min=finfo.tiny)) ) -class ELUplus1(ELU): +class OffsetELU(nn.Module): + def __init__(self, alpha=1.0, offset=0.0, dim=-1): + super().__init__() + self.elu = ELU(alpha, dim=dim) + self.offset = offset + + def forward(self, x, cond=None): + y, logL = self.elu(x, cond=cond) + return y + self.offset, logL + + def inverse(self, y, cond=None): + return self.elu.inverse(y - self.offset, cond=cond) + + +class ELUplus1(OffsetELU): def __init__(self, alpha=1.0, dim=-1): super().__init__(alpha=alpha, offset=1.0, dim=dim) From 7c357e094a0d9a3716f6697317a48b5768bb66fe Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Mon, 24 Jul 2023 08:57:05 +0000 Subject: [PATCH 07/14] Add independent poisson --- gensn/distributions.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/gensn/distributions.py b/gensn/distributions.py index 8841602..e33a02a 100644 --- a/gensn/distributions.py +++ b/gensn/distributions.py @@ -421,3 +421,33 @@ def __init__(self, loc=None, scale=None, _parameters=None, event_dims=1): **kwargs, _parameters=_parameters, ) + + +class IndependentPoisson(WrappedTrainableDistribution): + """ + A trainable distribution that wraps a D.Independent(D.Poisson) distribution + """ + + def __init__(self, rate=None, _parameters=None, event_dims=1): + """ + Args: + rate : torch.Tensor or nn.Parameter or None + The rate parameter of the poisson distribution. If None, _parameters must be provided + _parameters : callable or None + A function that takes in the conditioning variable and returns a dictionary of parameters + for the poisson distribution. If None, rate must be provided + event_dims : int + The number of dimensions to be considered as the event dimensions + """ + super().__init__() + if rate is None and _parameters is None: + raise ValueError("If rate is unspecificed, _parameters must be provided") + kwargs = {} + if rate is not None: + kwargs["rate"] = rate + self.trainable_distribution = IndependentTrainableDistributionAdapter( + D.Poisson, + event_dims=event_dims, + **kwargs, + _parameters=_parameters, + ) From 04df8625fef48a186490008e850317d4bdcd6103 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 20 Feb 2024 12:37:26 +0000 Subject: [PATCH 08/14] Add dtype specific eps to prevent numerical underflow --- gensn/parameters.py | 65 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/gensn/parameters.py b/gensn/parameters.py index 748bb18..c944657 100644 --- a/gensn/parameters.py +++ b/gensn/parameters.py @@ -3,6 +3,27 @@ class TransformedParameter(nn.Module): + """ + A module for applying a transformation function to a torch.nn.Parameter. + This can be useful for ensuring that a parameter adheres to certain constraints + (e.g., positivity) after transformation. The forward method applies the transformation + function to the parameter. + + Attributes: + parameter (nn.Parameter): The parameter to be transformed. + transform_fn (Callable): The transformation function to be applied to the parameter. + value (torch.Tensor): The transformed parameter value. + + Args: + tensor (torch.Tensor): The initial tensor to be wrapped as an nn.Parameter. + transform_fn (Callable, optional): The transformation function to be applied. + If None, the identity function is used, meaning no transformation is applied. + + Returns: + torch.Tensor: The transformed parameter as a tensor, with shape dependent on the + transformation function and the initial tensor shape. + """ + def __init__(self, tensor, transform_fn=None): super().__init__() self.parameter = nn.Parameter(tensor) @@ -19,14 +40,34 @@ def forward(self, *args): class Covariance(nn.Module): - def __init__(self, n_dims, rank=None, eps=1e-16): + """ + A module to represent a covariance matrix as a parameterized entity in a neural network. + This implementation ensures the covariance matrix is positive semi-definite by constructing + it as A @ A.T + epsilon * I, where A is a parameter matrix, and epsilon is a small positive + constant added for numerical stability. + + Attributes: + A (nn.Parameter): The parameter matrix used to construct the covariance matrix. + eps (float): A small positive constant added to the diagonal for numerical stability. + value (torch.Tensor): The covariance matrix. + + Args: + n_dims (int): The dimensionality of the square covariance matrix. + rank (int, optional): The rank of the matrix A used in constructing the covariance matrix. + If None, it defaults to n_dims, resulting in a full-rank covariance matrix. + + Returns: + torch.Tensor: The covariance matrix, with shape (n_dims, n_dims). + """ + + def __init__(self, n_dims, rank=None): super().__init__() if rank is None: rank = n_dims self.n_dims = n_dims self.rank = rank - self.eps = eps self.A = nn.Parameter(torch.randn(n_dims, rank)) + self.eps = torch.finfo(self.A.dtype).eps def forward(self, *args): return self.A @ self.A.T + torch.eye(self.n_dims) * self.eps @@ -38,11 +79,29 @@ def value(self): # TODO: generalize this so that positiveness can arise from other functions class PositiveDiagonal(nn.Module): + """ + A module for representing a diagonal matrix with positive diagonal elements. This is achieved + by squaring the elements of a parameter vector D and adding a small positive constant epsilon + to each squared element for numerical stability. + + Attributes: + D (nn.Parameter): The parameter vector whose squared elements form the diagonal of the matrix. + eps (float): A small positive constant added to each element of the squared D for numerical stability. + value (torch.Tensor): The resulting diagonal matrix with positive diagonal elements. + + Args: + n_dims (int): The dimensionality of the square diagonal matrix. + eps (float, optional): A small positive constant added for numerical stability. Defaults to 1e-16. + + Returns: + torch.Tensor: The diagonal matrix with positive diagonal elements, with shape (n_dims, n_dims). + """ + def __init__(self, n_dims, eps=1e-16): super().__init__() self.n_dims = n_dims - self.eps = eps self.D = nn.Parameter(torch.randn(n_dims)) + self.eps = torch.finfo(self.D.dtype).eps def forward(self, *args): return torch.diag(self.D**2 + self.eps) From df28be397da9761e3750bf0d2531e0c8f2ef2861 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Wed, 22 May 2024 18:25:20 +0000 Subject: [PATCH 09/14] Add dense affine --- gensn/transforms/invertible.py | 57 ++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/gensn/transforms/invertible.py b/gensn/transforms/invertible.py index c70a5c3..0f5363a 100644 --- a/gensn/transforms/invertible.py +++ b/gensn/transforms/invertible.py @@ -130,25 +130,6 @@ def inverse(self, y, cond=None): return x, log_det.sum(dim=self.dim) -class IndependentAffine(FactorizedTransform): - def __init__(self, input_dim=1, dim=-1): - super().__init__(dim=dim) - self.input_dim = input_dim - self.weight = nn.Parameter(torch.empty(input_dim)) - self.bias = nn.Parameter(torch.empty(input_dim)) - - def get_log_det(self, x, cond=None): - return torch.log( - abs(self.weight) + torch.finfo(self.weight.dtype).tiny - ) * torch.ones_like(x) - - def factorized_transform(self, x, cond=None): - return x * self.weight + self.bias - - def factorized_inverse_transform(self, y, cond=None): - return (y - self.bias) / (self.weight + torch.finfo(self.weight.dtype).tiny) - - class ELU(FactorizedTransform): def __init__(self, alpha=1.0, dim=-1): super().__init__(dim=dim) @@ -283,3 +264,41 @@ def factorized_transform(self, x, cond=None): def factorized_inverse_transform(self, z, cond=None): return z.pow(2) + + +class IndependentAffine(FactorizedTransform): + def __init__(self, input_dim=1, dim=-1): + super().__init__(dim=dim) + self.input_dim = input_dim + self.weight = nn.Parameter(torch.empty(input_dim)) + self.bias = nn.Parameter(torch.empty(input_dim)) + + def get_log_det(self, x, cond=None): + return torch.log( + abs(self.weight) + torch.finfo(self.weight.dtype).tiny + ) * torch.ones_like(x) + + def factorized_transform(self, x, cond=None): + return x * self.weight + self.bias + + def factorized_inverse_transform(self, y, cond=None): + return (y - self.bias) / (self.weight + torch.finfo(self.weight.dtype).tiny) + + +class Affine(nn.Module): + def __init__(self, input_dim=1): + super().__init__() + self.input_dim = input_dim + self.weight = nn.Parameter(torch.empty(input_dim, input_dim)) + self.bias = nn.Parameter(torch.empty(input_dim)) + + def get_log_det(self, x, cond=None): + return torch.slogdet(self.weight).logabsdet * torch.ones(x.shape[:-1]) + + def forward(self, x, cond=None): + return x @ self.weight + self.bias, self.get_log_det(x, cond=cond) + + def inverse(self, y, cond=None): + return (y - self.bias) @ torch.inverse(self.weight), -self.get_log_det( + y, cond=cond + ) From 1250b6cb0dbf70663457adf42647388aaf87d6dd Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Fri, 31 May 2024 15:47:13 +0000 Subject: [PATCH 10/14] Add affine coupling and leaky relu layers --- gensn/distributions.py | 9 ++---- gensn/transforms/invertible.py | 54 ++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/gensn/distributions.py b/gensn/distributions.py index 23bf5d2..c693d93 100644 --- a/gensn/distributions.py +++ b/gensn/distributions.py @@ -62,19 +62,16 @@ def n_rvs(self): ... @abstractmethod - def log_prob(self, *obs, cond=None): - ... + def log_prob(self, *obs, cond=None): ... def forward(self, *obs, cond=None): return self.log_prob(*obs, cond=cond) @abstractmethod - def sample(self, sample_shape=torch.Size([]), cond=None): - ... + def sample(self, sample_shape=torch.Size([]), cond=None): ... @abstractmethod - def rsample(self, sample_shape=torch.Size([]), cond=None): - ... + def rsample(self, sample_shape=torch.Size([]), cond=None): ... class TrainableDistributionAdapter(nn.Module): diff --git a/gensn/transforms/invertible.py b/gensn/transforms/invertible.py index 0f5363a..8d1720c 100644 --- a/gensn/transforms/invertible.py +++ b/gensn/transforms/invertible.py @@ -223,6 +223,24 @@ def factorized_inverse_transform(self, y, cond=None): return y.log() - (-y).log1p() +class LeakyReLU(FactorizedTransform): + def __init__(self, alpha=0.01, dim=-1): + super().__init__(dim=dim) + self.alpha = alpha + + def get_log_det(self, x, cond=None): + # Log determinant is log(1) for x > 0 and log(alpha) for x <= 0 + return torch.where( + x > 0, torch.zeros_like(x), torch.log(torch.full_like(x, self.alpha)) + ) + + def factorized_transform(self, x, cond=None): + return F.leaky_relu(x, self.alpha) + + def factorized_inverse_transform(self, y, cond=None): + return torch.where(y > 0, y, y / self.alpha) + + class Log(FactorizedTransform): def get_log_det(self, x, cond=None): return -torch.log(abs(x).clamp(min=torch.finfo(x.dtype).tiny)) @@ -298,7 +316,43 @@ def get_log_det(self, x, cond=None): def forward(self, x, cond=None): return x @ self.weight + self.bias, self.get_log_det(x, cond=cond) + # TODO: use QR decomposition for more stable inversion def inverse(self, y, cond=None): return (y - self.bias) @ torch.inverse(self.weight), -self.get_log_det( y, cond=cond ) + + +class AffineCoupling(nn.Module): + def __init__(self, transform, mask): + super().__init__() + self.transform = transform + self.mask = mask + + def forward(self, x, cond=None): + x0, x1 = x[:, self.mask], x[:, ~self.mask] + y0 = x0 + shift, log_scale = self.transform(x0, cond=cond) + y1 = x1 * log_scale.exp() + shift + return torch.cat([y0, y1], dim=1), log_scale + + def inverse(self, y, cond=None): + y0, y1 = y[:, self.mask], y[:, ~self.mask] + x0 = y0 + shift, log_scale = self.transform(y0, cond=cond) + x1 = (y1 - shift) / log_scale.exp() + return torch.cat([x0, x1], dim=1), log_scale + + def factorized_forward(self, x, cond=None): + x0, x1 = x[:, self.mask], x[:, ~self.mask] + y0 = x0 + shift, log_scale = self.transform.factorized_forward(x0, cond=cond) + y1 = x1 * log_scale.exp() + shift + return torch.cat([y0, y1], dim=1), log_scale + + def factorized_inverse(self, y, cond=None): + y0, y1 = y[:, self.mask], y[:, ~self.mask] + x0 = y0 + shift, log_scale = self.transform.factorized_inverse(y0, cond=cond) + x1 = (y1 - shift) / log_scale.exp() + return torch.cat([x0, x1], dim=1), log_scale From 194da4ffc60e12aa3e00b3f6a7a6ea4d456f1370 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Wed, 10 Jul 2024 19:25:25 +0000 Subject: [PATCH 11/14] Improve numerical stability of Covariance --- gensn/parameters.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gensn/parameters.py b/gensn/parameters.py index c944657..923d40b 100644 --- a/gensn/parameters.py +++ b/gensn/parameters.py @@ -66,11 +66,17 @@ def __init__(self, n_dims, rank=None): rank = n_dims self.n_dims = n_dims self.rank = rank - self.A = nn.Parameter(torch.randn(n_dims, rank)) - self.eps = torch.finfo(self.A.dtype).eps + self.A = nn.Parameter( + torch.randn(n_dims, rank) * 0.1 + ) # Initialize with small values to achieve positive definiteness + # self.eps = torch.finfo(self.A.dtype).eps + self.eps = 1e-4 # For numerical stability of positive definiteness def forward(self, *args): - return self.A @ self.A.T + torch.eye(self.n_dims) * self.eps + return ( + self.A @ self.A.T + + torch.eye(self.n_dims).to(device=self.A.device) * self.eps + ) @property def value(self): From 4e4f45784e1e27a0867f2f0a664caa3cbf146619 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Wed, 10 Jul 2024 19:26:21 +0000 Subject: [PATCH 12/14] Bug fix: add element to device --- gensn/transforms/invertible.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gensn/transforms/invertible.py b/gensn/transforms/invertible.py index 8d1720c..d1c9123 100644 --- a/gensn/transforms/invertible.py +++ b/gensn/transforms/invertible.py @@ -245,7 +245,7 @@ class Log(FactorizedTransform): def get_log_det(self, x, cond=None): return -torch.log(abs(x).clamp(min=torch.finfo(x.dtype).tiny)) - def factorized_forward(self, x, cond=None): + def factorized_transform(self, x, cond=None): return x.clamp(min=torch.finfo(x.dtype).tiny).log() def factorized_inverse_transform(self, y, cond=None): @@ -311,7 +311,9 @@ def __init__(self, input_dim=1): self.bias = nn.Parameter(torch.empty(input_dim)) def get_log_det(self, x, cond=None): - return torch.slogdet(self.weight).logabsdet * torch.ones(x.shape[:-1]) + return torch.slogdet(self.weight).logabsdet * torch.ones(x.shape[:-1]).to( + device=x.device + ) def forward(self, x, cond=None): return x @ self.weight + self.bias, self.get_log_det(x, cond=cond) From cf00b9964a62250782e4f03ddab7125819bcac27 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Tue, 13 Aug 2024 11:19:03 +0000 Subject: [PATCH 13/14] Add general variational bound class --- gensn/variational.py | 162 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/gensn/variational.py b/gensn/variational.py index be8f07b..2de8fc5 100644 --- a/gensn/variational.py +++ b/gensn/variational.py @@ -145,3 +145,165 @@ def sample(self, sample_shape=torch.Size([]), cond=None): def rsample(self, sample_shape=torch.Size([]), cond=None): samples = self.prior.rsample(sample_shape=sample_shape, cond=cond) return self.quantizer(samples) + + +def iw_joint(joint, posterior, *obs, n_samples=1): + """ + Estimate the expectation using importance sampling to compute the importance weighted bound (IW bound). + + Args: + joint (callable): The joint distribution function `p(x, z)`. + posterior (callable): The posterior distribution function `q(z|x)`. + obs (tuple): Observations `x`, passed as variable length arguments. + n_samples (int, optional): The number of samples to draw for importance sampling. Default is 1. + + Returns: + torch.Tensor: The log of the estimated expectation using importance sampling. + """ + # Draw samples from the posterior distribution + z_samples = posterior.rsample((n_samples,), cond=obs) + + # Compute the log probability under the posterior distribution + log_prob_posterior = posterior(*turn_to_tuple(z_samples), cond=obs) + + # Compute the log probability under the joint distribution + log_prob_joint = joint(*turn_to_tuple(z_samples), *obs) + + # Compute the importance weighted bound + return torch.logsumexp(log_prob_joint - log_prob_posterior, dim=0) - np.log( + n_samples + ) + + +class VariationalBound(nn.Module): + """ + A class to compute various variational bounds on the log-likelihood, including ELBO and IW bound. + + Args: + joint (callable): The joint distribution function `p(x, z)`. + posterior (callable): The posterior distribution function `q(z|x)`. + bound_type (str, optional): The type of variational bound to compute. Options are "elbo" and "iw". + Default is "elbo". + n_samples (int, optional): The number of samples to draw for the IW bound. Default is 1. + + Attributes: + joint (callable): The joint distribution function `p(x, z)`. + posterior (callable): The posterior distribution function `q(z|x)`. + n_samples (int): The number of samples to draw for the IW bound. + bound_type (str): The type of variational bound to compute. + bound_fn (callable): The function used to compute the variational bound. + """ + + def __init__(self, joint, posterior, bound_type="elbo", n_samples=1): + super().__init__() + self.joint = joint + self.posterior = posterior + self.n_samples = n_samples + self.bound_type = bound_type + + if bound_type == "elbo": + self.bound_fn = self.elbo + elif bound_type == "iw": + self.bound_fn = self.iw_bound + elif bound_type == "renyi": + raise NotImplementedError("Renyi bound not implemented yet") + elif bound_type == "forward_kl": + raise NotImplementedError("Forward KL bound not implemented yet") + else: + raise ValueError("Unknown bound type") + + @property + def n_rvs(self): + """ + Compute the number of random variables (RVs) in the evidence + + Returns: + int: The number of random variables in the evidence. + """ + return self.joint.n_rvs - self.posterior.n_rvs + + def forward(self, *obs, cond=None, n_samples=None): + """ + Forward pass to compute the selected variational bound. + + Args: + obs (tuple): Observations `x`, passed as variable length arguments. + cond (optional): Conditioning variables for the bound function. + n_samples (int, optional): The number of samples to draw for the IW bound. Default is `self.n_samples`. + + Returns: + torch.Tensor: The computed variational bound. + """ + return self.bound_fn(*obs, cond=cond, n_samples=n_samples) + + def elbo(self, *obs, cond=None, n_samples=None): + """ + Compute the Evidence Lower Bound (ELBO). + + Args: + obs (tuple): Observations `x`, passed as variable length arguments. + cond (optional): Conditioning variables for the ELBO. + n_samples (int, optional): The number of samples to draw. Default is `self.n_samples`. + + Returns: + torch.Tensor: The computed ELBO. + """ + if n_samples is None: + n_samples = self.n_samples + + return ELBO_joint( + self.joint, + self.posterior, + *obs, + n_samples=n_samples, + ) + + def iw_bound(self, *obs, cond=None, n_samples=None): + """ + Compute the Importance Weighted (IW) bound. + + Args: + obs (tuple): Observations `x`, passed as variable length arguments. + cond (optional): Conditioning variables for the IW bound. + n_samples (int, optional): The number of samples to draw. Default is `self.n_samples`. + + Returns: + torch.Tensor: The computed IW bound. + """ + if n_samples is None: + n_samples = self.n_samples + + return iw_joint( + self.joint, + self.posterior, + *obs, + n_samples=n_samples, + ) + + def sample(self, sample_shape=torch.Size([]), cond=None): + """ + Sample from the joint distribution `p(x, z)`. + + Args: + sample_shape (torch.Size, optional): The shape of the sample. Default is an empty shape. + cond (optional): Conditioning variables for the joint distribution. + + Returns: + tuple: The sampled variables corresponding to the random variables in the evidence. + """ + samples = self.joint.sample(sample_shape=sample_shape, cond=cond) + return squeeze_tuple(samples[-self.n_rvs :]) + + def rsample(self, sample_shape=torch.Size([]), cond=None): + """ + Reparameterized sample from the joint distribution `p(x, z)`. + + Args: + sample_shape (torch.Size, optional): The shape of the sample. Default is an empty shape. + cond (optional): Conditioning variables for the joint distribution. + + Returns: + tuple: The reparameterized sampled variables corresponding to the random variables for the evidence. + """ + samples = self.joint.rsample(sample_shape=sample_shape, cond=cond) + return squeeze_tuple(samples[-self.n_rvs :]) From 484ad7f8fe5c543ac2668106fb11879aa24c55c4 Mon Sep 17 00:00:00 2001 From: Suhas Shrinivasan Date: Wed, 14 Aug 2024 11:37:09 +0000 Subject: [PATCH 14/14] Add bound compute method based on string arg --- gensn/variational.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/gensn/variational.py b/gensn/variational.py index 2de8fc5..05826e9 100644 --- a/gensn/variational.py +++ b/gensn/variational.py @@ -280,6 +280,32 @@ def iw_bound(self, *obs, cond=None, n_samples=None): n_samples=n_samples, ) + def compute_bound(self, *obs, cond=None, n_samples=None, bound_type=None): + """ + Compute the selected variational bound based on string input. + + Args: + obs (tuple): Observations `x`, passed as variable length arguments. + cond (optional): Conditioning variables for the bound function. + n_samples (int, optional): The number of samples to draw for the IW bound. Default is `self.n_samples`. + bound_type (str, optional): The type of variational bound to compute. + + Returns: + torch.Tensor: The computed variational bound. + """ + if bound_type is None: + bound_type = self.bound_type + if bound_type == "elbo": + return self.elbo(*obs, cond=cond, n_samples=n_samples) + elif bound_type == "iw": + return self.iw_bound(*obs, cond=cond, n_samples=n_samples) + elif bound_type == "renyi": + raise NotImplementedError("Renyi bound not implemented yet") + elif bound_type == "forward_kl": + raise NotImplementedError("Forward KL bound not implemented yet") + else: + raise ValueError("Unknown bound type") + def sample(self, sample_shape=torch.Size([]), cond=None): """ Sample from the joint distribution `p(x, z)`.