Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/3rd_party/smcpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,9 @@ std::vector<Matrix<T> > OnePopConditionedSFS<T>::compute(const PiecewiseConstant
return csfs;
}

// Explicit instantiation for double (used by raw_sfs_batch to avoid autodiff overhead)
template class OnePopConditionedSFS<double>;

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1374,3 +1377,13 @@ Matrix<double> raw_sfs(const std::vector<double>& a, const std::vector<double>&
Matrix<adouble> sfs = sfs_cython(n, pv, t1, t2, below_only);
return sfs.unaryExpr([](adouble x){return x.value();});
}

// Added for PrepareDecoding: batched CSFS using double instead of adouble.
// The original raw_sfs uses autodiff (adouble) then discards derivatives.
// This version avoids that overhead and batches all intervals in one call.
std::vector<Matrix<double>> raw_sfs_batch(const std::vector<double>& a, const std::vector<double>& s, const int n, const std::vector<double>& hidden_states) {
ParameterVector pv = make_params(a, s);
OnePopConditionedSFS<double> csfs(n);
PiecewiseConstantRateFunction<double> eta(pv, hidden_states);
return csfs.compute(eta);
}
12 changes: 12 additions & 0 deletions src/3rd_party/smcpp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,16 @@ void smcpp_init_cache();
Matrix<double> raw_sfs(const std::vector<double>& a, const std::vector<double>& s, int n, double t1, double t2,
bool below_only = false);

/**
* Compute raw SFS for all intervals at once, using double (no autodiff overhead).
*
* @param a normalised population sizes from the demographic history
* @param s discrete derivative of the times
* @param n number of samples
* @param hidden_states vector of interval boundaries [t0, t1, t2, ..., tM] giving M intervals
* @return vector of M raw SFS matrices
*/
std::vector<Matrix<double>> raw_sfs_batch(const std::vector<double>& a, const std::vector<double>& s, int n,
const std::vector<double>& hidden_states);

#endif // PREPAREDECODING_SMCPP_HPP
26 changes: 16 additions & 10 deletions src/PrepareDecoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,24 @@ DecodingQuantities calculateCsfsAndDecodingQuantities(const Demography& demo, co
std::vector<double> tos;
std::vector<Matrix<double>> csfses;

// Must initialise smcpp cache once
smcpp_init_cache();
for (auto i = 0ul; i < arrayDisc.size() - 1; ++i) {
froms.push_back(arrayDisc[i]);
tos.push_back(arrayDisc[i + 1]);
}

auto t0 = arrayDisc[i];
auto t1 = arrayDisc[i + 1];

froms.push_back(t0);
tos.push_back(t1);

csfses.emplace_back(raw_sfs(aVec, sVec, static_cast<int>(samples - 2u), t0 / (2. * N0), t1 / (2. * N0)) * theta);
csfses.back()(0, 0) = 1.0 - csfses.back().sum();
// Build hidden state boundaries for batched CSFS computation.
// Uses double (not autodiff) and computes all intervals in one call.
smcpp_init_cache();
std::vector<double> hiddenStates;
hiddenStates.reserve(arrayDisc.size());
for (const auto& d : arrayDisc)
hiddenStates.push_back(d / (2. * N0));

auto batchSfs = raw_sfs_batch(aVec, sVec, static_cast<int>(samples - 2u), hiddenStates);
for (auto& sfs : batchSfs) {
sfs *= theta;
sfs(0, 0) = 1.0 - sfs.sum();
csfses.push_back(std::move(sfs));
}

auto csfs = CSFS::load(arrayTime, arraySize, mutRate, samples, froms, tos, csfses);
Expand Down
Loading
Loading