Skip to content

Absence of reparameterization for pyro.deterministic goes unnoticed #560

@PoorvaGarg

Description

@PoorvaGarg

Consider the following code snippet. It describes a model which uses the primitive pyro.deterministic and conditions on that variable. But then if the log_probs are queried, they do not reflect the conditioning. This has been remedied in Chirho using KernelSoftConditionReparam and AutoSoftConditioning but the following code should be throwing an explicit error to prevent users from not reparameterizing.

import pyro
import torch
import pyro.distributions as dist
from chirho.observational.handlers import condition

def model():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.deterministic("Y", X)
    return {"X": X, "Y": Y}

conditioned_model = condition(data={"Y": torch.tensor(1.0)})(model)

with pyro.plate("sample", size=5):
    with pyro.poutine.trace() as tr:
        conditioned_model()

tr.trace.compute_log_prob()
print("X values: ", tr.trace.nodes["X"]["value"])
print("Y log_probs: ", tr.trace.nodes["Y"]["log_prob"])

Output:

X values:  tensor([1., 1., 0., 0., 1.])
Y log_probs:  tensor([0., 0., 0., 0., 0.])

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions