From 7eb3b452a0ae507e061fc252e9b8b7ba8b3893a7 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 13 Feb 2025 20:03:31 +0330 Subject: [PATCH 1/7] test via zygote --- src/base_icnf.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/base_icnf.jl b/src/base_icnf.jl index cb45c820..f3fb9878 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -18,7 +18,16 @@ function construct( FillArrays.Zeros{data_type}(nvars + naugmented), FillArrays.Eye{data_type}(nvars + naugmented), ), - sol_kwargs::NamedTuple = (;), + sol_kwargs::NamedTuple = (; + sensealg = SciMLSensitivity.QuadratureAdjoint(; + autodiff = true, + autojacvec = SciMLSensitivity.ZygoteVJP(), + ), + save_everystep = false, + reltol = sqrt(eps(one(Float32))), + abstol = eps(one(Float32)), + maxiters = typemax(Int32), + ), rng::Random.AbstractRNG = rng_AT(resource), λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} convert(data_type, 1.0e-2) From 18a936cb0fa998ac2bb090e8c9cf956b82eac58f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 13 Feb 2025 23:05:19 +0330 Subject: [PATCH 2/7] have alg --- src/base_icnf.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/base_icnf.jl b/src/base_icnf.jl index f3fb9878..36eefa17 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -19,6 +19,7 @@ function construct( FillArrays.Eye{data_type}(nvars + naugmented), ), sol_kwargs::NamedTuple = (; + alg = OrdinaryDiffEqDefault.Vern6(), sensealg = SciMLSensitivity.QuadratureAdjoint(; autodiff = true, autojacvec = SciMLSensitivity.ZygoteVJP(), From d50726316d188346b93eef49210425344d9e87f7 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 14 Feb 2025 00:33:26 +0330 Subject: [PATCH 3/7] use defaultalg --- src/base_icnf.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 36eefa17..c4b5cc06 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -19,15 +19,11 @@ function construct( FillArrays.Eye{data_type}(nvars + naugmented), ), sol_kwargs::NamedTuple = (; - alg = OrdinaryDiffEqDefault.Vern6(), - sensealg = SciMLSensitivity.QuadratureAdjoint(; + alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), + sensealg = SciMLSensitivity.InterpolatingAdjoint(; autodiff = true, autojacvec = SciMLSensitivity.ZygoteVJP(), ), - save_everystep = false, - reltol = sqrt(eps(one(Float32))), - abstol = eps(one(Float32)), - maxiters = typemax(Int32), ), rng::Random.AbstractRNG = rng_AT(resource), λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} From 68c1ee20b4d488ce09b403dd2bf1f7b0e11faa2c Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 14 Feb 2025 12:06:58 +0330 Subject: [PATCH 4/7] disable lux nested --- test/LocalPreferences.toml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 test/LocalPreferences.toml diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml new file mode 100644 index 00000000..da147a8a --- /dev/null +++ b/test/LocalPreferences.toml @@ -0,0 +1,2 @@ +[Lux] +automatic_nested_ad_switching = false From f7fae08f679d82d368b1d9f5a8d6615787d39f2d Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 14 Feb 2025 13:10:14 +0330 Subject: [PATCH 5/7] back to DI --- test/regression_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regression_tests.jl b/test/regression_tests.jl index 6b59b355..98381b76 100644 --- a/test/regression_tests.jl +++ b/test/regression_tests.jl @@ -11,7 +11,7 @@ Test.@testset "Regression Tests" begin nn, nvars, naugs; - compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), + compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), tspan = (0.0f0, 13.0f0), steer_rate = 1.0f-1, λ₃ = 1.0f-2, From a5842e83b49e34bb20a775850d8ec095967d88ee Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Fri, 14 Feb 2025 16:45:22 +0330 Subject: [PATCH 6/7] simpler --- src/base_icnf.jl | 3 +-- test/LocalPreferences.toml | 2 -- test/regression_tests.jl | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) delete mode 100644 test/LocalPreferences.toml diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 469666ae..07070a1a 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -20,8 +20,7 @@ function construct( ), sol_kwargs::NamedTuple = (; alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.InterpolatingAdjoint(; - autodiff = true, + sensealg = SciMLSensitivity.BacksolveAdjoint(; autojacvec = SciMLSensitivity.ZygoteVJP(), ), ), diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml deleted file mode 100644 index da147a8a..00000000 --- a/test/LocalPreferences.toml +++ /dev/null @@ -1,2 +0,0 @@ -[Lux] -automatic_nested_ad_switching = false diff --git a/test/regression_tests.jl b/test/regression_tests.jl index 98381b76..6b59b355 100644 --- a/test/regression_tests.jl +++ b/test/regression_tests.jl @@ -11,7 +11,7 @@ Test.@testset "Regression Tests" begin nn, nvars, naugs; - compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()), tspan = (0.0f0, 13.0f0), steer_rate = 1.0f-1, λ₃ = 1.0f-2, From 65c31851bf679e80e57e7901d5ddcfc39c5656a8 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 20 Feb 2025 12:30:55 +0330 Subject: [PATCH 7/7] just sensealg --- src/base_icnf.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/base_icnf.jl b/src/base_icnf.jl index 07070a1a..fe3ebd62 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -19,8 +19,7 @@ function construct( FillArrays.Eye{data_type}(nvars + naugmented), ), sol_kwargs::NamedTuple = (; - alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(), - sensealg = SciMLSensitivity.BacksolveAdjoint(; + sensealg = SciMLSensitivity.InterpolatingAdjoint(; autojacvec = SciMLSensitivity.ZygoteVJP(), ), ),