Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ makedocs(
"https://royalsocietypublishing.org/doi/10.1098/rspa.2020.0279",
"https://www.pnas.org/doi/10.1073/pnas.1517384113",
"https://link.springer.com/article/10.1007/s00332-015-9258-5",
# SciML's hosted docs reject Documenter's linkcheck crawler with HTTP 403,
# though the cross-doc links resolve fine in a browser.
r"^https://docs\.sciml\.ai/.*",
],
format = Documenter.HTML(
assets = ["assets/favicon.ico"],
Expand Down
4 changes: 2 additions & 2 deletions docs/src/libs/datadrivensr/example_01.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ prob = ContinuousDataDrivenProblem(X, t, U = U)

#md plot(prob)

# To solve our problem, we will use [`EQSearch`](@ref), which provides a wrapper for the [symbolic regression interface](https://ai.damtp.cam.ac.uk/symbolicregression/stable/api/#Options).
# To solve our problem, we will use [`EQSearch`](@ref), which provides a wrapper for the [symbolic regression interface](https://docs.sciml.ai/SymbolicRegression/stable/api/#Options).
# We will stick to simple operations, use a `L1DistLoss`, and limit the verbosity of the algorithm.

eqsearch_options = SymbolicRegression.Options(
Expand All @@ -46,7 +46,7 @@ eqsearch_options = SymbolicRegression.Options(
alg = EQSearch(eq_options = eqsearch_options)

# Again, we `solve` the problem to obtain a [`DataDrivenSolution`](@ref). Note that any additional keyword arguments are passed onto
# symbolic regressions [`EquationSearch`](https://ai.damtp.cam.ac.uk/symbolicregression/stable/api/#EquationSearch) with the exception of `niterations` which
# symbolic regressions [`equation_search`](https://docs.sciml.ai/SymbolicRegression/stable/api/#equation_search) with the exception of `niterations` which
# is `maxiters`

res = solve(prob, alg, options = DataDrivenCommonOptions(maxiters = 100))
Expand Down
2 changes: 1 addition & 1 deletion docs/src/libs/datadrivensr/example_02.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ prob = DataDrivenProblem(sol)

#md plot(prob)

# To solve our problem, we will use [`EQSearch`](@ref), which provides a wrapper for the [symbolic regression interface](https://ai.damtp.cam.ac.uk/symbolicregression/stable/api/#Options).
# To solve our problem, we will use [`EQSearch`](@ref), which provides a wrapper for the [symbolic regression interface](https://docs.sciml.ai/SymbolicRegression/stable/api/#Options).
# We will stick to simple operations, use a `L1DistLoss`, and limit the verbosity of the algorithm.
# Note that we do not include `sin`, but rather lift the search space of variables.

Expand Down
62 changes: 44 additions & 18 deletions lib/DataDrivenLux/src/algorithms/common.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,45 @@
@kwdef @concrete struct CommonAlgOptions
populationsize::Int = 100
functions = (sin, exp, cos, log, +, -, /, *)
arities = (1, 1, 1, 1, 2, 2, 2, 2)
n_layers::Int = 1
skip::Bool = true
simplex <: AbstractSimplex = Softmax()
loss = aicc
keep <: Union{Real, Int} = 0.1
use_protected::Bool = true
distributed::Bool = false
threaded::Bool = false
rng <: AbstractRNG = Random.default_rng()
optimizer = LBFGS()
optim_options <: Optim.Options = Optim.Options()
optimiser <: Union{Nothing, Optimisers.AbstractRule} = nothing
observed <: Union{ObservedModel, Nothing} = nothing
alpha::Real = 0.999f0
@concrete struct CommonAlgOptions
populationsize::Int
functions
arities
n_layers::Int
skip::Bool
simplex <: AbstractSimplex
loss
keep <: Union{Real, Int}
use_protected::Bool
distributed::Bool
threaded::Bool
rng <: AbstractRNG
optimizer
optim_options <: Optim.Options
optimiser <: Union{Nothing, Optimisers.AbstractRule}
observed <: Union{ObservedModel, Nothing}
alpha::Real
end

function CommonAlgOptions(;
populationsize::Int = 100,
functions = (sin, exp, cos, log, +, -, /, *),
arities = (1, 1, 1, 1, 2, 2, 2, 2),
n_layers::Int = 1,
skip::Bool = true,
simplex::AbstractSimplex = Softmax(),
loss = aicc,
keep::Union{Real, Int} = 0.1,
use_protected::Bool = true,
distributed::Bool = false,
threaded::Bool = false,
rng::AbstractRNG = Random.default_rng(),
optimizer = LBFGS(),
optim_options::Optim.Options = Optim.Options(),
optimiser::Union{Nothing, Optimisers.AbstractRule} = nothing,
observed::Union{ObservedModel, Nothing} = nothing,
alpha::Real = 0.999f0
)
return CommonAlgOptions(
populationsize, functions, arities, n_layers, skip, simplex, loss, keep,
use_protected, distributed, threaded, rng, optimizer, optim_options,
optimiser, observed, alpha
)
end
22 changes: 11 additions & 11 deletions lib/DataDrivenLux/src/caches/candidate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,27 @@ $(FIELDS)
"""
@concrete struct Candidate <: StatsBase.StatisticalModel
"Random seed"
rng <: AbstractRNG
rng
"The current state"
st <: NamedTuple
st
"The current parameters"
ps <: AbstractVector
ps
"Incoming paths"
incoming_path <: Vector{<:AbstractPathState}
incoming_path
"Outgoing path"
outgoing_path <: Vector{<:AbstractPathState}
outgoing_path
"Statistics"
statistics <: PathStatistics
statistics
"The observed model"
observed <: ObservedModel
observed
"The parameter distribution"
parameterdist <: ParameterDistributions
parameterdist
"The optimal scales"
scales <: AbstractVector
scales
"The optimal parameters"
parameters <: AbstractVector
parameters
"The component model"
model <: ComponentModel
model
end

function (c::Candidate)(dataset::Dataset{T}, ps = c.ps, p = c.parameters) where {T}
Expand Down
2 changes: 1 addition & 1 deletion lib/DataDrivenLux/src/caches/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function Dataset(
x_intervals = interval.(map(extrema, eachrow(X)))
y_intervals = interval.(map(extrema, eachrow(Y)))
u_intervals = interval.(map(extrema, eachrow(U)))
t_intervals = isempty(t) ? Interval{T}(zero(T), zero(T)) : interval(extrema(t))
t_intervals = isempty(t) ? interval(zero(T), zero(T)) : interval(extrema(t))
return Dataset{T}(X, Y, U, t, x_intervals, y_intervals, u_intervals, t_intervals)
end

Expand Down
7 changes: 5 additions & 2 deletions lib/DataDrivenLux/src/custom_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ function ObservedDistribution(
return ObservedDistribution{fixed, D}(errormodel, latent_scale, transform)
end

function Base.summary(io::IO, ::ObservedDistribution{fixed, D}) where {fixed, D}
return print(io, "$E : $D() with $(fixed ? "fixed" : "variable") scale.")
function Base.summary(io::IO, d::ObservedDistribution{fixed, D}) where {fixed, D}
return print(
io,
"$(nameof(typeof(d.errormodel))) : $D() with $(fixed ? "fixed" : "variable") scale."
)
end

get_init(d::ObservedDistribution) = d.latent_scale
Expand Down
8 changes: 5 additions & 3 deletions lib/DataDrivenLux/src/lux/path_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ function update_path(::Nothing, id::Tuple{Int, Int}, state::PathState{T}) where
end

function update_path(
f::F where {F <: Function}, id::Tuple{Int, Int}, states::PathState{T}...
f::F where {F <: Function}, id::Tuple{Int, Int},
state1::PathState{T}, states::PathState{T}...
) where {T}
allstates = (state1, states...)
return PathState{T}(
f(get_interval.(states)...), (f, tuplejoin(map(get_operators, states)...)...),
(id, tuplejoin(map(get_nodes, states)...)...)
f(get_interval.(allstates)...), (f, tuplejoin(map(get_operators, allstates)...)...),
(id, tuplejoin(map(get_nodes, allstates)...)...)
)
end

Expand Down
2 changes: 2 additions & 0 deletions lib/DataDrivenLux/test/qa/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
DataDrivenLux = "47881146-99d0-492a-8425-8f2f33327637"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DataDrivenDiffEq = {path = "../../../.."}
DataDrivenLux = {path = "../.."}

[compat]
Expand Down
12 changes: 12 additions & 0 deletions lib/DataDrivenLux/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ const GROUP = get(ENV, "DATADRIVENDIFFEQ_TEST_GROUP", get(ENV, "GROUP", "All"))

function activate_qa_env()
Pkg.activate(joinpath(@__DIR__, "qa"))
# On Julia < 1.11 the qa env's [sources] table is ignored, so the in-repo
# DataDrivenLux/DataDrivenDiffEq would resolve as registered packages and QA
# would analyze stale released code. Develop the local paths to restore the
# 1.11+ [sources] behavior (no-op effect on >= 1.11, which honors [sources]).
if VERSION < v"1.11.0-DEV.0"
Pkg.develop(
[
Pkg.PackageSpec(path = joinpath(@__DIR__, "..", "..", "..")),
Pkg.PackageSpec(path = joinpath(@__DIR__, "..")),
]
)
end
return Pkg.instantiate()
end

Expand Down
4 changes: 2 additions & 2 deletions lib/DataDrivenSR/src/DataDrivenSR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using Reexport
"""
$(TYPEDEF)
Options for using SymbolicRegression.jl within the `solve` function.
Automatically creates [`Options`](https://ai.damtp.cam.ac.uk/symbolicregression/stable/api/#Options) with the given specification.
Automatically creates [`Options`](https://docs.sciml.ai/SymbolicRegression/stable/api/#Options) with the given specification.
Sorts the operators stored in `functions` into unary and binary operators on conversion.

# Fields
Expand Down Expand Up @@ -89,7 +89,7 @@ end
is_success(k::SRResult) = getfield(k, :retcode) == DDReturnCode(1)

# StatsBase Overload
StatsBase.coef(x::SRResult) = getfield(x, :k)
StatsBase.coef(x::SRResult) = get_parameter_values(getfield(x, :basis))

StatsBase.rss(x::SRResult) = getfield(x, :rss)

Expand Down
2 changes: 2 additions & 0 deletions lib/DataDrivenSR/test/qa/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataDrivenDiffEq = "2445eb08-9709-466a-b3fc-47e12bd697a2"
DataDrivenSR = "7fed8a53-d475-4873-af3a-ba53cceea094"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DataDrivenDiffEq = {path = "../../../.."}
DataDrivenSR = {path = "../.."}

[compat]
Expand Down
12 changes: 12 additions & 0 deletions lib/DataDrivenSR/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ const GROUP = get(ENV, "DATADRIVENDIFFEQ_TEST_GROUP", get(ENV, "GROUP", "All"))

function activate_qa_env()
Pkg.activate(joinpath(@__DIR__, "qa"))
# On Julia < 1.11 the qa env's [sources] table is ignored, so the in-repo
# DataDrivenSR/DataDrivenDiffEq would resolve as registered packages and QA
# would analyze stale released code. Develop the local paths to restore the
# 1.11+ [sources] behavior (no-op effect on >= 1.11, which honors [sources]).
if VERSION < v"1.11.0-DEV.0"
Pkg.develop(
[
Pkg.PackageSpec(path = joinpath(@__DIR__, "..", "..", "..")),
Pkg.PackageSpec(path = joinpath(@__DIR__, "..")),
]
)
end
return Pkg.instantiate()
end

Expand Down
29 changes: 29 additions & 0 deletions src/basis/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,20 @@ function (f::DataDrivenFunction{true, false})(
return _apply_vec_function(f, du, u, p, t, __EMPTY_MATRIX)
end

function (f::DataDrivenFunction{true, true})(
du::AbstractMatrix, u::AbstractMatrix, p::P,
t::AbstractVector,
c::AbstractMatrix
) where {
P <:
Union{
AbstractArray,
Tuple,
},
}
return _apply_vec_function(f, du, u, p, t, c)
end

## IIP

function (f::DataDrivenFunction{false, false})(
Expand Down Expand Up @@ -302,3 +316,18 @@ function (f::DataDrivenFunction{true, false})(
}
return _apply_vec_function!(f, res, du, u, p, t, __EMPTY_MATRIX)
end

function (f::DataDrivenFunction{true, true})(
res::AbstractMatrix, du::AbstractMatrix,
u::AbstractMatrix, p::P,
t::AbstractVector,
c::AbstractMatrix
) where {
P <:
Union{
AbstractArray,
Tuple,
},
}
return _apply_vec_function!(f, res, du, u, p, t, c)
end
Loading