From 43ed004655093cabb86eac8e45d59d6a6ab86413 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Wed, 19 Nov 2025 13:47:24 -0500 Subject: [PATCH 1/2] Start addition of RW site prior --- src/instrument/priors/array_priors.jl | 21 +++++++++++-- src/instrument/priors/independent.jl | 43 ++++++++++++++++++++++++++- src/instrument/site_array.jl | 12 +++++--- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index 9c49dda39..c71d9dc47 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -42,8 +42,9 @@ end # end -struct ObservedArrayPrior{D, S} <: Distributions.ContinuousMultivariateDistribution +struct ObservedArrayPrior{D, DS, S} <: Distributions.ContinuousMultivariateDistribution dists::D + sitedists::DS sitemap::S phase::Bool end @@ -51,6 +52,7 @@ Base.eltype(d::ObservedArrayPrior) = eltype(d.dists) Base.length(d::ObservedArrayPrior) = length(d.dists) Dists._logpdf(d::ObservedArrayPrior, x::AbstractArray{<:Real}) = Dists._logpdf(d.dists, parent(x)) Dists._rand!(rng::Random.AbstractRNG, d::ObservedArrayPrior, x::AbstractArray{<:Real}) = SiteArray(Dists._rand!(rng, d.dists, x), d.sitemap) + function asflat(d::ObservedArrayPrior) d.phase && MarkovInstrumentTransform(asflat(d.dists), d.sitemap) return InstrumentTransform(asflat(d.dists), d.sitemap) @@ -197,7 +199,7 @@ HypercubeTransform.ascube(t::PartiallyConditionedDist) = PartiallyFixedTransform function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroid_station) ts = smap.times ss = smap.sites - # fs = smap.frequencies + fs = smap.frequencies fixedinds, vals = reference_indices(array, smap, refants) if !(centroid_station isa Nothing) @@ -212,9 +214,22 @@ function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroi variateinds = setdiff(eachindex(ts), fixedinds) dist = map(variateinds) do i - getproperty(dists, ss[i]).dist + s = ss[i] + sitedist(getproperty(dists, s), s, ts[i], fs[i], smap) end dist = Dists.product_distribution(dist) length(fixedinds) == 0 && return dist return PartiallyConditionedDist(dist, variateinds, fixedinds, vals) end + +function sitedist(dist::IIDSitePrior, site, time, freq, smap) + return dist.dist +end + +function sitedist(dist::RWSitePrior, site, time, freq, smap) + if time > first(smap.times[smap.lookup[site]]) + return dist.trans + else + return dist.dist0 + end +end diff --git a/src/instrument/priors/independent.jl b/src/instrument/priors/independent.jl index 1c7501576..37fb84ca0 100644 --- a/src/instrument/priors/independent.jl +++ b/src/instrument/priors/independent.jl @@ -1,7 +1,22 @@ -export IIDSitePrior +export IIDSitePrior, RWSitePrior +""" + AbstractSitePrior + +An abstract type for site priors. This defines an abstract distribution for the prior of a single +site (antenna). Specific site priors should subtype this abstract type and then implement +```julia +sitedist(d::AbstractSitePrior, site::Symbol, time, frequency, sitemap::SiteMap) +``` +""" abstract type AbstractSitePrior end +""" + sitedist(d::AbstractSitePrior, site::Symbol, time, frequency, smap) +Get the distribution for a specific site at a specific time and frequency, +""" +function sitedist end + segmentation(d::AbstractSitePrior) = getfield(d, :seg) """ @@ -24,3 +39,29 @@ struct IIDSitePrior{S <: Segmentation, D} <: AbstractSitePrior seg::S dist::D end + +""" + RWSitePrior(seg::Segmentation, dist0, transition) + +Create a site prior that is essentially a random walk from segment to segment. +The `seg` argument is a segmentation object that defines how fine the time +segmentation is. The `dist0` argument is the distribution for the initial scan, +and the `transition` argument is the distribution for the transition between segments. + +This means +x0 ~ dist0 +xt = x0 + ϵt, where ϵt ~ transition + + +## Example + +```julia +A = RWSitePrior(ScanSeg(), Normal(0, 1), Normal(0, 0.1)) +``` + +""" +struct RWSitePrior{S <: Segmentation, D0, DT} <: AbstractSitePrior + seg::S + dist0::D0 + trans::DT +end diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 909ed2506..046d195a8 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -152,10 +152,14 @@ sites(s::SiteLookup) = s.sites frequencies(s::SiteLookup) = s.frequencies lookup(s::SiteLookup) = s.lookup -EnzymeRules.inactive(::typeof(times), ::SiteLookup) = nothing -EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup) = nothing -EnzymeRules.inactive(::typeof(sites), ::SiteLookup) = nothing -EnzymeRules.inactive(::typeof(lookup), ::SiteLookup) = nothing +times(s::SiteLookup, site::Symbol) = @view(times(s), lookup(s)[site]) +frequencies(s::SiteLookup, site::Symbol) = @view(frequencies(s), lookup(s)[site]) +sites(s::SiteLookup, site::Symbol) = @view(sites(s), lookup(s)[site]) + +EnzymeRules.inactive(::typeof(times), ::SiteLookup, ::Any...) = nothing +EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup, ::Any...) = nothing +EnzymeRules.inactive(::typeof(sites), ::SiteLookup, ::Any...) = nothing +EnzymeRules.inactive(::typeof(lookup), ::SiteLookup, ::Any...) = nothing function sitemap!(f, out::AbstractArray, gains::AbstractArray, slook::SiteLookup) return map(lookup(slook)) do site From 7b99b8230523f0fa09fbacaf70771db7a84f0804 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 8 Dec 2025 21:07:53 -0500 Subject: [PATCH 2/2] rw --- src/instrument/instrument_transforms.jl | 2 +- src/instrument/priors/array_priors.jl | 2 -- src/instrument/site_array.jl | 6 +++--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/instrument/instrument_transforms.jl b/src/instrument/instrument_transforms.jl index 917ea58cd..130fe002a 100644 --- a/src/instrument/instrument_transforms.jl +++ b/src/instrument/instrument_transforms.jl @@ -35,7 +35,7 @@ struct InstrumentTransform{T, L <: SiteLookup} <: AbstractInstrumentTransform site_map::L end -function TV.inverse_eltype(::AbstractInstrumentTransform, x) +function TV.inverse_eltype(::AbstractInstrumentTransform, x::Type) return eltype(x) end diff --git a/src/instrument/priors/array_priors.jl b/src/instrument/priors/array_priors.jl index c71d9dc47..c2e91d08a 100644 --- a/src/instrument/priors/array_priors.jl +++ b/src/instrument/priors/array_priors.jl @@ -43,10 +43,8 @@ end struct ObservedArrayPrior{D, DS, S} <: Distributions.ContinuousMultivariateDistribution - dists::D sitedists::DS sitemap::S - phase::Bool end Base.eltype(d::ObservedArrayPrior) = eltype(d.dists) Base.length(d::ObservedArrayPrior) = length(d.dists) diff --git a/src/instrument/site_array.jl b/src/instrument/site_array.jl index 046d195a8..4729612b2 100644 --- a/src/instrument/site_array.jl +++ b/src/instrument/site_array.jl @@ -152,9 +152,9 @@ sites(s::SiteLookup) = s.sites frequencies(s::SiteLookup) = s.frequencies lookup(s::SiteLookup) = s.lookup -times(s::SiteLookup, site::Symbol) = @view(times(s), lookup(s)[site]) -frequencies(s::SiteLookup, site::Symbol) = @view(frequencies(s), lookup(s)[site]) -sites(s::SiteLookup, site::Symbol) = @view(sites(s), lookup(s)[site]) +times(s::SiteLookup, site::Symbol) = view(times(s), lookup(s)[site]) +frequencies(s::SiteLookup, site::Symbol) = view(frequencies(s), lookup(s)[site]) +sites(s::SiteLookup, site::Symbol) = view(sites(s), lookup(s)[site]) EnzymeRules.inactive(::typeof(times), ::SiteLookup, ::Any...) = nothing EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup, ::Any...) = nothing