-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Description
Consider the following code snippet. It describes a model which uses the primitive pyro.deterministic and conditions on that variable. But then if the log_probs are queried, they do not reflect the conditioning. This has been remedied in Chirho using KernelSoftConditionReparam and AutoSoftConditioning but the following code should be throwing an explicit error to prevent users from not reparameterizing.
import pyro
import torch
import pyro.distributions as dist
from chirho.observational.handlers import condition
def model():
X = pyro.sample("X", dist.Bernoulli(0.5))
Y = pyro.deterministic("Y", X)
return {"X": X, "Y": Y}
conditioned_model = condition(data={"Y": torch.tensor(1.0)})(model)
with pyro.plate("sample", size=5):
with pyro.poutine.trace() as tr:
conditioned_model()
tr.trace.compute_log_prob()
print("X values: ", tr.trace.nodes["X"]["value"])
print("Y log_probs: ", tr.trace.nodes["Y"]["log_prob"])
Output:
X values: tensor([1., 1., 0., 0., 1.])
Y log_probs: tensor([0., 0., 0., 0., 0.])
Reactions are currently unavailable