-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
The code for the logpartition and gradlogpartition is very suboptimal. It can be very naivly improved.
The current implementation for the logpartition for example:
k = length(η₁)
Cinv, l = FastCholesky.cholinv_logdet(-η₂)
return (dot(η₁, Cinv, η₁) / 2 - (k * log(2) + l)) / 2Can be re-written in the following way:
k = length(η₁)
F = FastCholesky.fastcholesky(-η₂)
l = logdet(F)
sol = F \ η₁The script to compare them:
using LinearAlgebra
using Random
using BenchmarkTools
using FastCholesky
using Plots
using ProgressMeter
"""
lp_old(η₁, η₂)
Reference implementation using inverse + logdet from a Cholesky factorization.
"""
function lp_old(η₁::AbstractVector{T}, η₂::AbstractMatrix{T}) where {T<:Real}
k = length(η₁)
Cinv, l = FastCholesky.cholinv_logdet(-η₂)
return (dot(η₁, Cinv, η₁) / 2 - (k * log(2) + l)) / 2
end
"""
lp_new(η₁, η₂)
Optimized implementation using a single Cholesky factorization and triangular solves.
"""
function lp_new(η₁::AbstractVector{T}, η₂::AbstractMatrix{T}) where {T<:Real}
k = length(η₁)
F = FastCholesky.fastcholesky(-η₂)
l = logdet(F)
sol = F \ η₁
return (dot(η₁, sol) / 2 - (k * log(2) + l)) / 2
end
"""
rand_spd(k; T=Float64)
Generate a random k×k symmetric positive definite matrix.
"""
function rand_spd(k::Int; T::Type{<:Real}=Float64)
X = randn(T, k, k)
return Symmetric(X' * X + k * I)
end
mkpath("./benchmark_logs")
dims = collect(8:8:48)
T = Float64
rng = Random.default_rng()
Random.seed!(rng, 20251118)
times_old = similar(dims, Float64)
times_new = similar(dims, Float64)
println("Benchmarking log-partition across dimensions: $(first(dims))..$(last(dims))")
@showprogress 1 "Sweeping dimensions..." for (i, k) in enumerate(dims)
η₁ = randn(rng, T, k)
A = rand_spd(k; T)
η₂ = -Matrix{T}(A) # ensure -η₂ is SPD
t_old = @belapsed lp_old($η₁, $η₂)
t_new = @belapsed lp_new($η₁, $η₂)
times_old[i] = t_old
times_new[i] = t_new
end
plt = plot(
dims, times_old;
label = "inverse-based",
xlabel = "dimension k",
ylabel = "time (s)",
legend = :topleft,
lw = 2,
marker = :circle,
)
plot!(plt, dims, times_new; label = "cholesky-solve", lw = 2, marker = :square)
savepath = "./benchmark_logs/mvnormal_logpartition_vs_dim.png"
savefig(plt, savepath)
println("Saved plot to: $savepath")
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels