diff --git a/Project.toml b/Project.toml index a57a60e..fd8f663 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b" ExponentialFamilyManifolds = "5c9727c4-3b82-4ab3-b165-76e2eb971b08" FastCholesky = "2d5283b6-8564-42b6-bb00-83ed8e915756" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" @@ -22,6 +23,9 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +[weakdeps] +ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" + [extensions] ClosedFormExpectationsExt = "ClosedFormExpectations" @@ -30,6 +34,7 @@ BayesBase = "1.5.0" Bumper = "0.6" ClosedFormExpectations = "0.3.0" Distributions = "0.25" +DomainSets = "0.7.16" ExponentialFamily = "2.0.0" ExponentialFamilyManifolds = "3.0.3" FastCholesky = "1.3" @@ -47,13 +52,11 @@ StaticArrays = "1.9" StatsFuns = "1.3" julia = "1.10" -[weakdeps] -ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" - [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ClosedFormExpectations = "70ff922c-62d4-418d-abfc-e284e489b734" +DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -66,4 +69,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"] +test = ["Test", "Aqua", "BenchmarkTools", "ClosedFormExpectations", "DomainSets", "Hwloc", "Plots", "Printf", "ForwardDiff", "Manifolds", "ReTestItems", "RollingFunctions", "JET", "StableRNGs"] diff --git a/src/ExponentialFamilyProjection.jl b/src/ExponentialFamilyProjection.jl index 39adb1a..c55cfcd 100644 --- a/src/ExponentialFamilyProjection.jl +++ b/src/ExponentialFamilyProjection.jl @@ -1,5 +1,7 @@ module ExponentialFamilyProjection +using ForwardDiff + using ExponentialFamily, ExponentialFamilyManifolds, BayesBase, diff --git a/src/strategies/bonnet/gauss_newton.jl b/src/strategies/bonnet/gauss_newton.jl index f1da51a..b77c40a 100644 --- a/src/strategies/bonnet/gauss_newton.jl +++ b/src/strategies/bonnet/gauss_newton.jl @@ -17,8 +17,43 @@ end get_nsamples(strategy::GaussNewton) = strategy.nsamples -preprocess_strategy_argument(strategy::GaussNewton{S,TL}, argument::Any) where {S,TL} = - (strategy, convert(TL, argument)) +get_default_InplaceLogpdfGradHess(argument::ContinuousUnivariateLogPdf) = begin + function __logpdf!(out, x) + out[1] = argument.logpdf(x) + return out + end + function __grad_hess!(out_grad, out_hess, x) + # Note that for univariate, the gradient is the derivative + # and the Hessian is the second derivative + out_grad .= ForwardDiff.derivative(argument.logpdf, x) + out_hess .= ForwardDiff.derivative(x -> ForwardDiff.derivative(argument.logpdf, x), x) + return out_grad, out_hess + end + default_inplace = InplaceLogpdfGradHess(__logpdf!, __grad_hess!) + return default_inplace +end + +get_default_InplaceLogpdfGradHess(argument::ContinuousMultivariateLogPdf) = begin + function __logpdf!(out, x) + out[1] = argument.logpdf(x) + return out + end + function __grad_hess!(out_grad, out_hess, x) + out_grad .= ForwardDiff.gradient(argument.logpdf, x) + out_hess .= ForwardDiff.hessian(argument.logpdf, x) + return out_grad, out_hess + end + default_inplace = InplaceLogpdfGradHess(__logpdf!, __grad_hess!) + return default_inplace +end + +preprocess_strategy_argument(strategy::GaussNewton{S,TL}, argument::Any) where {S,TL} = begin + if argument isa Union{ContinuousUnivariateLogPdf, ContinuousMultivariateLogPdf} + return (strategy, get_default_InplaceLogpdfGradHess(argument)) + else + (strategy, convert(TL, argument)) + end +end preprocess_strategy_argument(::GaussNewton, argument::AbstractArray) = error( lazy"The `GaussNewton` strategy requires the projection argument to be a callable object (e.g. `Function`) or an `InplaceLogpdfGradHess`. Got `$(typeof(argument))` instead.", ) diff --git a/test/strategies/gauss_newton_tests.jl b/test/strategies/gauss_newton_tests.jl index 9a8cbe6..7e367e6 100644 --- a/test/strategies/gauss_newton_tests.jl +++ b/test/strategies/gauss_newton_tests.jl @@ -321,3 +321,48 @@ end end end end + +@testitem "Test default GaussNewton strategy given ContinuousUni-/Multi-variateLogPdf" begin + using ExponentialFamily, BayesBase + import ExponentialFamilyProjection: + GaussNewton, + getstrategy, + preprocess_strategy_argument, + get_default_InplaceLogpdfGradHess + import ExponentialFamily: + NormalMeanVariance, + MvNormalMeanCovariance + import BayesBase: ContinuousUnivariateLogPdf, ContinuousMultivariateLogPdf + import LinearAlgebra: Diagonal + import DomainSets: ℝ3 + + # univariate case + a1 = NormalMeanVariance(-10,0.1) + my_logpdf(x) = logpdf(a1, x) + my_uni_continuous_logpdf = ContinuousUnivariateLogPdf(my_logpdf) + default_uni_inplace = get_default_InplaceLogpdfGradHess(my_uni_continuous_logpdf) + params = ProjectionParameters(niterations=2000, strategy = GaussNewton()) + prj = ProjectedTo(NormalMeanVariance; parameters = params) + + test_uni_cont_proj = project_to(prj, my_uni_continuous_logpdf) + test_uni_inplace_proj = project_to(prj, default_uni_inplace) + @test test_uni_cont_proj ≈ test_uni_inplace_proj atol=1e-6 + @test test_uni_cont_proj isa NormalMeanVariance + @test mean(test_uni_cont_proj) ≈ mean(a1) atol=1e-6 + @test var(test_uni_cont_proj) ≈ var(a1) atol=1e-6 + + # multivariate case + a2 = MvNormalMeanCovariance([1.3, -5, 30.0], Diagonal([0.5, 2.0, 1.0])) + my_logpdf(x) = logpdf(a2, x) + my_mv_continuous_logpdf = ContinuousMultivariateLogPdf(ℝ3, my_logpdf) + default_mv_inplace = get_default_InplaceLogpdfGradHess(my_mv_continuous_logpdf) + params = ProjectionParameters(niterations=2000, strategy = GaussNewton()) + prj = ProjectedTo(MvNormalMeanCovariance, 3; parameters = params) + + test_mv_cont_proj = project_to(prj, my_mv_continuous_logpdf) + test_mv_inplace_proj = project_to(prj, default_mv_inplace) + @test test_mv_cont_proj ≈ test_mv_inplace_proj atol=1e-6 + @test test_mv_cont_proj isa MvNormalMeanCovariance + @test mean(test_mv_cont_proj) ≈ mean(a2) atol=1e-5 + @test cov(test_mv_cont_proj) ≈ cov(a2) atol=1e-6 +end \ No newline at end of file