Skip to content

Commit ff698fa

Browse files
committed
Optimize PrepareDecoding: simplify Transition, batch CSFS with double
Transition.cpp: remove 7 dead methods, closed-form matrix exponentials for SMC/SMC1, 3x3 block decomposition for CSC, binary search in findIntervalForTime. Simplify getOmegas and decoding quantity assembly. smcpp CSFS: add raw_sfs_batch() using double instead of adouble — the original raw_sfs computed autodiff derivatives then discarded them. Batch all 69 discretization intervals in one call instead of 69 separate calls. Output matches within machine epsilon (max abs 2e-15).
1 parent 78018ec commit ff698fa

5 files changed

Lines changed: 171 additions & 240 deletions

File tree

src/3rd_party/smcpp.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,9 @@ std::vector<Matrix<T> > OnePopConditionedSFS<T>::compute(const PiecewiseConstant
13331333
return csfs;
13341334
}
13351335

1336+
// Explicit instantiation for double (used by raw_sfs_batch to avoid autodiff overhead)
1337+
template class OnePopConditionedSFS<double>;
1338+
13361339
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
13371340
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
13381341
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1374,3 +1377,13 @@ Matrix<double> raw_sfs(const std::vector<double>& a, const std::vector<double>&
13741377
Matrix<adouble> sfs = sfs_cython(n, pv, t1, t2, below_only);
13751378
return sfs.unaryExpr([](adouble x){return x.value();});
13761379
}
1380+
1381+
// Added for PrepareDecoding: batched CSFS using double instead of adouble.
1382+
// The original raw_sfs uses autodiff (adouble) then discards derivatives.
1383+
// This version avoids that overhead and batches all intervals in one call.
1384+
std::vector<Matrix<double>> raw_sfs_batch(const std::vector<double>& a, const std::vector<double>& s, const int n, const std::vector<double>& hidden_states) {
1385+
ParameterVector pv = make_params(a, s);
1386+
OnePopConditionedSFS<double> csfs(n);
1387+
PiecewiseConstantRateFunction<double> eta(pv, hidden_states);
1388+
return csfs.compute(eta);
1389+
}

src/3rd_party/smcpp.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,16 @@ void smcpp_init_cache();
2929
Matrix<double> raw_sfs(const std::vector<double>& a, const std::vector<double>& s, int n, double t1, double t2,
3030
bool below_only = false);
3131

32+
/**
33+
* Compute raw SFS for all intervals at once, using double (no autodiff overhead).
34+
*
35+
* @param a normalised population sizes from the demographic history
36+
* @param s discrete derivative of the times
37+
* @param n number of samples
38+
* @param hidden_states vector of interval boundaries [t0, t1, t2, ..., tM] giving M intervals
39+
* @return vector of M raw SFS matrices
40+
*/
41+
std::vector<Matrix<double>> raw_sfs_batch(const std::vector<double>& a, const std::vector<double>& s, int n,
42+
const std::vector<double>& hidden_states);
43+
3244
#endif // PREPAREDECODING_SMCPP_HPP

src/PrepareDecoding.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,24 @@ DecodingQuantities calculateCsfsAndDecodingQuantities(const Demography& demo, co
116116
std::vector<double> tos;
117117
std::vector<Matrix<double>> csfses;
118118

119-
// Must initialise smcpp cache once
120-
smcpp_init_cache();
121119
for (auto i = 0ul; i < arrayDisc.size() - 1; ++i) {
120+
froms.push_back(arrayDisc[i]);
121+
tos.push_back(arrayDisc[i + 1]);
122+
}
122123

123-
auto t0 = arrayDisc[i];
124-
auto t1 = arrayDisc[i + 1];
125-
126-
froms.push_back(t0);
127-
tos.push_back(t1);
128-
129-
csfses.emplace_back(raw_sfs(aVec, sVec, static_cast<int>(samples - 2u), t0 / (2. * N0), t1 / (2. * N0)) * theta);
130-
csfses.back()(0, 0) = 1.0 - csfses.back().sum();
124+
// Build hidden state boundaries for batched CSFS computation.
125+
// Uses double (not autodiff) and computes all intervals in one call.
126+
smcpp_init_cache();
127+
std::vector<double> hiddenStates;
128+
hiddenStates.reserve(arrayDisc.size());
129+
for (const auto& d : arrayDisc)
130+
hiddenStates.push_back(d / (2. * N0));
131+
132+
auto batchSfs = raw_sfs_batch(aVec, sVec, static_cast<int>(samples - 2u), hiddenStates);
133+
for (auto& sfs : batchSfs) {
134+
sfs *= theta;
135+
sfs(0, 0) = 1.0 - sfs.sum();
136+
csfses.push_back(std::move(sfs));
131137
}
132138

133139
auto csfs = CSFS::load(arrayTime, arraySize, mutRate, samples, froms, tos, csfses);

0 commit comments

Comments
 (0)