From c5df7677e673cb4b3f462aae2a2c0d8fcaf4b9a6 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 18 Sep 2025 14:28:34 -0400 Subject: [PATCH] first push --- ext/ComradeDynamicHMCExt.jl | 62 +++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 ext/ComradeDynamicHMCExt.jl diff --git a/ext/ComradeDynamicHMCExt.jl b/ext/ComradeDynamicHMCExt.jl new file mode 100644 index 000000000..91000a977 --- /dev/null +++ b/ext/ComradeDynamicHMCExt.jl @@ -0,0 +1,62 @@ +module ComradeDynamicHMCExt + +using Comrade +using DynamicHMC + +using DocStringExtensions +using HypercubeTransform +using Random +using LogDensityProblems +using Serialization +using StatsBase + + +function DynamicHMC.mcmc_with_warmup( + rng, post::Comrade.VLBIPosterior, N; + kwargs... + ) + + if isnothing(Comrade.admode(post)) + throw(ArgumentError("You must specify an automatic differentiation type in VLBIPosterior with admode kwarg")) + else + tpost = asflat(post) + end + + results = mcmc_with_warmup(rng, tpost, N; kwargs...) + + return PosteriorSamples( + transform.(Ref(tpost), eachcol(results.posterior_matrix)), + results.tree_statistics; + metadata = Dict( + :sampler = :DynamicHMC, :ϵ => results.ϵ, + :mass_matrix => results.κ + ) + ) +end + +function DynamicHMC.mcmc_with_warmup( + rng, post::Comrade.VLBIPosterior, N, output::DiskStore; + kwargs... + ) + + (; name, stride) = output + stride = min(stride, N) + nscans = nsamples ÷ output_stride + (nsamples % output_stride != 0 ? 1 : 0) + outbase = joinpath(name, "samples", "output_scan_") + + + tpost = asflat(post) + results = mcmc_with_warmup(rng, tpost, N; kwargs...) + + return PosteriorSamples( + transform.(Ref(tpost), eachcol(results.posterior_matrix)), + results.tree_statistics; + metadata = Dict( + :sampler = :DynamicHMC, :ϵ => results.ϵ, + :mass_matrix => results.κ + ) + ) +end + + +end