diff --git a/src/datafit.jl b/src/datafit.jl index b40fb0f..12bf499 100644 --- a/src/datafit.jl +++ b/src/datafit.jl @@ -1,3 +1,23 @@ +""" + _pprior_samples(chain, i) + +Extract the posterior samples of `pprior[i]` from a Turing sampling result, supporting +both chain backends across Turing versions: the `FlexiChains.VNChain` returned by newer +Turing (indexed by `@varname(pprior[i])`) and the legacy `MCMCChains.Chains` +(indexed by the `"pprior[i]"` string key). +""" +function _pprior_samples(chain, i) + vn = @varname(pprior[i]) + samples = try + chain[vn] + catch err + err isa Union{MethodError, KeyError, ArgumentError} || + rethrow(err) + chain["pprior[" * string(i) * "]"] + end + return collect(samples)[:] +end + function l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}}) p = Pair.(pkeys, pvals) prob = remake(prob, tspan = (prob.tspan[1], t[end]), p = p) @@ -6,7 +26,7 @@ function l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}}) for pairs in data tot_loss += sum((sol[pairs.first] .- pairs.second) .^ 2) end - return tot_loss, sol + return tot_loss end function l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}}) @@ -22,7 +42,7 @@ function l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}}) for i in 1:length(ts) tot_loss += sum((sol(ts[i]; idxs = datakeys[i]) .- timeseries[i]) .^ 2) end - return tot_loss, sol + return tot_loss end function relative_l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}}) @@ -33,7 +53,7 @@ function relative_l2loss(pvals, (prob, pkeys, t, data)::Tuple{Vararg{Any, 4}}) for pairs in data tot_loss += sum(((sol[pairs.first] .- pairs.second) ./ sol[pairs.first]) .^ 2) end - return tot_loss, sol + return tot_loss end function relative_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}}) @@ -50,7 +70,7 @@ function relative_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}}) vals = sol(ts[i]; idxs = datakeys[i]) tot_loss += sum(((vals .- timeseries[i]) ./ vals) .^ 2) end - return tot_loss, sol + return tot_loss end """ @@ -223,7 +243,7 @@ Turing.@model function bayesianODE(prob, t, pdist, pkeys, data, noise_prior) prob = remake(prob, tspan = (prob.tspan[1], t[end]), p = Pair.(pkeys, pprior)) sol = solve(prob, saveat = t) if !SciMLBase.successful_retcode(sol) - Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) + Turing.@addlogprob! -Inf return nothing end for i in eachindex(data) @@ -249,7 +269,7 @@ Turing.@model function bayesianODE( prob = remake(prob, tspan = (prob.tspan[1], lastt), p = Pair.(pkeys, pprior)) sol = solve(prob) if !SciMLBase.successful_retcode(sol) - Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf) + Turing.@addlogprob! -Inf return nothing end for i in eachindex(datakeys) @@ -308,7 +328,7 @@ function bayesian_datafit( progress = true ) return [ - Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:]) + Pair(p[i].first, _pprior_samples(chain, i)) for i in eachindex(p) ] end @@ -333,7 +353,7 @@ function bayesian_datafit( progress = true ) return [ - Pair(p[i].first, collect(chain["pprior[" * string(i) * "]"])[:]) + Pair(p[i].first, _pprior_samples(chain, i)) for i in eachindex(p) ] end diff --git a/src/ensemble.jl b/src/ensemble.jl index ded0389..77e6e99 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -19,7 +19,7 @@ dataset on which the ensembler should be trained on. function ensemble_weights(sol::EnsembleSolution, data_ensem) obs = first.(data_ensem) predictions = reduce( - vcat, reduce(hcat, [sol[i][s] for i in 1:length(sol)]) for s in obs + vcat, reduce(hcat, [sol.u[i][s] for i in 1:length(sol.u)]) for s in obs ) data = reduce( vcat, @@ -31,6 +31,23 @@ function ensemble_weights(sol::EnsembleSolution, data_ensem) return weights = predictions \ data end +""" + EnsembleProbForwarder(all_probs) + +Callable used as the `prob_func` of the `EnsembleProblem` returned by +[`bayesian_ensemble`](@ref). It selects the per-trajectory problem from the stored +`all_probs` vector. It supports both the `prob_func(prob, ctx)` interface of newer +SciMLBase (selecting via `ctx.sim_id`) and the legacy `prob_func(prob, i, repeat)` +interface (selecting via the integer index). Storing `all_probs` lets callers recover +the number of trajectories via `enprob.prob_func.all_probs`. +""" +struct EnsembleProbForwarder{P} + all_probs::P +end + +(f::EnsembleProbForwarder)(prob, i::Integer, repeat) = f.all_probs[i] +(f::EnsembleProbForwarder)(prob, ctx) = f.all_probs[ctx.sim_id] + function bayesian_ensemble( probs, ps, datas; noise_prior = InverseGamma(2, 3), @@ -56,5 +73,7 @@ function bayesian_ensemble( @info "$(length(all_probs)) total models" - return enprob = EnsembleProblem(all_probs) + return enprob = EnsembleProblem( + all_probs[1]; prob_func = EnsembleProbForwarder(all_probs) + ) end diff --git a/src/sensitivity.jl b/src/sensitivity.jl index ea52199..4f5f0ed 100644 --- a/src/sensitivity.jl +++ b/src/sensitivity.jl @@ -2,7 +2,8 @@ function _get_sensitivity(prob, t, x, pbounds; samples) boundvals = getfield.(pbounds, :second) boundkeys = getfield.(pbounds, :first) f = function (p) - prob_func(prob, i, repeat) = remake(prob; p = Pair.(boundkeys, p[:, i])) + prob_func(prob, i::Integer, repeat) = remake(prob; p = Pair.(boundkeys, p[:, i])) + prob_func(prob, ctx) = remake(prob; p = Pair.(boundkeys, p[:, ctx.sim_id])) ensemble_prob = EnsembleProblem(prob, prob_func = prob_func) sol = solve( ensemble_prob, nothing, EnsembleThreads(); saveat = t, @@ -11,11 +12,11 @@ function _get_sensitivity(prob, t, x, pbounds; samples) out = zeros(size(p, 2)) if x isa Function for i in 1:size(p, 2) - out[i] = x(sol[i]) + out[i] = x(sol.u[i]) end else for i in 1:size(p, 2) - out[i] = sol[i](t; idxs = x) + out[i] = sol.u[i](t; idxs = x) end end return out diff --git a/src/threshold.jl b/src/threshold.jl index 3ab0039..9a22f0e 100644 --- a/src/threshold.jl +++ b/src/threshold.jl @@ -32,6 +32,30 @@ function get_threshold(prob, obs, threshold; alg = nothing, kw...) return sol.t[end] end +# Decompose a symbolic threshold inequality (e.g. `x > 10.0`) into the state it +# constrains, the numeric bound, and whether the violating side is the upper one +# (`maximum(state) > bound`) or the lower one (`minimum(state) < bound`). +# Symbolics canonicalizes comparisons, so `x > 10.0` is stored as `<(10.0, x)`: +# the constant can land on either side of the operator, which this normalizes. +function _threshold_violation(threshold) + v = ModelingToolkit.value(threshold) + op = operation(v) + args = arguments(v) + isconst(z) = ModelingToolkit.value(z) isa Number + if isconst(args[1]) + bound = ModelingToolkit.value(args[1]) + state = args[2] + # `bound op state`: `bound < state` ⟺ `state > bound` (upper violation). + upper = (op === <) || (op === <=) + else + bound = ModelingToolkit.value(args[2]) + state = args[1] + # `state op bound`: `state > bound`/`state >= bound` is the upper violation. + upper = (op === >) || (op === >=) + end + return state, bound, upper +end + """ prob_violating_thresholdd(prob, p, thresholds) @@ -46,16 +70,15 @@ function prob_violating_threshold(prob, p, thresholds) h(x, u, p) = u, remake(prob, p = Pair.(pkeys, [x...])).p # remake does not work well with static arrays function g(sol, p) for threshold in thresholds - if (threshold.val.f == >) || (threshold.val.f == >=) - if maximum(sol[threshold.val.arguments[1]]) > threshold.val.arguments[2] + state, bound, upper = _threshold_violation(threshold) + if upper + if maximum(sol[state]) > bound return 1.0 end - elseif (threshold.val.f == <) || (threshold.val.f == <=) - if minimum(sol[threshold.val.arguments[1]]) < threshold.val.arguments[2] + else + if minimum(sol[state]) < bound return 1.0 end - else - error() end end return 0.0 diff --git a/test/ensemble.jl b/test/ensemble.jl index 08bfa7e..7f1749b 100644 --- a/test/ensemble.jl +++ b/test/ensemble.jl @@ -52,15 +52,18 @@ eqs = [ @mtkbuild sys3 = ODESystem(eqs, t) prob3 = ODEProblem(sys3, [], tspan); -enprob = EnsembleProblem([prob, prob2, prob3]) +probs = [prob, prob2, prob3] +prob_func(prob, i::Integer, repeat) = probs[i] +prob_func(prob, ctx) = probs[ctx.sim_id] +enprob = EnsembleProblem(probs[1]; prob_func = prob_func) -sol = solve(enprob; saveat = 1); +sol = solve(enprob, Tsit5(); saveat = 1, trajectories = length(probs)); weights = [0.2, 0.5, 0.3] -fullS = vec(sum(stack(weights .* sol[S, :]), dims = 2)) -fullI = vec(sum(stack(weights .* sol[I, :]), dims = 2)) -fullR = vec(sum(stack(weights .* sol[R, :]), dims = 2)) +fullS = vec(sum(stack(weights .* [sol.u[i][S] for i in 1:length(sol.u)]), dims = 2)) +fullI = vec(sum(stack(weights .* [sol.u[i][I] for i in 1:length(sol.u)]), dims = 2)) +fullR = vec(sum(stack(weights .* [sol.u[i][R] for i in 1:length(sol.u)]), dims = 2)) t_train = 0:14 data_train = [ @@ -81,14 +84,16 @@ data_forecast = [ R => (t_forecast, fullR), ] -sol = solve(enprob; saveat = t_ensem); +sol = solve(enprob, Tsit5(); saveat = t_ensem, trajectories = length(probs)); @test ensemble_weights(sol, data_ensem) ≈ [0.2, 0.5, 0.3] -probs = [prob, prob2, prob3] ps = [[β => Uniform(0.01, 10.0), γ => Uniform(0.01, 10.0)] for i in 1:3] datas = [data_train, data_train, data_train] enprobs = bayesian_ensemble(probs, ps, datas) -sol = solve(enprobs; saveat = t_ensem); +sol = solve( + enprobs, Tsit5(); saveat = t_ensem, + trajectories = length(enprobs.prob_func.all_probs) +); ensemble_weights(sol, data_ensem)