Replies: 5 comments 4 replies
-
|
Tagging @mamagarobonomon as it may come relevant for his project. |
Beta Was this translation helpful? Give feedback.
-
|
I'm not sure what the problem is. Have you tried debugging by hard-coding F and Q to constants? |
Beta Was this translation helpful? Give feedback.
-
|
@murphyk I've fixed the issue! As soon as we impose the structured noise via R matrix, the predictions become dramatically better. I'm really happy with the results — it converges nicely in just 10 iterations (less than 2 sec on my mac). D = 6
H = [1.0, 1.0, 0.0, 1.0, 0.0, 1.0]
X = [[temp] for temp in temperature]
function transition(F)
FT = eltype(F)
M = zeros(FT, 6, 6)
M[1,1] = one(FT)
M[2,2] = F[1]; M[2,3] = F[2]; M[3,2] = -F[2]; M[3,3] = F[1]
M[4,4] = F[3]; M[4,5] = F[4]; M[5,4] = -F[4]; M[5,5] = F[3]
M[6,6] = F[5]
return M
end
@model function rxsts(H, X, y, R, priors)
τy ~ priors[:τy]
β ~ priors[:β]
Q ~ Wishart(priors[:Q].df, priors[:Q].S)
η ~ MvNormalMeanPrecision(mean(priors[:η]), Q)
zprev ~ priors[:z0]
F ~ priors[:F]
for t in eachindex(y)
z₁[t] ~ ContinuousTransition(zprev, F, diageye(D))
z₂[t] ~ R*η
z[t] ~ z₁[t] + z₂[t]
μ[t] ~ dot(H, z[t]) + dot(X[t], β)
y[t] ~ Normal(mean = μ[t], precision = τy)
zprev = z[t]
end
end
@constraints function rxsts_constraints()
q(z, z₁, z₂, zprev, F, Q, η, μ, y, τy, β) = q(z, z₁, z₂, zprev)q(F)q(Q)q(η)q(μ, y)q(τy)q(β)
end
@meta function rxsts_meta()
ContinuousTransition() -> CTMeta(transition)
end
R = [
1 0 0 0
0 1 0 0
0 0 0 0
0 0 1 0
0 0 0 0
0 0 0 1
]
priors = Dict(
:τy => GammaShapeRate(10.0, 1.0),
:β => MvNormalMeanPrecision(ones(1), diageye(1)),
:z0 => MvNormalMeanPrecision(ones(D), diageye(D)),
:F => MvNormalMeanPrecision([1.0, 1.0, 1.0, 1.0, 1.0], diageye(5)),
:Q => Wishart(4, diagm([1.0, 1.0, 1.0, 1.0])),
:η => MvNormalMeanPrecision(zeros(4), diageye(4))
)
@initialization function rxsts_init(priors)
q(τy) = priors[:τy]
q(F) = priors[:F]
q(Q) = priors[:Q]
q(η) = priors[:η]
μ(β) = priors[:β]
μ(zprev) = priors[:z0]
μ(z) = priors[:z0]
end
n_predict = Int(round(length(demand)*0.1)) # 10%
results = infer(
model = rxsts(H=H, X=X, R=R, priors=priors),
data = (y = [demand[1:end-n_predict]; repeat([missing], n_predict)],),
constraints = rxsts_constraints(),
meta = rxsts_meta(),
initialization = rxsts_init(priors),
returnvars = KeepLast(),
iterations = 10,
showprogress = true,
options = (limit_stack_depth = 500,)
)
@Nimrais this example can be used to check this issue in future #570. |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
|
BTW, now that the Gaussian case works, it would be cool try some non-conjugate likelihood, such as Poisson Just for fun, you could also try a nonlinear link function (eg a little neural network) to represent p(y|z)= Cat(y|softmax(MLP(z))), And a version of causal impact would also be cool: |
Beta Was this translation helpful? Give feedback.




Uh oh!
There was an error while loading. Please reload this page.
-
Hi! I've been exploring the excellent sts-jax repo by @xinglong-li and @murphyk, and thought it would be interesting to implement something similar in RxInferExamples.jl with a twist: learning the transition matrix F parameters rather than fixing them.
I started with electricity demand example but added priors on the seasonal frequencies and AR coefficient. However, my forecasts are quite off compared to what I'd expect (see attached plot).
Model structure:
[level, daily_cos, daily_sin, weekly_cos, weekly_sin, ar]F ~ MvNormal([1.0, 0.1, 1.0, 0.1, -1.0], Σ_F)QR ~ Wishart(8, diag([1e2, 1.0, 1.0, 1.0, 1.0, 1.0]))Suspected issue: In the original sts-jax implementation, you use a selection matrix R to create structured covariance
R*Q*R'. I'm using an unstructuredQR ~ Wishart(...)directly, which might be causing the level component to dominate while seasonals collapse. I tried to work around this with diagonal scaling in the Wishart prior, but it might not be enough.Dataset: electricity data gist
Minimal code:
Any thoughts on what might be going wrong?
I think supporting this type of models would be great as we could run different what-if scenarios!
P.S. I'm also concerned with the number of iterations needed here...
Beta Was this translation helpful? Give feedback.
All reactions