diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index 917ea58cd..130fe002a 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -35,7 +35,7 @@ struct InstrumentTransform{T, L <: SiteLookup} <: AbstractInstrumentTransform site_map::L end -function TV.inverse_eltype(::AbstractInstrumentTransform, x) +function TV.inverse_eltype(::AbstractInstrumentTransform, x::Type) return eltype(x) end diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index 9c49dda39..c2e91d08a 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -42,15 +42,15 @@ end # end -struct ObservedArrayPrior{D, S} <: Distributions.ContinuousMultivariateDistribution - dists::D +struct ObservedArrayPrior{D, DS, S} <: Distributions.ContinuousMultivariateDistribution + sitedists::DS sitemap::S - phase::Bool end Base.eltype(d::ObservedArrayPrior) = eltype(d.dists) Base.length(d::ObservedArrayPrior) = length(d.dists) Dists._logpdf(d::ObservedArrayPrior, x::AbstractArray{<:Real}) = Dists._logpdf(d.dists, parent(x)) Dists._rand!(rng::Random.AbstractRNG, d::ObservedArrayPrior, x::AbstractArray{<:Real}) = SiteArray(Dists._rand!(rng, d.dists, x), d.sitemap) + function asflat(d::ObservedArrayPrior) d.phase && MarkovInstrumentTransform(asflat(d.dists), d.sitemap) return InstrumentTransform(asflat(d.dists), d.sitemap) @@ -197,7 +197,7 @@ HypercubeTransform.ascube(t::PartiallyConditionedDist) = PartiallyFixedTransform function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroid_station) ts = smap.times ss = smap.sites - # fs = smap.frequencies + fs = smap.frequencies fixedinds, vals = reference_indices(array, smap, refants) if !(centroid_station isa Nothing) @@ -212,9 +212,22 @@ function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroi variateinds = setdiff(eachindex(ts), fixedinds) dist = map(variateinds) do i - getproperty(dists, ss[i]).dist + s = ss[i] + sitedist(getproperty(dists, s), s, ts[i], fs[i], smap) end dist = Dists.product_distribution(dist) length(fixedinds) == 0 && return dist return PartiallyConditionedDist(dist, variateinds, fixedinds, vals) end + +function sitedist(dist::IIDSitePrior, site, time, freq, smap) + return dist.dist +end + +function sitedist(dist::RWSitePrior, site, time, freq, smap) + if time > first(smap.times[smap.lookup[site]]) + return dist.trans + else + return dist.dist0 + end +end diff --git a/src/instrument/priors/independent.jl b/src/instrument/priors/independent.jl index 1c7501576..37fb84ca0 100644 --- a/src/instrument/priors/independent.jl +++ b/src/instrument/priors/independent.jl @@ -1,7 +1,22 @@ -export IIDSitePrior +export IIDSitePrior, RWSitePrior +""" + AbstractSitePrior + +An abstract type for site priors. This defines an abstract distribution for the prior of a single +site (antenna). Specific site priors should subtype this abstract type and then implement +```julia +sitedist(d::AbstractSitePrior, site::Symbol, time, frequency, sitemap::SiteMap) +``` +""" abstract type AbstractSitePrior end +""" + sitedist(d::AbstractSitePrior, site::Symbol, time, frequency, smap) +Get the distribution for a specific site at a specific time and frequency, +""" +function sitedist end + segmentation(d::AbstractSitePrior) = getfield(d, :seg) """ @@ -24,3 +39,29 @@ struct IIDSitePrior{S <: Segmentation, D} <: AbstractSitePrior seg::S dist::D end + +""" + RWSitePrior(seg::Segmentation, dist0, transition) + +Create a site prior that is essentially a random walk from segment to segment. +The `seg` argument is a segmentation object that defines how fine the time +segmentation is. The `dist0` argument is the distribution for the initial scan, +and the `transition` argument is the distribution for the transition between segments. + +This means +x0 ~ dist0 +xt = x0 + ϵt, where ϵt ~ transition + + +## Example + +```julia +A = RWSitePrior(ScanSeg(), Normal(0, 1), Normal(0, 0.1)) +``` + +""" +struct RWSitePrior{S <: Segmentation, D0, DT} <: AbstractSitePrior + seg::S + dist0::D0 + trans::DT +end diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 909ed2506..4729612b2 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -152,10 +152,14 @@ sites(s::SiteLookup) = s.sites frequencies(s::SiteLookup) = s.frequencies lookup(s::SiteLookup) = s.lookup -EnzymeRules.inactive(::typeof(times), ::SiteLookup) = nothing -EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup) = nothing -EnzymeRules.inactive(::typeof(sites), ::SiteLookup) = nothing -EnzymeRules.inactive(::typeof(lookup), ::SiteLookup) = nothing +times(s::SiteLookup, site::Symbol) = view(times(s), lookup(s)[site]) +frequencies(s::SiteLookup, site::Symbol) = view(frequencies(s), lookup(s)[site]) +sites(s::SiteLookup, site::Symbol) = view(sites(s), lookup(s)[site]) + +EnzymeRules.inactive(::typeof(times), ::SiteLookup, ::Any...) = nothing +EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup, ::Any...) = nothing +EnzymeRules.inactive(::typeof(sites), ::SiteLookup, ::Any...) = nothing +EnzymeRules.inactive(::typeof(lookup), ::SiteLookup, ::Any...) = nothing function sitemap!(f, out::AbstractArray, gains::AbstractArray, slook::SiteLookup) return map(lookup(slook)) do site