Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
ExponentialFamilyManifolds = "5c9727c4-3b82-4ab3-b165-76e2eb971b08"
FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please implement it as a julia extension. You can read about julia extensions here https://discourse.julialang.org/t/quick-tutorial-on-package-extensions/130923.

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Expand All @@ -22,6 +23,9 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"

[extensions]
ClosedFormExpectationsExt = "ClosedFormExpectations"

Expand All @@ -30,6 +34,7 @@ BayesBase = "1.5.0"
Bumper = "0.6"
ClosedFormExpectations = "0.3.0"
Distributions = "0.25"
DomainSets = "0.7.16"
ExponentialFamily = "2.0.0"
ExponentialFamilyManifolds = "3.0.3"
FastCholesky = "1.3"
Expand All @@ -47,13 +52,11 @@ StaticArrays = "1.9"
StatsFuns = "1.3"
julia = "1.10"

[weakdeps]
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand All @@ -66,4 +69,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"]
test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "DomainSets", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"]
2 changes: 2 additions & 0 deletions src/ExponentialFamilyProjection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module ExponentialFamilyProjection

using ForwardDiff

using ExponentialFamily,
ExponentialFamilyManifolds,
BayesBase,
Expand Down
39 changes: 37 additions & 2 deletions src/strategies/bonnet/gauss_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,43 @@ end

get_nsamples(strategy::GaussNewton) = strategy.nsamples

preprocess_strategy_argument(strategy::GaussNewton{S,TL}, argument::Any) where {S,TL} =
(strategy, convert(TL, argument))
get_default_InplaceLogpdfGradHess(argument::ContinuousUnivariateLogPdf) = begin
function __logpdf!(out, x)
out[1] = argument.logpdf(x)
return out
end
function __grad_hess!(out_grad, out_hess, x)
# Note that for univariate, the gradient is the derivative
# and the Hessian is the second derivative
out_grad .= ForwardDiff.derivative(argument.logpdf, x)
out_hess .= ForwardDiff.derivative(x -> ForwardDiff.derivative(argument.logpdf, x), x)
return out_grad, out_hess
end
default_inplace = InplaceLogpdfGradHess(__logpdf!, __grad_hess!)
return default_inplace
end

get_default_InplaceLogpdfGradHess(argument::ContinuousMultivariateLogPdf) = begin
function __logpdf!(out, x)
out[1] = argument.logpdf(x)
return out
end
function __grad_hess!(out_grad, out_hess, x)
out_grad .= ForwardDiff.gradient(argument.logpdf, x)
out_hess .= ForwardDiff.hessian(argument.logpdf, x)
return out_grad, out_hess
end
default_inplace = InplaceLogpdfGradHess(__logpdf!, __grad_hess!)
return default_inplace
end

preprocess_strategy_argument(strategy::GaussNewton{S,TL}, argument::Any) where {S,TL} = begin
if argument isa Union{ContinuousUnivariateLogPdf, ContinuousMultivariateLogPdf}
return (strategy, get_default_InplaceLogpdfGradHess(argument))
else
(strategy, convert(TL, argument))
end
end
preprocess_strategy_argument(::GaussNewton, argument::AbstractArray) = error(
lazy"The `GaussNewton` strategy requires the projection argument to be a callable object (e.g. `Function`) or an `InplaceLogpdfGradHess`. Got `$(typeof(argument))` instead.",
)
Expand Down
45 changes: 45 additions & 0 deletions test/strategies/gauss_newton_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,48 @@ end
end
end
end

@testitem "Test default GaussNewton strategy given ContinuousUni-/Multi-variateLogPdf" begin
using ExponentialFamily, BayesBase
import ExponentialFamilyProjection:
GaussNewton,
getstrategy,
preprocess_strategy_argument,
get_default_InplaceLogpdfGradHess
import ExponentialFamily:
NormalMeanVariance,
MvNormalMeanCovariance
import BayesBase: ContinuousUnivariateLogPdf, ContinuousMultivariateLogPdf
import LinearAlgebra: Diagonal
import DomainSets: ℝ3

# univariate case
a1 = NormalMeanVariance(-10,0.1)
my_logpdf(x) = logpdf(a1, x)
my_uni_continuous_logpdf = ContinuousUnivariateLogPdf(my_logpdf)
default_uni_inplace = get_default_InplaceLogpdfGradHess(my_uni_continuous_logpdf)
params = ProjectionParameters(niterations=2000, strategy = GaussNewton())
prj = ProjectedTo(NormalMeanVariance; parameters = params)

test_uni_cont_proj = project_to(prj, my_uni_continuous_logpdf)
test_uni_inplace_proj = project_to(prj, default_uni_inplace)
@test test_uni_cont_proj ≈ test_uni_inplace_proj atol=1e-6
@test test_uni_cont_proj isa NormalMeanVariance
@test mean(test_uni_cont_proj) ≈ mean(a1) atol=1e-6
@test var(test_uni_cont_proj) ≈ var(a1) atol=1e-6

# multivariate case
a2 = MvNormalMeanCovariance([1.3, -5, 30.0], Diagonal([0.5, 2.0, 1.0]))
my_logpdf(x) = logpdf(a2, x)
my_mv_continuous_logpdf = ContinuousMultivariateLogPdf(ℝ3, my_logpdf)
default_mv_inplace = get_default_InplaceLogpdfGradHess(my_mv_continuous_logpdf)
params = ProjectionParameters(niterations=2000, strategy = GaussNewton())
prj = ProjectedTo(MvNormalMeanCovariance, 3; parameters = params)

test_mv_cont_proj = project_to(prj, my_mv_continuous_logpdf)
test_mv_inplace_proj = project_to(prj, default_mv_inplace)
@test test_mv_cont_proj ≈ test_mv_inplace_proj atol=1e-6
@test test_mv_cont_proj isa MvNormalMeanCovariance
@test mean(test_mv_cont_proj) ≈ mean(a2) atol=1e-5
@test cov(test_mv_cont_proj) ≈ cov(a2) atol=1e-6
end
Loading