Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 65 additions & 24 deletions src/data/state.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Unicode: normalize

export HistoryState, State
export solution, state, vectorfield
export solution, state, timevariable, vectorfield


# The `_state` function returns an appropriate empty state for a given state variable.
Expand Down Expand Up @@ -43,6 +43,7 @@ It stores all the information that is required to uniquely determine the state o
in particular all state variables and their corresponding vector fields.
"""
struct State{
timeType<:TimeVariable,
stateType<:NamedTuple,
solutionType<:NamedTuple,
vectorfieldType<:NamedTuple,
Expand All @@ -51,20 +52,24 @@ struct State{
vectorfieldKeys
}

time::timeType
state::stateType
solution::solutionType
vectorfield::vectorfieldType

function State(state, solution, vectorfield)
function State(time, state, solution, vectorfield)
stateKeys = Val.(keys(state))
solutionKeys = Val.(keys(solution))
vectorfieldKeys = Val.(keys(vectorfield))

new{typeof(state),typeof(solution),typeof(vectorfield),stateKeys,solutionKeys,vectorfieldKeys}(state, solution, vectorfield)
new{typeof(time),typeof(state),typeof(solution),typeof(vectorfield),stateKeys,solutionKeys,vectorfieldKeys}(time, state, solution, vectorfield)
end
end

function State(ics::NamedTuple; initialize=true)
function State(initialtime::Real, ics::NamedTuple; initialize=true)
# create time variable
time = TimeVariable(zero(initialtime))

# create solution tuple for all variables in ics
solution = map(x -> _state(x), ics)

Expand All @@ -82,55 +87,69 @@ function State(ics::NamedTuple; initialize=true)
state_fields = merge(solution, vectorfield_dots)

# create state
state = State(state_fields, solution, vectorfield_filtered)
state = State(time, state_fields, solution, vectorfield_filtered)

# copy initial conditions to state if initialize == true
initialize && copy!(time, initialtime)
initialize && copy!(state, ics)

return state
end

State(initialtime::TimeVariable, ics::NamedTuple; kwargs...) = State(value(initialtime), ics; kwargs...)
State(st::State; kwargs...) = State(time(st), solution(st); kwargs...)
State(st::StateWithError; kwargs...) = State(state(st); kwargs...)



"""
keys(st::State)

Return the keys of all the state variables in the `State`.
"""
Base.keys(::State{stT,solT,vecT,stKeys}) where {stT,solT,vecT,stKeys} = stKeys
Base.keys(st::State) = keys(state(st))

solutionkeys(::State{stT,solT,vecT,stKeys,solKeys,vecKeys}) where {stT,solT,vecT,stKeys,solKeys,vecKeys} = solKeys
vectorfieldkeys(::State{stT,solT,vecT,stKeys,solKeys,vecKeys}) where {stT,solT,vecT,stKeys,solKeys,vecKeys} = vecKeys
statekeys(::State{TT,stT,solT,vecT,stKeys,solKeys,vecKeys}) where {TT,stT,solT,vecT,stKeys,solKeys,vecKeys} = stKeys
solutionkeys(::State{TT,stT,solT,vecT,stKeys,solKeys,vecKeys}) where {TT,stT,solT,vecT,stKeys,solKeys,vecKeys} = solKeys
vectorfieldkeys(::State{TT,stT,solT,vecT,stKeys,solKeys,vecKeys}) where {TT,stT,solT,vecT,stKeys,solKeys,vecKeys} = vecKeys

"""
haskey(st::State, ::Val{s})
haskey(st::State, s::Symbol)

Checks if `s` is a valid state variable in the `State`.
"""
Base.haskey(st::State, s::Val) = s ∈ keys(st)
Base.haskey(st::State, s::Symbol) = haskey(st, Val(s))
# Base.haskey(st::State, s::Val) = s ∈ keys(st)
# Base.haskey(st::State, s::Symbol) = haskey(st, Val(s))
Base.haskey(st::State, s::Symbol) = haskey(state(st), s)


function Base.hasproperty(st::State, s::Symbol)
haskey(st, s) || hasfield(State, s)
s === :t || haskey(st, s) || hasfield(State, s)
end

function Base.getproperty(st::State, s::Symbol)
if haskey(st, s)
return value(getfield(st, :state)[s])
if haskey(getfield(st, :state), s)
return getfield(st, :state)[s]
elseif s === :t
return value(getfield(st, :time))
else
return getfield(st, s)
end
end

function Base.setproperty!(st::State, s::Symbol, x)
if haskey(st, s)
if haskey(getfield(st, :state), s)
return copy!(getfield(st, :state)[s], x)
elseif s === :t
return copy!(getfield(st, :time), x)
else
return setfield!(st, s, x)
end
end

timevariable(st::State) = st.time
Base.time(st::State) = value(timevariable(st))
state(st::State) = st.state
solution(st::State) = st.solution
vectorfield(st::State) = st.vectorfield
Expand Down Expand Up @@ -158,30 +177,52 @@ Base.isnan(st::State) = mapfoldl(isnan, |, variables(st))


"""
copy!(st::State, sol::NamedTuple)
initialize!(st::State, ics::NamedTuple)

Copy the values from a `NamedTuple` `sol` to the `State` `st`.
Copy the values from a `NamedTuple` `ics` to the `State` `st`.

The keys of `sol` must be a subset of the keys of the state.
The keys of `ics` must be the same as the solution keys of the state.

# Arguments
- `st`: the state to copy into
- `sol`: the named tuple containing the solution values to copy
- `ics`: the named tuple containing the initial values to copy
"""
function Base.copy!(st::State, sol::NamedTuple)
@assert keys(sol) ⊆ keys(state(st))
map(k -> copy!(st[k], sol[k]), keys(sol))
function copy!(st::State, sol::NamedTuple)
# @assert keys(sol) == keys(solution(st))
# map((x, y) -> copy!(x, y), values(solution(st)), values(sol))
# println(" keys(sol) = ", keys(sol))
# println(" keys(st) = ", keys(st))
@assert keys(sol) ⊆ keys(st)
map(k -> copy!(state(st)[k], sol[k]), keys(sol))
return st
end

function copy!(st::State, t::Union{Real,TimeVariable}, sol::NamedTuple)
copy!(st, sol)
copy!(st.time, t)
return st
end

"""
copy!(dst::State, src::State)

Copy the values from one `State` `src` to another `State` `dst`.

The keys of `src` and `dst` must identical.

# Arguments
- `src`: the state to copy into
- `dst`: the state containing the solution values to copy
"""
function Base.copy!(dst::State, src::State)
@assert keys(dst) == keys(src)
map((x,y) -> copy!(x,y), variables(dst), variables(src))
copy!(dst.time, src.time)
map((x, y) -> copy!(x, y), variables(dst), variables(src))
return dst
end

function Base.copy(oldstate::State)
newstate = State(solution(oldstate))
newstate = State(time(oldstate), solution(oldstate))
copy!(newstate, oldstate)
return newstate
end
Expand All @@ -203,5 +244,5 @@ function HistoryState(st::State)
history_vectorfield = NamedTuple{_add_bar.(keys(vectorfield(st)))}(values(vectorfield(st)))

# create history state
State(history_state, history_solution, history_vectorfield)
State(timevariable(st), history_state, history_solution, history_vectorfield)
end
44 changes: 25 additions & 19 deletions src/data/state_variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,9 @@ abstract type AbstractVariable{DT,N} <: AbstractArray{DT,N} end

parent(::AV) where {AV<:AbstractVariable} = error("parent() method not implemented for abstract variable ", AV)


# The `value` function returns a processable value for any given `AbstractVariable`.
# In most cases, this is just the variable itself, but for an `AbstractScalarVariable`
# it is the scalar value stored in the 0d array (see `TimeVariable` for more details).
value(a::AbstractVariable) = a
value(x::Missing) = x

isnan(a::AbstractVariable) = any(isnan, parent(a))


axes(a::AbstractVariable, ind...) = axes(parent(a), ind...)
size(a::AbstractVariable, ind...) = size(parent(a), ind...)
eachindex(a::AbstractVariable) = eachindex(parent(a))
Expand All @@ -39,6 +33,11 @@ Base.:(==)(x::AbstractVariable, y::AbstractVariable) = parent(x) == parent(y)
Base.:(==)(x::AbstractVariable, y::AbstractArray) = parent(x) == y
Base.:(==)(x::AbstractArray, y::AbstractVariable) = y == x

# Base.:(≠)(x::AV, y::AV) where {AV<:AbstractVariable} = parent(x) ≠ parent(y)
Base.:(≠)(x::AbstractVariable, y::AbstractVariable) = parent(x) ≠ parent(y)
Base.:(≠)(x::AbstractVariable, y::AbstractArray) = parent(x) ≠ y
Base.:(≠)(x::AbstractArray, y::AbstractVariable) = y ≠ x

# Base.:(≈)(x::AV, y::AV, args...; kwargs...) where {AV<:AbstractVariable} = ≈(parent(x), parent(y), args...; kwargs...)
Base.:(≈)(x::AbstractVariable, y::AbstractVariable, args...; kwargs...) = ≈(parent(x), parent(y), args...; kwargs...)
Base.:(≈)(x::AbstractVariable, y::AbstractArray, args...; kwargs...) = ≈(parent(x), y, args...; kwargs...)
Expand Down Expand Up @@ -69,6 +68,11 @@ abstract type AbstractStateVariable{DT,N,AT} <: AbstractVariable{DT,N} end

parenttype(::AbstractStateVariable{DT,N,AT}) where {DT,N,AT} = AT

Base.broadcasted(::typeof(:(+)), a::AbstractStateVariable, b::AbstractStateVariable) = parent(a) .+= parent(b)
Base.broadcasted(::typeof(:(+)), a::AbstractStateVariable, b::AbstractArray) = parent(a) .+= b
# Base.broadcasted(::typeof(:(-)), a::AbstractStateVariable, b::AbstractStateVariable) = parent(a) .-= parent(b)
# Base.broadcasted(::typeof(:(-)), a::AbstractStateVariable, b::AbstractArray) = parent(a) .-= b

function copy!(dst::AbstractStateVariable{DT,N,AT}, src::AT) where {DT,N,AT<:AbstractArray{DT,N}}
@assert axes(dst) == axes(src)
copy!(parent(dst), src)
Expand Down Expand Up @@ -105,6 +109,8 @@ Base.convert(::Type{T}, x::TimeVariable{T}) where {T<:Number} = value(x)

Base.broadcasted(::typeof(:(+)), a::TimeVariable, b::TimeVariable) = parent(a) .+= parent(b)
Base.broadcasted(::typeof(:(+)), a::TimeVariable, b::Number) = parent(a) .+= b
# Base.broadcasted(::typeof(:(-)), a::TimeVariable, b::TimeVariable) = parent(a) .-= parent(b)
# Base.broadcasted(::typeof(:(-)), a::TimeVariable, b::Number) = parent(a) .-= b

add!(t::TimeVariable{DT}, Δt::DT) where {DT} = t .+= Δt
copy!(t::TimeVariable{DT}, y::DT) where {DT} = t .= y
Expand Down Expand Up @@ -187,18 +193,6 @@ verifyrange(s::StateVariable) = BitArray(verifyrange(s::StateVariable, i) for i
Base.:(==)(x::StateVariable, y::StateVariable) = parent(x) == parent(y) && range(x) == range(y) && periodic(x) == periodic(y)


struct VectorfieldVariable{DT,N,AT<:AbstractArray{DT,N}} <: AbstractStateVariable{DT,N,AT}
value::AT
end
VectorfieldVariable(x::VectorfieldVariable) = VectorfieldVariable(parent(x))
VectorfieldVariable(x::StateVariable) = VectorfieldVariable(zero(parent(x)))

parent(v::VectorfieldVariable) = v.value

copy(v::VectorfieldVariable) = VectorfieldVariable(copy(parent(v)))
zero(v::VectorfieldVariable) = VectorfieldVariable(zero(parent(v)))


struct AlgebraicVariable{DT,N,AT<:AbstractArray{DT,N}} <: AbstractStateVariable{DT,N,AT}
value::AT
end
Expand Down Expand Up @@ -319,6 +313,18 @@ function add!(s::StateWithError{DT,N,VT}, Δs::Increment{DT,N,VT}) where {DT,N,V
end


struct VectorfieldVariable{DT,N,AT<:AbstractArray{DT,N}} <: AbstractStateVariable{DT,N,AT}
value::AT
end
VectorfieldVariable(x::VectorfieldVariable) = VectorfieldVariable(parent(x))
VectorfieldVariable(x::AbstractStateVariable) = VectorfieldVariable(zero(parent(x)))

parent(v::VectorfieldVariable) = v.value

copy(v::VectorfieldVariable) = VectorfieldVariable(copy(parent(v)))
zero(v::VectorfieldVariable) = VectorfieldVariable(zero(parent(v)))


"""
`StateVector{DT,VT}` is a vector of [`StateVariable`](@ref)s, where `DT` is the datatype of the state and `VT` is the
type of the vector.
Expand Down
3 changes: 3 additions & 0 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ function timespan end
function timestep end
function timesteps end

function initialstate end

initialtime(x) = timespan(x)[begin]
finaltime(x) = timespan(x)[end]

Expand All @@ -47,3 +49,4 @@ function description end
function reference end

function value end
function variables end
2 changes: 1 addition & 1 deletion src/utils/norms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end

function L2norm(x, y)
@assert axes(x) == axes(y)
mapreduce((xᵢ, yᵢ) -> (xᵢ - yᵢ)^2, +, x, y)
mapfoldl(z -> (z[1] - z[2])^2, +, zip(x, y))
end

l2norm(x) = sqrt(L2norm(x))
Expand Down
Loading
Loading