Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/instrument/instrument_transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct InstrumentTransform{T, L <: SiteLookup} <: AbstractInstrumentTransform
site_map::L
end

function TV.inverse_eltype(::AbstractInstrumentTransform, x)
function TV.inverse_eltype(::AbstractInstrumentTransform, x::Type)
return eltype(x)
end

Expand Down
23 changes: 18 additions & 5 deletions src/instrument/priors/array_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ end
# end


struct ObservedArrayPrior{D, S} <: Distributions.ContinuousMultivariateDistribution
dists::D
struct ObservedArrayPrior{D, DS, S} <: Distributions.ContinuousMultivariateDistribution
sitedists::DS
sitemap::S
phase::Bool
end
Base.eltype(d::ObservedArrayPrior) = eltype(d.dists)
Base.length(d::ObservedArrayPrior) = length(d.dists)
Dists._logpdf(d::ObservedArrayPrior, x::AbstractArray{<:Real}) = Dists._logpdf(d.dists, parent(x))
Dists._rand!(rng::Random.AbstractRNG, d::ObservedArrayPrior, x::AbstractArray{<:Real}) = SiteArray(Dists._rand!(rng, d.dists, x), d.sitemap)

function asflat(d::ObservedArrayPrior)
d.phase && MarkovInstrumentTransform(asflat(d.dists), d.sitemap)
return InstrumentTransform(asflat(d.dists), d.sitemap)
Expand Down Expand Up @@ -197,7 +197,7 @@ HypercubeTransform.ascube(t::PartiallyConditionedDist) = PartiallyFixedTransform
function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroid_station)
ts = smap.times
ss = smap.sites
# fs = smap.frequencies
fs = smap.frequencies
fixedinds, vals = reference_indices(array, smap, refants)

if !(centroid_station isa Nothing)
Expand All @@ -212,9 +212,22 @@ function build_dist(dists::NamedTuple, smap::SiteLookup, array, refants, centroi

variateinds = setdiff(eachindex(ts), fixedinds)
dist = map(variateinds) do i
getproperty(dists, ss[i]).dist
s = ss[i]
sitedist(getproperty(dists, s), s, ts[i], fs[i], smap)
end
dist = Dists.product_distribution(dist)
length(fixedinds) == 0 && return dist
return PartiallyConditionedDist(dist, variateinds, fixedinds, vals)
end

function sitedist(dist::IIDSitePrior, site, time, freq, smap)
return dist.dist
end

function sitedist(dist::RWSitePrior, site, time, freq, smap)
if time > first(smap.times[smap.lookup[site]])
return dist.trans
else
return dist.dist0
end
end
43 changes: 42 additions & 1 deletion src/instrument/priors/independent.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
export IIDSitePrior
export IIDSitePrior, RWSitePrior

"""
AbstractSitePrior

An abstract type for site priors. This defines an abstract distribution for the prior of a single
site (antenna). Specific site priors should subtype this abstract type and then implement
```julia
sitedist(d::AbstractSitePrior, site::Symbol, time, frequency, sitemap::SiteMap)
```
"""
abstract type AbstractSitePrior end

"""
sitedist(d::AbstractSitePrior, site::Symbol, time, frequency, smap)
Get the distribution for a specific site at a specific time and frequency,
"""
function sitedist end

segmentation(d::AbstractSitePrior) = getfield(d, :seg)

"""
Expand All @@ -24,3 +39,29 @@ struct IIDSitePrior{S <: Segmentation, D} <: AbstractSitePrior
seg::S
dist::D
end

"""
RWSitePrior(seg::Segmentation, dist0, transition)

Create a site prior that is essentially a random walk from segment to segment.
The `seg` argument is a segmentation object that defines how fine the time
segmentation is. The `dist0` argument is the distribution for the initial scan,
and the `transition` argument is the distribution for the transition between segments.

This means
x0 ~ dist0
xt = x0 + ϵt, where ϵt ~ transition


## Example

```julia
A = RWSitePrior(ScanSeg(), Normal(0, 1), Normal(0, 0.1))
```

"""
struct RWSitePrior{S <: Segmentation, D0, DT} <: AbstractSitePrior
seg::S
dist0::D0
trans::DT
end
12 changes: 8 additions & 4 deletions src/instrument/site_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,14 @@ sites(s::SiteLookup) = s.sites
frequencies(s::SiteLookup) = s.frequencies
lookup(s::SiteLookup) = s.lookup

EnzymeRules.inactive(::typeof(times), ::SiteLookup) = nothing
EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup) = nothing
EnzymeRules.inactive(::typeof(sites), ::SiteLookup) = nothing
EnzymeRules.inactive(::typeof(lookup), ::SiteLookup) = nothing
times(s::SiteLookup, site::Symbol) = view(times(s), lookup(s)[site])
frequencies(s::SiteLookup, site::Symbol) = view(frequencies(s), lookup(s)[site])
sites(s::SiteLookup, site::Symbol) = view(sites(s), lookup(s)[site])

EnzymeRules.inactive(::typeof(times), ::SiteLookup, ::Any...) = nothing
EnzymeRules.inactive(::typeof(frequencies), ::SiteLookup, ::Any...) = nothing
EnzymeRules.inactive(::typeof(sites), ::SiteLookup, ::Any...) = nothing
EnzymeRules.inactive(::typeof(lookup), ::SiteLookup, ::Any...) = nothing

function sitemap!(f, out::AbstractArray, gains::AbstractArray, slook::SiteLookup)
return map(lookup(slook)) do site
Expand Down
Loading