Adding support for load_state_dict with assign=True for priors of Transformed distributions#2691
Adding support for load_state_dict with assign=True for priors of Transformed distributions#2691hvarfner wants to merge 2 commits intocornellius-gp:mainfrom
Conversation
gpytorch/priors/prior.py
Outdated
| tensor_value = ( | ||
| value if isinstance(value, torch.Tensor) else torch.as_tensor(value) | ||
| ) |
There was a problem hiding this comment.
should be a noop if already a tensor...
| tensor_value = ( | |
| value if isinstance(value, torch.Tensor) else torch.as_tensor(value) | |
| ) | |
| tensor_value = torch.as_tensor(value) |
| # Update the base attribute in the base distribution | ||
| self.base_dist.__setattr__(base_attr_name, tensor_value) | ||
| # Update the transformed attribute as well | ||
| super().__setattr__(name, tensor_value) |
There was a problem hiding this comment.
Wouldn't we have to save the untransformed value for the attribute of the base distribution here? Seems odd that we assign the same value to both base dist and dist...
There was a problem hiding this comment.
@Balandat _transformed_ is just a buffer copy of the base attribute, indicating that it comes from a torch.distributions.TransformedDistribution, not a transformed version. I could change the name to something like _buffered_ instead.
There was a problem hiding this comment.
Recall that the issue is that the base attributes (loc, scale etc.) are a @property of the base distribution, e.g. torch.distributions.Normal so we can't bufferize these on the LogNormalPrior. Thus, we need to bufferize an attribute containing the same info, that we can then use to set the loc and scale on the base distribution.
There was a problem hiding this comment.
@Balandat Added a note on this in a new version:
# Note: "_transformed_" is just an indicator that this attribute belongs to a
# TransformedDistribution, the value itself is not transformed.
6ce060b to
5c174a1
Compare
| buffered_attrs = [attr for attr in dir(module) if buffered_str in attr] | ||
| for buffered_attr in buffered_attrs: | ||
| base_attr_name = buffered_attr.replace(buffered_str, "") | ||
| setattr(module.base_dist, base_attr_name, getattr(module, buffered_attr)) |
There was a problem hiding this comment.
no need to traverse twice here
| buffered_attrs = [attr for attr in dir(module) if buffered_str in attr] | |
| for buffered_attr in buffered_attrs: | |
| base_attr_name = buffered_attr.replace(buffered_str, "") | |
| setattr(module.base_dist, base_attr_name, getattr(module, buffered_attr)) | |
| for attr in dir(module): | |
| if buffered_str in attr: | |
| base_attr_name = attr.replace(buffered_str, "") | |
| setattr(module.base_dist, base_attr_name, getattr(module, attr)) |
| # TransformedDistribution, NOT that the value itself is transformed. | ||
| # The _buffered_ buffer is simply a copy of the base_dist attribute, | ||
| # so we assign the same value to both. | ||
| if hasattr(self, name) and "_buffered_" in name: |
There was a problem hiding this comment.
let's make "_buffered_" a constant so that there is a single source of truth
Required changes in GPytorch to unblock meta-pytorch/botorch#3080. When
load_state_dictis called withassign=True,setattris called on_transformedattributes of the prior at the pytorch level.This was not the intended use of the
_transformedattribute, but it seems like we have to enable its modification directly.