Skip to content

[Task]: Optimize logpartition and gradlogpartion for MvNormal distribution #271

@Nimrais

Description

@Nimrais

The code for the logpartition and gradlogpartition is very suboptimal. It can be very naivly improved.

The current implementation for the logpartition for example:

k = length(η₁)
Cinv, l = FastCholesky.cholinv_logdet(-η₂)
return (dot(η₁, Cinv, η₁) / 2 - (k * log(2) + l)) / 2

Can be re-written in the following way:

k = length(η₁)
F = FastCholesky.fastcholesky(-η₂)
l = logdet(F)
sol = F \ η₁

The script to compare them:

using LinearAlgebra
using Random
using BenchmarkTools
using FastCholesky
using Plots
using ProgressMeter

"""
    lp_old(η₁, η₂)

Reference implementation using inverse + logdet from a Cholesky factorization.
"""
function lp_old(η₁::AbstractVector{T}, η₂::AbstractMatrix{T}) where {T<:Real}
    k = length(η₁)
    Cinv, l = FastCholesky.cholinv_logdet(-η₂)
    return (dot(η₁, Cinv, η₁) / 2 - (k * log(2) + l)) / 2
end

"""
    lp_new(η₁, η₂)

Optimized implementation using a single Cholesky factorization and triangular solves.
"""
function lp_new(η₁::AbstractVector{T}, η₂::AbstractMatrix{T}) where {T<:Real}
    k = length(η₁)
    F = FastCholesky.fastcholesky(-η₂)
    l = logdet(F)
    sol = F \ η₁
    return (dot(η₁, sol) / 2 - (k * log(2) + l)) / 2
end

"""
    rand_spd(k; T=Float64)

Generate a random k×k symmetric positive definite matrix.
"""
function rand_spd(k::Int; T::Type{<:Real}=Float64)
    X = randn(T, k, k)
    return Symmetric(X' * X + k * I)
end

mkpath("./benchmark_logs")

dims = collect(8:8:48)
T = Float64
rng = Random.default_rng()
Random.seed!(rng, 20251118)

times_old = similar(dims, Float64)
times_new = similar(dims, Float64)

println("Benchmarking log-partition across dimensions: $(first(dims))..$(last(dims))")
@showprogress 1 "Sweeping dimensions..." for (i, k) in enumerate(dims)
    η₁ = randn(rng, T, k)
    A  = rand_spd(k; T)
    η₂ = -Matrix{T}(A)  # ensure -η₂ is SPD

    t_old = @belapsed lp_old($η₁, $η₂)
    t_new = @belapsed lp_new($η₁, $η₂)

    times_old[i] = t_old
    times_new[i] = t_new
end

plt = plot(
    dims, times_old;
    label = "inverse-based",
    xlabel = "dimension k",
    ylabel = "time (s)",
    legend = :topleft,
    lw = 2,
    marker = :circle,
)
plot!(plt, dims, times_new; label = "cholesky-solve", lw = 2, marker = :square)

savepath = "./benchmark_logs/mvnormal_logpartition_vs_dim.png"
savefig(plt, savepath)
println("Saved plot to: $savepath")
Image

Metadata

Metadata

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions