diff --git a/pyproject.toml b/pyproject.toml index 00935f6..31af362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ requires-python = ">=3.9" dependencies = [ "click", - "xarray", "h5py", "pandas", "numba", @@ -25,7 +24,6 @@ dependencies = [ "tszip", "arg-needle-lib==1.2.1", "cyvcf2", - "ray", "pgenlib", "tqdm" ] diff --git a/src/DataConsistency.cpp b/src/DataConsistency.cpp index 2294690..fa05b9a 100644 --- a/src/DataConsistency.cpp +++ b/src/DataConsistency.cpp @@ -222,7 +222,7 @@ ThreadingInstructions ConsistencyWrapper::get_consistent_instructions() { // Make output threading instructions std::vector output_instructions; output_instructions.reserve(instruction_converters.size()); - for (InstructionConverter converter : instruction_converters) { + for (auto& converter : instruction_converters) { output_instructions.push_back(converter.parse_converted_instructions()); } diff --git a/src/Demography.cpp b/src/Demography.cpp index 6b1f178..5a21366 100644 --- a/src/Demography.cpp +++ b/src/Demography.cpp @@ -83,7 +83,7 @@ double Demography::expected_branch_length(const int N) { std::ostream& operator<<(std::ostream& os, const Demography& d) { for (std::size_t i = 0; i < d.sizes.size(); i++) { - std::cout << d.times[i] << " " << d.sizes[i] << " " << d.std_times[i] << std::endl; + os << d.times[i] << " " << d.sizes[i] << " " << d.std_times[i] << "\n"; } return os; } diff --git a/src/HMM.cpp b/src/HMM.cpp index 8f902c0..9e9b806 100644 --- a/src/HMM.cpp +++ b/src/HMM.cpp @@ -28,17 +28,17 @@ HMM::HMM(Demography demography, std::vector bp_sizes, std::vector trellis_row(num_states, 0.0); - std::vector pointer_row(num_states, 0); - trellis.push_back(trellis_row); - pointers.push_back(pointer_row); + trellis.emplace_back(num_states, 0.0); + pointers.emplace_back(num_states, 0); } } std::vector HMM::compute_expected_times(Demography demography, const int K) { std::vector result; + result.reserve(K); double k = static_cast(num_states); boost::math::exponential e; @@ -50,39 +50,45 @@ std::vector HMM::compute_expected_times(Demography demography, const int } void HMM::compute_recombination_scores(std::vector cm_sizes) { + const double log_K = std::log(num_states); + non_transition_score.reserve(cm_sizes.size()); + transition_score.reserve(cm_sizes.size()); for (std::size_t i = 0; i < cm_sizes.size(); i++) { - non_transition_score.push_back(std::vector()); - transition_score.push_back(std::vector()); + std::vector non_trans(num_states); + std::vector trans(num_states); for (int k = 0; k < num_states; k++) { - double t = expected_times[k]; - const double l = 2. * 0.01 * cm_sizes[i] * t; - const double trans = std::log1p(-std::exp(-l)) - std::log(num_states); + const double l = 2. * 0.01 * cm_sizes[i] * expected_times[k]; + const double t = std::log1p(-std::exp(-l)) - log_K; // log-prob of transitioning - transition_score[i].push_back(trans); + trans[k] = t; // log-prob of *not* transitioning - non_transition_score[i].push_back(std::log(std::exp(-l) + std::exp(trans))); + non_trans[k] = std::log(std::exp(-l) + std::exp(t)); } + transition_score.push_back(std::move(trans)); + non_transition_score.push_back(std::move(non_trans)); } } void HMM::compute_mutation_scores(std::vector bp_sizes, double mutation_rate) { + hom_score.reserve(bp_sizes.size()); + het_score.reserve(bp_sizes.size()); for (std::size_t i = 0; i < bp_sizes.size(); i++) { - hom_score.push_back(std::vector()); - het_score.push_back(std::vector()); + std::vector hom(num_states); + std::vector het(num_states); for (int k = 0; k < num_states; k++) { - double t = expected_times[k]; - // TODO: use mean-bp sizes here as in the main algorithm - const double l = 2. * mutation_rate * bp_sizes[i] * t; + const double l = 2. * mutation_rate * bp_sizes[i] * expected_times[k]; // log-prob of mutating - het_score[i].push_back(std::log1p(-std::exp(-l))); + het[k] = std::log1p(-std::exp(-l)); // log-prob of *not* mutating - hom_score[i].push_back(-l); + hom[k] = -l; } + het_score.push_back(std::move(het)); + hom_score.push_back(std::move(hom)); } } @@ -107,12 +113,14 @@ std::vector HMM::breakpoints(std::vector observations, int start) { double score = 0.0; unsigned short running_argmax = 0; for (int j = 1; j < neighborhood_size; j++) { + const int js = j + start; for (int i = 0; i < num_states; i++) { + // Hoist mut_score out of the inner k-loop: it only depends on i, not k + const double mut_score = observations[j] ? het_score[js][i] : hom_score[js][i]; double running_max = 0; for (int k = 0; k < num_states; k++) { - double mut_score = observations[j] ? het_score[j + start][i] : hom_score[j + start][i]; double rec_score = - k == i ? non_transition_score[j + start][k] : transition_score[j + start][k]; + k == i ? non_transition_score[js][k] : transition_score[js][k]; score = trellis[j - 1 + start][k] + rec_score + mut_score; @@ -121,8 +129,8 @@ std::vector HMM::breakpoints(std::vector observations, int start) { running_argmax = static_cast(k); } } - trellis[j + start][i] = running_max; - pointers[j + start][i] = running_argmax; + trellis[js][i] = running_max; + pointers[js][i] = running_argmax; } } diff --git a/src/ImputationMatcher.cpp b/src/ImputationMatcher.cpp index f869f6c..828049a 100644 --- a/src/ImputationMatcher.cpp +++ b/src/ImputationMatcher.cpp @@ -191,7 +191,7 @@ void ImputationMatcher::process_site(const std::vector& genotype) { throw std::runtime_error(prompt); } } - sorting = next_sorting; + std::swap(sorting, next_sorting); sites_processed++; } diff --git a/src/Matcher.cpp b/src/Matcher.cpp index e19d3ec..2fbd97a 100644 --- a/src/Matcher.cpp +++ b/src/Matcher.cpp @@ -55,14 +55,14 @@ void MatchGroup::filter_matches(int min_matches) { } } else if (i < 1000) { - for (auto counts : match_candidates_counts.at(i)) { + for (const auto& counts : match_candidates_counts.at(i)) { if (counts.second >= std::min(2, min_matches)) { match_candidates.at(i).insert(counts.first); } } } else if (i < 10000) { - for (auto counts : match_candidates_counts.at(i)) { + for (const auto& counts : match_candidates_counts.at(i)) { if (counts.second >= min_matches) { match_candidates.at(i).insert(counts.first); } @@ -70,7 +70,7 @@ void MatchGroup::filter_matches(int min_matches) { } else { // Don't want too much stuff for very big studies - for (auto counts : match_candidates_counts.at(i)) { + for (const auto& counts : match_candidates_counts.at(i)) { if (counts.second >= 2 * min_matches) { match_candidates.at(i).insert(counts.first); } @@ -80,7 +80,7 @@ void MatchGroup::filter_matches(int min_matches) { if (match_candidates.at(i).size() == 0) { int tmp_min_matches = min_matches; while (match_candidates.at(i).size() == 0 && tmp_min_matches > 0) { - for (auto counts : match_candidates_counts.at(i)) { + for (const auto& counts : match_candidates_counts.at(i)) { if (counts.second >= tmp_min_matches) { match_candidates.at(i).insert(counts.first); } @@ -106,7 +106,7 @@ void MatchGroup::filter_matches(int min_matches) { void MatchGroup::insert_tops_from(MatchGroup& other) { for (int i = 1; i < num_samples; i++) { - for (auto p : other.top_four_maps.at(i)) { + for (const auto& p : other.top_four_maps.at(i)) { match_candidates.at(i).insert(p.first); } } @@ -230,7 +230,7 @@ void Matcher::process_site(const std::vector& genotype) { throw std::runtime_error(prompt); } } - sorting = next_sorting; + std::swap(sorting, next_sorting); // Threading-neighbor queries if (match_group_idx < (static_cast(match_group_sites.size()) - 1) && @@ -248,42 +248,41 @@ void Matcher::process_site(const std::vector& genotype) { } next_query_site_idx++; - // Initialize the red-black tree - std::set threaded = {permutation.at(0)}; + // Boolean array for O(1) mark + sequential scan neighbor finding + std::vector inserted(num_samples, 0); + inserted[permutation[0]] = 1; // Insert sequences and query in order for (int i = 1; i < num_samples; i++) { - std::vector matches; - matches.reserve(neighborhood_size); - auto iter = threaded.insert(permutation.at(i)); - auto iter_up = iter.first; - auto iter_down = iter.first; - // Check if genotypes are identical, just to be sure - while ((static_cast(matches.size()) < neighborhood_size) && - (iter_down != threaded.begin() || iter_up != threaded.end())) { - if (iter_down != threaded.begin()) { - iter_down--; - matches.push_back(sorting.at(*iter_down)); - } - if (static_cast(matches.size()) < neighborhood_size && iter_up != threaded.end()) { - iter_up++; - if (iter_up != threaded.end()) { - matches.push_back(sorting.at(*iter_up)); + const int pos = permutation[i]; + inserted[pos] = 1; + + // Find neighborhood_size nearest neighbors by scanning left/right + int n_found = 0; + int left = pos - 1; + int right = pos + 1; + std::unordered_map& mmmap = + match_groups[match_group_idx].match_candidates_counts[i]; + while (n_found < neighborhood_size && (left >= 0 || right < num_samples)) { + // Scan left for next set bit + if (left >= 0) { + while (left >= 0 && !inserted[left]) left--; + if (left >= 0) { + int m = sorting[left]; + mmmap[m]++; + n_found++; + left--; } } - } - for (int m : matches) { - std::unordered_map& mmmap = - match_groups.at(match_group_idx).match_candidates_counts.at(i); - if (m >= i) { - throw std::runtime_error("Illegal match candidate " + std::to_string(m) + - ", something is very wrong"); - } - if (!mmmap.count(m)) { - mmmap[m] = 1; - } - else { - mmmap[m]++; + // Scan right for next set bit + if (n_found < neighborhood_size && right < num_samples) { + while (right < num_samples && !inserted[right]) right++; + if (right < num_samples) { + int m = sorting[right]; + mmmap[m]++; + n_found++; + right++; + } } } } @@ -296,6 +295,23 @@ void Matcher::process_site(const std::vector& genotype) { sites_processed++; } +void Matcher::process_all_sites(const std::vector>& genotypes) { + for (const auto& genotype : genotypes) { + process_site(genotype); + } +} + +void Matcher::process_all_sites_flat(const int32_t* data, int n_sites, int n_haps) { + std::vector genotype(n_haps); + for (int s = 0; s < n_sites; s++) { + const int32_t* row = data + static_cast(s) * n_haps; + for (int h = 0; h < n_haps; h++) { + genotype[h] = row[h]; + } + process_site(genotype); + } +} + // Propagate top 4 matches from left and right match groups void Matcher::propagate_adjacent_matches() { for (int i = 1; i < static_cast(match_groups.size()); i++) { diff --git a/src/Matcher.hpp b/src/Matcher.hpp index a7039ec..3469830 100644 --- a/src/Matcher.hpp +++ b/src/Matcher.hpp @@ -17,6 +17,7 @@ #ifndef THREADS_ARG_MATCHER_HPP #define THREADS_ARG_MATCHER_HPP +#include #include #include #include @@ -46,6 +47,8 @@ class Matcher { // Do all the work void process_site(const std::vector& genotype); + void process_all_sites(const std::vector>& genotypes); + void process_all_sites_flat(const int32_t* data, int n_sites, int n_haps); void propagate_adjacent_matches(); void clear(); diff --git a/src/State.cpp b/src/State.cpp index fd5b93f..094af8b 100644 --- a/src/State.cpp +++ b/src/State.cpp @@ -81,25 +81,22 @@ void StateBranch::prune() { } StateTree::StateTree(std::vector& states) { - for (auto s : states) { + for (const auto& s : states) { int sample_ID = s.below->sample_ID; - if (branches.find(sample_ID) == branches.end()) { - branches[sample_ID] = StateBranch(); - } branches[sample_ID].insert(s); } } void StateTree::prune() { - for (auto pair : branches) { - branches[pair.first].prune(); + for (auto& [key, branch] : branches) { + branch.prune(); } } std::vector StateTree::dump() const { std::vector states; - for (auto pair : branches) { - for (auto s : pair.second.states) { + for (const auto& [key, branch] : branches) { + for (const auto& s : branch.states) { states.push_back(s); } } diff --git a/src/ThreadingInstructions.cpp b/src/ThreadingInstructions.cpp index 4ba5be4..cbe8363 100644 --- a/src/ThreadingInstructions.cpp +++ b/src/ThreadingInstructions.cpp @@ -341,7 +341,7 @@ std::vector ThreadingInstructions::right_multiply(const std::vector #include #include -#include +#include #include #include #include @@ -40,7 +40,8 @@ const int END_ALLELE = 0; const int HMM_SPLIT_THRESHOLD = 1000; inline std::size_t pair_key(int i, int j) { - return (static_cast(i) << 32) | static_cast(j); + return (static_cast(static_cast(i)) << 32) | + static_cast(static_cast(j)); } } // namespace @@ -56,16 +57,14 @@ ThreadsFastLS::ThreadsFastLS(std::vector _physical_positions, physical_positions(_physical_positions), genetic_positions(_genetic_positions), demography(Demography(ne, ne_times)) { if (physical_positions.size() != genetic_positions.size()) { - std::cerr << "Map lengths don't match.\n"; - exit(1); + throw std::runtime_error("Map lengths don't match."); } else if (physical_positions.size() <= 2) { - std::cerr << "Need at least 3 sites, found " << physical_positions.size() << std::endl; - exit(1); + throw std::runtime_error("Need at least 3 sites, found " + + std::to_string(physical_positions.size())); } if (mutation_rate <= 0) { - std::cerr << "Need a strictly positive mutation rate.\n"; - exit(1); + throw std::runtime_error("Need a strictly positive mutation rate."); } num_sites = static_cast(physical_positions.size()); num_samples = 0; @@ -73,14 +72,14 @@ ThreadsFastLS::ThreadsFastLS(std::vector _physical_positions, #ifdef THREADS_FAST_LS_CHECK_IN_ORDER for (int i = 0; i < num_sites - 1; i++) { if (physical_positions[i + 1] <= physical_positions[i]) { - cerr << "Physical positions must be strictly increasing, found "; - cerr << physical_positions[i + 1] << " after " << physical_positions[i] << endl; - exit(1); + throw std::runtime_error("Physical positions must be strictly increasing, found " + + std::to_string(physical_positions[i + 1]) + " after " + + std::to_string(physical_positions[i])); } if (genetic_positions[i + 1] <= genetic_positions[i]) { - cerr << "Genetic coordinates must be strictly increasing, found "; - cerr << genetic_positions[i + 1] << " after " << genetic_positions[i] << endl; - exit(1); + throw std::runtime_error("Genetic coordinates must be strictly increasing, found " + + std::to_string(genetic_positions[i + 1]) + " after " + + std::to_string(genetic_positions[i])); } } #endif // THREADS_FAST_LS_CHECK_IN_ORDER @@ -108,9 +107,8 @@ ThreadsFastLS::ThreadsFastLS(std::vector _physical_positions, } } if (trim_pos_start_idx >= trim_pos_end_idx - 3) { - std::cerr << "Too few positions left after applying burn-in, need at least 3. Aborting." - << std::endl; - exit(1); + throw std::runtime_error( + "Too few positions left after applying burn-in, need at least 3."); } // Initialize both ends of the linked-list columns @@ -135,15 +133,12 @@ ThreadsFastLS::ThreadsFastLS(std::vector _physical_positions, std::tie(cm_boundaries, cm_sizes) = site_sizes(genetic_positions); if (use_hmm) { - hmm = new HMM(demography, bp_sizes, cm_sizes, mutation_rate, 64); - } - else { - hmm = nullptr; + hmm = std::make_unique(demography, bp_sizes, cm_sizes, mutation_rate, 64); } } std::tuple, std::vector> -ThreadsFastLS::site_sizes(std::vector positions) { +ThreadsFastLS::site_sizes(const std::vector& positions) { // Find mid-points between sites std::size_t M = positions.size(); std::vector pos_means(M - 1); @@ -161,8 +156,7 @@ ThreadsFastLS::site_sizes(std::vector positions) { site_sizes[M - 1] = mean_size; for (double s : site_sizes) { if (s < 0) { - std::cerr << "Found negative site size " << s << std::endl; - exit(1); + throw std::runtime_error("Found negative site size " + std::to_string(s)); } } std::vector boundaries(M + 1); @@ -182,7 +176,7 @@ std::vector ThreadsFastLS::trimmed_positions() const { void ThreadsFastLS::delete_hmm() { if (use_hmm) { - delete hmm; + hmm.reset(); use_hmm = false; } } @@ -206,12 +200,10 @@ void ThreadsFastLS::insert(const std::vector& genotype) { void ThreadsFastLS::insert(const int ID, const std::vector& genotype) { if (ID_map.find(ID) != ID_map.end()) { - std::cerr << "ID " << ID << " is already in the panel.\n"; - exit(1); + throw std::runtime_error("ID " + std::to_string(ID) + " is already in the panel."); } if (static_cast(genotype.size()) != num_sites) { - std::cerr << "Number of input markers does not match map.\n"; - exit(1); + throw std::runtime_error("Number of input markers does not match map."); } int insert_index = num_samples; ID_map[ID] = insert_index; @@ -377,8 +369,7 @@ std::pair ThreadsFastLS::fastLS(const std::vector& int n_states = static_cast(current_states.size()); max_states = std::max(n_states, max_states); if (n_states == 0) { - std::cerr << "No states left on stack, something is messed up in the algorithm.\n"; - exit(1); + throw std::runtime_error("No states left on stack, something is messed up in the algorithm."); } // Heuristically get a bound on states we want to add @@ -418,8 +409,8 @@ std::pair ThreadsFastLS::fastLS(const std::vector& } if (new_states.size() == 0) { - std::cerr << "The algorithm is in an illegal state because no new_states were created.\n"; - exit(1); + throw std::runtime_error( + "The algorithm is in an illegal state because no new_states were created."); } // Find a best state in the current layer and recombine. @@ -428,9 +419,10 @@ std::pair ThreadsFastLS::fastLS(const std::vector& [](const auto& s1, const auto& s2) { return s1.score < s2.score; })); if (best_extension.score < z - 0.001 || best_extension.score > z + 0.001) { - std::cerr << "The algorithm is in an illegal state because z != best_extension.score, found "; - std::cerr << "best_extension.score=" << best_extension.score << " and z=" << z << std::endl; - exit(1); + throw std::runtime_error( + "The algorithm is in an illegal state because z != best_extension.score, found " + "best_extension.score=" + std::to_string(best_extension.score) + + " and z=" + std::to_string(z)); } // Add the new recombinant state to the stack (we never enter this clause on the first @@ -510,8 +502,8 @@ ThreadsFastLS::fastLS_diploid(const std::vector& genotype) { std::vector new_pairs; max_state_pairs = std::max(n_state_pairs, max_state_pairs); if (n_state_pairs == 0) { - std::cerr << "No state pairs left on stack, something is messed up in the algorithm.\n"; - exit(1); + throw std::runtime_error( + "No state pairs left on stack, something is messed up in the algorithm."); } // Heuristically get a bound on states we want to add, @@ -556,8 +548,8 @@ ThreadsFastLS::fastLS_diploid(const std::vector& genotype) { } } else { - std::cerr << "Only 0, 1, 2-alleles allowed." << std::endl; - exit(1); + throw std::runtime_error("Only 0, 1, 2-alleles allowed, found " + + std::to_string(allele)); } // Set local minima, this maps (anchor, traceback) to a score @@ -883,8 +875,8 @@ ThreadsFastLS::fastLS_diploid(const std::vector& genotype) { // END OF EXTENSION LOOP if (new_pairs.size() == 0) { - std::cerr << "The algorithm is in an illegal state because no new_states were created.\n"; - exit(1); + throw std::runtime_error( + "The algorithm is in an illegal state because no new_states were created."); } // SINGLE RECOMBINATION EVENTS @@ -939,9 +931,10 @@ ThreadsFastLS::fastLS_diploid(const std::vector& genotype) { z = double_recombinant_score; } if (std::abs(best_pair.score - z) > 0.0001) { - std::cerr << "The algorithm is in an illegal state because z != best_pair.score, found "; - std::cerr << "best_pair.score=" << best_pair.score << " and z=" << z << std::endl; - exit(1); + throw std::runtime_error( + "The algorithm is in an illegal state because z != best_pair.score, found " + "best_pair.score=" + std::to_string(best_pair.score) + + " and z=" + std::to_string(z)); } current_pairs = new_pairs; new_pairs.clear(); @@ -1183,8 +1176,8 @@ ThreadsFastLS::recombination_penalties_correct() { double ThreadsFastLS::date_segment(const int num_het_sites, const int start, const int end) { if (start > end) { - std::cerr << "Can't date a segment with length <= 0\n"; - exit(1); + throw std::runtime_error("Can't date a segment with start > end (" + + std::to_string(start) + " > " + std::to_string(end) + ")"); } double bp_size = 0; double cm_size = 0; @@ -1505,13 +1498,11 @@ std::pair ThreadsFastLS::overflow_region(const std::vector& geno std::vector ThreadsFastLS::fetch_het_hom_sites(const int id1, const int id2, const int start, const int end) { if (ID_map.find(id1) == ID_map.end()) { - std::cerr << "fetch_het_hom_sites bad id1 " << id1 << std::endl; - exit(1); + throw std::runtime_error("fetch_het_hom_sites bad id1 " + std::to_string(id1)); } if (ID_map.find(id2) == ID_map.end()) { - std::cerr << "fetch_het_hom_sites bad id2 " << id2 << std::endl; - exit(1); + throw std::runtime_error("fetch_het_hom_sites bad id2 " + std::to_string(id2)); } std::vector het_hom_sites(end - start); for (int i = start; i < end; i++) { @@ -1534,8 +1525,8 @@ ThreadsFastLS::het_sites_from_thread(const int focal_ID, const std::vector int segment_end = seg_i == num_segments - 1 ? (static_cast(physical_positions.back()) + 1) : bp_starts[seg_i + 1]; int target_ID = target_IDs[seg_i][0]; - while (segment_start <= physical_positions[site_i] && - physical_positions[site_i] < segment_end && site_i < num_sites) { + while (site_i < num_sites && segment_start <= physical_positions[site_i] && + physical_positions[site_i] < segment_end) { if (panel[ID_map.at(focal_ID)][site_i]->genotype != panel[ID_map.at(target_ID)][site_i]->genotype) { het_sites.push_back(static_cast(physical_positions[site_i])); @@ -1544,8 +1535,8 @@ ThreadsFastLS::het_sites_from_thread(const int focal_ID, const std::vector } } if (site_i != num_sites) { - std::cerr << "Found " << site_i + 1 << " sites, expected " << num_sites << std::endl; - exit(1); + throw std::runtime_error("Found " + std::to_string(site_i + 1) + + " sites, expected " + std::to_string(num_sites)); } return het_sites; } diff --git a/src/ThreadsFastLS.hpp b/src/ThreadsFastLS.hpp index fa49b8c..53b2738 100644 --- a/src/ThreadsFastLS.hpp +++ b/src/ThreadsFastLS.hpp @@ -86,7 +86,7 @@ class ThreadsFastLS { std::vector> target_IDs); static std::tuple, std::vector> - site_sizes(std::vector positions); + site_sizes(const std::vector& positions); // More attributes std::vector trimmed_positions() const; @@ -194,7 +194,7 @@ class ThreadsFastLS { std::vector bp_boundaries; std::vector cm_boundaries; Demography demography; - HMM* hmm = nullptr; + std::unique_ptr hmm; // The dynamic reference panel std::vector>> panel; diff --git a/src/ThreadsLowMem.cpp b/src/ThreadsLowMem.cpp index 2ed8ca9..0129beb 100644 --- a/src/ThreadsLowMem.cpp +++ b/src/ThreadsLowMem.cpp @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include #include @@ -32,6 +32,11 @@ ThreadsLowMem::ThreadsLowMem(const std::vector _target_ids, physical_positions(_physical_positions), genetic_positions(_genetic_positions), sparse(_sparse), demography(Demography(ne, ne_times)) { num_samples = static_cast(target_ids.size()); + for (const int t : target_ids) { + if (t > max_sample_id) { + max_sample_id = t; + } + } if (physical_positions.size() != genetic_positions.size()) { throw std::runtime_error("Map lengths don't match."); } @@ -66,9 +71,17 @@ ThreadsLowMem::ThreadsLowMem(const std::vector _target_ids, // Mean interval size in base-pairs mean_bp_size = (physical_positions.back() - physical_positions[0]) / static_cast(num_sites - 1); + + // Build flat vectors for the hot path for (int target_id : target_ids) { - segment_indices[target_id] = 0; - expected_branch_lengths[target_id] = demography.expected_branch_length(target_id + 1); + double t = demography.expected_branch_length(target_id + 1); + expected_branch_lengths[target_id] = t; + if (target_id != 0) { + active_target_ids.push_back(target_id); + branch_lengths_vec.push_back(t); + log_target_ids_vec.push_back(std::log(static_cast(target_id))); + segment_indices_vec.push_back(0); + } } // Site counters @@ -109,128 +122,195 @@ void ThreadsLowMem::initialize_viterbi(std::vector sample_ids(match_groups.at(0).match_candidates.at(target_id).begin(), match_groups.at(0).match_candidates.at(target_id).end()); - hmms.emplace(target_id, ViterbiState(target_id, sample_ids)); + hmm_vec.push_back(std::make_unique(target_id, sample_ids)); + hmm_ptrs.push_back(hmm_vec.back().get()); } } -// Pass genotypes for a single site through the intialized Threads-Viterbi instances -void ThreadsLowMem::process_site_viterbi(const std::vector& genotype) { +// Internal: process one site from raw pointer (no vector copy) +void ThreadsLowMem::process_site_viterbi_raw(const int* genotype) { bool group_change = false; if (match_group_idx < (static_cast(match_groups.size()) - 1) && - (genetic_positions.at(hmm_sites_processed) >= - match_groups.at(match_group_idx + 1).cm_position)) { + (genetic_positions[hmm_sites_processed] >= + match_groups[match_group_idx + 1].cm_position)) { match_group_idx++; group_change = true; } - double k = 2. * 0.01 * cm_sizes.at(hmm_sites_processed); - double l = 2. * mutation_rate * bp_sizes.at(hmm_sites_processed); - for (int target_id : target_ids) { - if (target_id == 0) { - continue; - } + const double k = 2. * 0.01 * cm_sizes[hmm_sites_processed]; + const double l = 2. * mutation_rate * bp_sizes[hmm_sites_processed]; + const int n_active = static_cast(active_target_ids.size()); + for (int idx = 0; idx < n_active; ++idx) { + const int target_id = active_target_ids[idx]; if (group_change) { - hmms.at(target_id).set_samples( - match_groups.at(match_group_idx).match_candidates.at(target_id)); + hmm_ptrs[idx]->set_samples( + match_groups[match_group_idx].match_candidates.at(target_id)); } - double t = expected_branch_lengths.at(target_id); - double rho_c = k * t; - double rho = sparse ? -std::log1p(-std::exp(-(k * t))) - : -(std::log1p(-std::exp(-(k * t))) - std::log(target_id)); - double mu_c = l * t; - double mu = -std::log1p(-std::exp(-(l * t))); - hmms.at(target_id).process_site(genotype, rho, rho_c, mu, mu_c); + const double t = branch_lengths_vec[idx]; + const double kt = k * t; + const double lt = l * t; + const double rho_c = kt; + const double log1p_rho = -std::log1p(-std::exp(-kt)); + const double rho = sparse ? log1p_rho : log1p_rho + log_target_ids_vec[idx]; + const double mu_c = lt; + const double mu = -std::log1p(-std::exp(-lt)); + hmm_ptrs[idx]->process_site(genotype, rho, rho_c, mu, mu_c); } hmm_sites_processed++; - return; +} + +// Pass genotypes for a single site through the initialized Threads-Viterbi instances +void ThreadsLowMem::process_site_viterbi(const std::vector& genotype) { + process_site_viterbi_raw(genotype.data()); +} + +void ThreadsLowMem::process_all_sites_viterbi(const std::vector>& genotypes) { + const int prune_interval = 500; + for (const auto& genotype : genotypes) { + process_site_viterbi_raw(genotype.data()); + if (hmm_sites_processed % prune_interval == 0) { + prune(); + } + } +} + +void ThreadsLowMem::process_all_sites_viterbi_flat(const int32_t* data, int n_sites, int n_haps) { + static_assert(sizeof(int) == sizeof(int32_t), "int and int32_t must be the same size"); + if (n_haps < max_sample_id) { + throw std::runtime_error( + "Genotype matrix should have at least " + std::to_string(max_sample_id - 1) + " samples, found " + + std::to_string(max_sample_id) + "\n"); + } + const int prune_interval = 500; + for (int s = 0; s < n_sites; s++) { + const int* row = reinterpret_cast(data + static_cast(s) * n_haps); + process_site_viterbi_raw(row); + if (hmm_sites_processed % prune_interval == 0) { + prune(); + } + } } void ThreadsLowMem::traceback() { + // Target 0 gets an empty path for (int target_id : target_ids) { if (target_id == 0) { paths.emplace(target_id, ViterbiPath(0)); + break; } - else { - - paths.emplace(target_id, hmms.at(target_id).traceback()); - } } - hmms.clear(); + // Active targets get traceback from HMM + for (std::size_t i = 0; i < active_target_ids.size(); i++) { + paths.emplace(active_target_ids[i], hmm_ptrs[i]->traceback()); + } + // Build path_ptrs for hets/dating hot path + path_ptrs.clear(); + path_ptrs.reserve(active_target_ids.size()); + for (int target_id : active_target_ids) { + path_ptrs.push_back(&paths.at(target_id)); + } + // Free HMM memory + hmm_vec.clear(); + hmm_ptrs.clear(); } -void ThreadsLowMem::process_site_hets(const std::vector& genotype) { +// Internal: process one het site from raw pointer +void ThreadsLowMem::process_site_hets_raw(const int* genotype, int n_haps) { + // Handle target 0 for (int target_id : target_ids) { if (target_id == 0) { - if (genotype.at(0) == 1) { + if (genotype[0] == 1) { paths.at(0).het_sites.push_back(het_sites_processed); } + break; } - else { - ViterbiPath& path = paths.at(target_id); - int current_seg_idx = segment_indices.at(target_id); - while (current_seg_idx < (static_cast(path.segment_starts.size()) - 1) && - (het_sites_processed >= path.segment_starts.at(current_seg_idx + 1))) { - current_seg_idx++; - } - segment_indices.at(target_id) = current_seg_idx; - int sample = path.sample_ids.at(current_seg_idx); - // For now, we do not count unphased variants as a part of this, - // so we verify at least one of the het-pair is a "1", - // (i.e., "-7") is treated as "0". - // More work is needed to verify inclusion of unphased variants helps at all - if (genotype.at(sample) != genotype.at(target_id) && - (genotype.at(sample) == 1 || genotype.at(target_id) == 1)) { - path.het_sites.push_back(het_sites_processed); - } + } + // Active targets + const int n_active = static_cast(active_target_ids.size()); + for (int idx = 0; idx < n_active; ++idx) { + const int target_id = active_target_ids[idx]; + ViterbiPath* path = path_ptrs[idx]; + int current_seg_idx = segment_indices_vec[idx]; + while (current_seg_idx < (static_cast(path->segment_starts.size()) - 1) && + (het_sites_processed >= path->segment_starts[current_seg_idx + 1])) { + current_seg_idx++; + } + segment_indices_vec[idx] = current_seg_idx; + const int sample = path->sample_ids[current_seg_idx]; + if (genotype[sample] != genotype[target_id] && + (genotype[sample] == 1 || genotype[target_id] == 1)) { + path->het_sites.push_back(het_sites_processed); } } het_sites_processed++; } +void ThreadsLowMem::process_site_hets(const std::vector& genotype) { + process_site_hets_raw(genotype.data(), static_cast(genotype.size())); +} + +void ThreadsLowMem::process_all_sites_hets(const std::vector>& genotypes) { + for (const auto& genotype : genotypes) { + process_site_hets_raw(genotype.data(), static_cast(genotype.size())); + } +} + +void ThreadsLowMem::process_all_sites_hets_flat(const int32_t* data, int n_sites, int n_haps) { + static_assert(sizeof(int) == sizeof(int32_t), "int and int32_t must be the same size"); + for (int s = 0; s < n_sites; s++) { + const int* row = reinterpret_cast(data + static_cast(s) * n_haps); + process_site_hets_raw(row, n_haps); + } +} + void ThreadsLowMem::date_segments() { if (het_sites_processed != num_sites) { throw std::runtime_error( "Can't date segments, not all sites have been parsed for heterozygosity."); } - for (int target_id : target_ids) { - if (target_id == 0) { - continue; - } - if (segment_indices.at(target_id) != paths.at(target_id).size() - 1) { + const int n_active = static_cast(active_target_ids.size()); + for (int idx = 0; idx < n_active; idx++) { + const int target_id = active_target_ids[idx]; + if (segment_indices_vec[idx] != path_ptrs[idx]->size() - 1) { std::string prompt = "incomplete path at sample " + std::to_string(target_id) + ", processed "; - prompt += std::to_string(segment_indices.at(target_id) + 1) + " segments, expected "; - prompt += std::to_string(paths.at(target_id).size()); + prompt += std::to_string(segment_indices_vec[idx] + 1) + " segments, expected "; + prompt += std::to_string(path_ptrs[idx]->size()); throw std::runtime_error(prompt); } } - for (int target_id : target_ids) { - if (target_id == 0) { - continue; - } - ViterbiPath& path = paths.at(target_id); + for (int idx = 0; idx < n_active; idx++) { + const int target_id = active_target_ids[idx]; + ViterbiPath& path = *path_ptrs[idx]; ViterbiPath new_path(target_id); std::size_t n_segs = path.segment_starts.size(); + // Track position in sorted het_sites to avoid repeated linear scans + auto het_it = path.het_sites.begin(); for (std::size_t k = 0; k < n_segs; k++) { - int sample_id = path.sample_ids.at(k); - int segment_start = path.segment_starts.at(k); - int segment_end = k < n_segs - 1 ? path.segment_starts.at(k + 1) : num_sites - 1; + int sample_id = path.sample_ids[k]; + int segment_start = path.segment_starts[k]; + int segment_end = k < n_segs - 1 ? path.segment_starts[k + 1] : num_sites - 1; if (segment_end == segment_start) { continue; } - // This is inefficient but probably not that bad + // Advance iterator to first het in [segment_start, ...) + while (het_it != path.het_sites.end() && *het_it < segment_start) { + ++het_it; + } std::vector segment_hets; - for (int h : path.het_sites) { + for (auto it = het_it; it != path.het_sites.end(); ++it) { + int h = *it; if (((segment_start <= h) && (h < segment_end)) || ((h == num_sites - 1) && (segment_end == num_sites - 1))) { segment_hets.push_back(h); @@ -250,13 +330,11 @@ void ThreadsLowMem::date_segments() { for (std::size_t j = 0; j < breakpoints.size(); j++) { int breakpoint_start = breakpoints[j]; int breakpoint_end = (j == breakpoints.size() - 1) ? segment_end : breakpoints[j + 1]; - // there may be off-by-one errors here on the last segment (but who cares?) double bp_size = - physical_positions.at(breakpoint_end) - physical_positions.at(breakpoint_start); + physical_positions[breakpoint_end] - physical_positions[breakpoint_start]; double cm_size = - genetic_positions.at(breakpoint_end) - genetic_positions.at(breakpoint_start); + genetic_positions[breakpoint_end] - genetic_positions[breakpoint_start]; - // Same as above std::vector breakpoint_hets; for (int h : segment_hets) { if (((breakpoint_start <= h) && (h < breakpoint_end)) || @@ -273,36 +351,29 @@ void ThreadsLowMem::date_segments() { } } else { - // there are off-by-one errors here on the last segment (but who cares?) - double bp_size = physical_positions.at(segment_end) - physical_positions.at(segment_start); - double cm_size = genetic_positions.at(segment_end) - genetic_positions.at(segment_start); + double bp_size = physical_positions[segment_end] - physical_positions[segment_start]; + double cm_size = genetic_positions[segment_end] - genetic_positions[segment_start]; double height = ThreadsFastLS::date_segment( static_cast(segment_hets.size()), cm_size, bp_size, mutation_rate, demography); new_path.append(segment_start, sample_id, height, segment_hets); } } - paths.at(target_id) = new_path; + *path_ptrs[idx] = new_path; } return; } int ThreadsLowMem::count_branches() const { int n_branches = 0; - for (int target_id : target_ids) { - if (target_id == 0) { - continue; - } - n_branches += hmms.at(target_id).count_branches(); + for (std::size_t i = 0; i < hmm_ptrs.size(); i++) { + n_branches += hmm_ptrs[i]->count_branches(); } return n_branches; } void ThreadsLowMem::prune() { - for (int target_id : target_ids) { - if (target_id == 0) { - continue; - } - hmms.at(target_id).prune(); + for (std::size_t i = 0; i < hmm_ptrs.size(); i++) { + hmm_ptrs[i]->prune(); } } @@ -328,4 +399,4 @@ ThreadsLowMem::serialize_paths() { return std::tuple>, std::vector>, std::vector>, std::vector>>( all_starts, all_ids, all_heights, all_hetsites); -} \ No newline at end of file +} diff --git a/src/ThreadsLowMem.hpp b/src/ThreadsLowMem.hpp index 77ad06f..ef07d0c 100644 --- a/src/ThreadsLowMem.hpp +++ b/src/ThreadsLowMem.hpp @@ -21,7 +21,8 @@ #include "Matcher.hpp" #include "ThreadsFastLS.hpp" #include "ViterbiLowMem.hpp" -#include +#include +#include #include #include #include @@ -41,6 +42,8 @@ class ThreadsLowMem { const std::vector& cm_positions); // 2b. process all sites for the hmms void process_site_viterbi(const std::vector& genotype); + void process_all_sites_viterbi(const std::vector>& genotypes); + void process_all_sites_viterbi_flat(const int32_t* data, int n_sites, int n_haps); // 2c. prune branches at regular intervals (i.e. when there's a lot of them, figure this out soon) void prune(); // 2d. traceback all the hmms to get viterbi paths @@ -48,6 +51,8 @@ class ThreadsLowMem { // 3a. add het sites void process_site_hets(const std::vector& genotype); + void process_all_sites_hets(const std::vector>& genotypes); + void process_all_sites_hets_flat(const int32_t* data, int n_sites, int n_haps); // 3b. date all segments void date_segments(); @@ -63,11 +68,12 @@ class ThreadsLowMem { public: // This object will only run the HMM for these ids std::vector target_ids; + // Keep legacy map interface for pybind compatibility std::unordered_map expected_branch_lengths; double mean_bp_size = 0.0; - std::unordered_map segment_indices; std::unordered_map paths; int num_samples = 0; + int max_sample_id = -1; int num_sites = 0; double mutation_rate = 0.0; std::vector physical_positions; @@ -79,11 +85,19 @@ class ThreadsLowMem { bool sparse = false; private: + // Hot-path data: flat vectors parallel to target_ids (excluding id 0) + std::vector active_target_ids; // target_ids without 0 + std::vector branch_lengths_vec; // parallel to active_target_ids + std::vector log_target_ids_vec; // precomputed log(target_id) + std::vector segment_indices_vec; // parallel to active_target_ids + std::vector hmm_ptrs; // parallel to active_target_ids + std::vector path_ptrs; // parallel to active_target_ids + Demography demography; // 2. HMM quantites int hmm_sites_processed = 0; - std::unordered_map hmms; + std::vector> hmm_vec; // owned, never moves int match_group_idx = 0; std::vector match_groups; @@ -92,6 +106,10 @@ class ThreadsLowMem { int het_sites_processed = 0; int n_hmm_samples = 100; int hmm_min_sites = 10; + + // Internal: process one site from raw pointer (no copy) + void process_site_viterbi_raw(const int* genotype); + void process_site_hets_raw(const int* genotype, int n_haps); }; #endif // THREADS_ARG_THREADS_LOW_MEM_HPP diff --git a/src/ViterbiLowMem.cpp b/src/ViterbiLowMem.cpp index 306cde1..57b8704 100644 --- a/src/ViterbiLowMem.cpp +++ b/src/ViterbiLowMem.cpp @@ -28,7 +28,8 @@ namespace { const int ALLELE_UNPHASED_HET = -7; inline std::size_t coord_id_key(int i, int j) { - return (static_cast(i) << 32) | static_cast(j); + return (static_cast(static_cast(i)) << 32) | + static_cast(static_cast(j)); } } // namespace @@ -128,95 +129,120 @@ ViterbiState::ViterbiState(int _target_id, std::vector _sample_ids) throw std::runtime_error("found no samples for ViterbiState object for sample " + std::to_string(target_id)); } + current_traceback_ptrs.reserve(sample_ids.size()); for (int sample_id : sample_ids) { - std::size_t key = coord_id_key(0, sample_id); - traceback_states.emplace(key, TracebackNode(sample_id, 0, nullptr, 0.)); - current_tracebacks[sample_id] = &traceback_states.at(key); + current_traceback_ptrs.push_back(alloc_node(sample_id, 0, nullptr, 0.)); } best_score = 0; best_match = sample_ids.at(0); + best_match_idx = 0; } -void ViterbiState::process_site(const std::vector& genotype, double rho, double rho_c, +void ViterbiState::process_site(const int* genotype, double rho, double rho_c, double mu, double mu_c) { - int current_site = sites_processed; + const int current_site = sites_processed; double best_new_score = best_score + std::max(rho, rho_c) + std::max(mu, mu_c); int best_new_match = best_match; - double new_score; - int observed_allele = genotype.at(target_id); - TracebackNode* prev_best = current_tracebacks.at(best_match); - for (int sample_id : sample_ids) { - int allele = genotype.at(sample_id); + int best_new_match_idx = best_match_idx; + const int observed_allele = genotype[target_id]; + const double recomb_threshold = best_score + rho; + const double unphased_penalty = (mu_c + mu) * 0.5; + const bool observed_is_unphased = (observed_allele == ALLELE_UNPHASED_HET); + + TracebackNode* prev_best = current_traceback_ptrs[best_match_idx]; + + const int n_samples = static_cast(sample_ids.size()); + for (int idx = 0; idx < n_samples; ++idx) { + const int sample_id = sample_ids[idx]; + const int allele = genotype[sample_id]; double copy_penalty; - if ((allele == ALLELE_UNPHASED_HET) || (observed_allele == ALLELE_UNPHASED_HET)) { - copy_penalty = (mu_c + mu) / 2.; + if (observed_is_unphased || (allele == ALLELE_UNPHASED_HET)) { + copy_penalty = unphased_penalty; } else { copy_penalty = (allele == observed_allele) ? mu_c : mu; } - if (!current_tracebacks.count(sample_id)) { - // If we've just added new sites (this will happen vary rarely), - // recombine from previous best state - new_score = best_score + copy_penalty + rho; - std::size_t key = coord_id_key(current_site, sample_id); - traceback_states.emplace(key, TracebackNode(sample_id, current_site, prev_best, new_score)); - current_tracebacks[sample_id] = &traceback_states.at(key); + + double new_score; + TracebackNode* state = current_traceback_ptrs[idx]; + if (state == nullptr) { + // Newly added sample (happens rarely, after set_samples) + new_score = recomb_threshold + copy_penalty; + current_traceback_ptrs[idx] = alloc_node(sample_id, current_site, prev_best, new_score); } else { - // Otherwise, check whether we should recombine or extend - TracebackNode* state = current_tracebacks.at(sample_id); - if (state->score + rho_c <= best_score + rho) { - // If extending is cheaper, simply update the score of the current traceback + if (state->score + rho_c <= recomb_threshold) { + // Extend: cheaper than recombining new_score = state->score + copy_penalty + rho_c; state->score = new_score; } else { - // If we recombine, add a new branch - new_score = best_score + copy_penalty + rho; - std::size_t key = coord_id_key(current_site, sample_id); - traceback_states.emplace(key, TracebackNode(sample_id, current_site, prev_best, new_score)); - current_tracebacks.at(sample_id) = &traceback_states.at(key); + // Recombine: add a new branch + new_score = recomb_threshold + copy_penalty; + current_traceback_ptrs[idx] = alloc_node(sample_id, current_site, prev_best, new_score); } } if (new_score < best_new_score) { best_new_score = new_score; best_new_match = sample_id; + best_new_match_idx = idx; } } best_score = best_new_score; best_match = best_new_match; + best_match_idx = best_new_match_idx; sites_processed++; } void ViterbiState::set_samples(std::unordered_set new_sample_ids) { + // Build old sample_id → ptr map from current parallel vectors + std::unordered_map old_ptrs; + old_ptrs.reserve(sample_ids.size()); + for (std::size_t i = 0; i < sample_ids.size(); ++i) { + old_ptrs[sample_ids[i]] = current_traceback_ptrs[i]; + } + std::vector new_samples_vec(new_sample_ids.begin(), new_sample_ids.end()); if (!new_sample_ids.count(best_match)) { new_samples_vec.push_back(best_match); } - for (int sample_id : sample_ids) { - // clean up branches we definitely won't use - if (!new_sample_ids.count(sample_id) && sample_id != best_match) { - current_tracebacks.erase(sample_id); + sample_ids = new_samples_vec; + + // Rebuild parallel pointer vector to match new sample_ids ordering + current_traceback_ptrs.clear(); + current_traceback_ptrs.reserve(sample_ids.size()); + for (std::size_t i = 0; i < sample_ids.size(); ++i) { + int sample_id = sample_ids[i]; + auto it = old_ptrs.find(sample_id); + current_traceback_ptrs.push_back(it != old_ptrs.end() ? it->second : nullptr); + if (sample_id == best_match) { + best_match_idx = static_cast(i); } } - sample_ids = new_samples_vec; } void ViterbiState::prune() { - std::unordered_map tmp_traceback_states; + std::deque new_nodes; + std::unordered_map key_to_ptr; - for (int sample_id : sample_ids) { - TracebackNode* state = current_tracebacks.at(sample_id); - TracebackNode* new_state = recursive_insert(tmp_traceback_states, state); - current_tracebacks[sample_id] = new_state; - } + // Recursively copy only reachable nodes into the new deque + auto copy_node = [&](auto& self, TracebackNode* state) -> TracebackNode* { + if (state == nullptr) return nullptr; + std::size_t key = state->key(); + auto it = key_to_ptr.find(key); + if (it != key_to_ptr.end()) return it->second; + TracebackNode* new_parent = self(self, state->previous); + new_nodes.emplace_back(state->sample_id, state->site, new_parent, state->score); + TracebackNode* ptr = &new_nodes.back(); + key_to_ptr[key] = ptr; + return ptr; + }; - traceback_states.clear(); - for (int sample_id : sample_ids) { - TracebackNode* state = current_tracebacks.at(sample_id); - TracebackNode* new_state = recursive_insert(traceback_states, state); - current_tracebacks[sample_id] = new_state; + for (std::size_t idx = 0; idx < sample_ids.size(); ++idx) { + current_traceback_ptrs[idx] = copy_node(copy_node, current_traceback_ptrs[idx]); } + + traceback_nodes = std::move(new_nodes); } // add everything above and return a key to the new address @@ -235,14 +261,19 @@ ViterbiState::recursive_insert(std::unordered_map& s return &state_map.at(key); } +TracebackNode* ViterbiState::alloc_node(int sample_id, int site, TracebackNode* previous, double score) { + traceback_nodes.emplace_back(sample_id, site, previous, score); + return &traceback_nodes.back(); +} + int ViterbiState::count_branches() const { - return static_cast(traceback_states.size()); + return static_cast(traceback_nodes.size()); } ViterbiPath ViterbiState::traceback() { ViterbiPath path(target_id); path.score = best_score; - TracebackNode* state = current_tracebacks.at(best_match); + TracebackNode* state = current_traceback_ptrs[best_match_idx]; while (state != nullptr) { int match_id = state->sample_id; int seg_start = state->site; diff --git a/src/ViterbiLowMem.hpp b/src/ViterbiLowMem.hpp index 4a33dfd..dc7a6a5 100644 --- a/src/ViterbiLowMem.hpp +++ b/src/ViterbiLowMem.hpp @@ -17,6 +17,7 @@ #ifndef THREADS_ARG_VITERBI_LOW_MEM_HPP #define THREADS_ARG_VITERBI_LOW_MEM_HPP +#include #include #include #include @@ -61,27 +62,36 @@ class ViterbiState { public: ViterbiState(int _target_id, std::vector _sample_ids); - void process_site(const std::vector& genotype, double rho, double rho_c, double _mu, + void process_site(const int* genotype, double rho, double rho_c, double _mu, double _mu_c); + void process_site(const std::vector& genotype, double rho, double rho_c, double _mu, + double _mu_c) { + process_site(genotype.data(), rho, rho_c, _mu, _mu_c); + } void set_samples(std::unordered_set new_sample_ids); int count_branches() const; void prune(); ViterbiPath traceback(); private: - std::unordered_map traceback_states; + // Arena for TracebackNode storage — deque guarantees pointer stability + std::deque traceback_nodes; + TracebackNode* alloc_node(int sample_id, int site, TracebackNode* previous, double score); + // Used only during prune to deduplicate copied nodes TracebackNode* recursive_insert(std::unordered_map& state_map, TracebackNode* state); public: int target_id = 0; int best_match = -1; + int best_match_idx = 0; double best_score = 0.0; int sites_processed = 0; double mutation_penalty = 0.0; std::vector sample_ids; std::vector sample_scores; - std::unordered_map current_tracebacks; + // Parallel to sample_ids: traceback pointer for each sample + std::vector current_traceback_ptrs; }; #endif // THREADS_ARG_VITERBI_LOW_MEM_HPP diff --git a/src/threads_arg/infer.py b/src/threads_arg/infer.py index 88faeac..d3f03d8 100644 --- a/src/threads_arg/infer.py +++ b/src/threads_arg/infer.py @@ -21,10 +21,7 @@ import pgenlib import importlib -os.environ["RAY_DEDUP_LOGS"] = "0" -import ray import numpy as np -import pandas as pd from threads_arg import ( ThreadsLowMem, @@ -70,7 +67,8 @@ def partial_viterbi(pgen, mode, num_samples_hap, physical_positions, genetic_pos ) local_logger = logging.getLogger(__name__) - reader = pgenlib.PgenReader(pgen.encode()) + num_pgen_samples = (1 + max(sample_batch)) // 2 + pgen_sample_subset = np.arange(num_pgen_samples, dtype=np.uint32) ne_times, ne = parse_demography(demography) sparse = None @@ -79,7 +77,7 @@ def partial_viterbi(pgen, mode, num_samples_hap, physical_positions, genetic_pos elif mode == "wgs": sparse = False else: - raise RuntimeError + raise ValueError(f"Invalid mode {mode}") # Batching here saves a small amount of memory num_samples = len(sample_batch) @@ -104,51 +102,14 @@ def partial_viterbi(pgen, mode, num_samples_hap, physical_positions, genetic_pos else: TLM.initialize_viterbi([[s[k] for k in sample_index_subset] for s in s_match_group], match_cm_positions) - M = reader.get_variant_ct() - BATCH_SIZE = int(4e7 // num_samples_hap) - n_batches = int(np.ceil(M / BATCH_SIZE)) - - # Initialize pruning parameters - a_counter = 0 - prune_threshold = 10 * num_samples_hap - prune_count = 0 - last_prune = 0 - # Iterate across the genotypes and run Li-Stephens inference - for b in range(n_batches): - # Read genotypes and check for phase - b_start = b * BATCH_SIZE - b_end = min(M, (b+1) * BATCH_SIZE) - g_size = b_end - b_start - alleles_out = np.empty((g_size, num_samples_hap), dtype=np.int32) - phased_out = np.empty((g_size, num_samples_hap // 2), dtype=np.uint8) - if (phased_out == 0).any(): - unphased_sites, unphased_samples = (1 - phased_out).nonzero() - ALLELE_UNPHASED_HET = -7 - alleles_out[unphased_sites, 2 * unphased_samples] = ALLELE_UNPHASED_HET - alleles_out[unphased_sites, 2 * unphased_samples + 1] = ALLELE_UNPHASED_HET - reader.read_alleles_and_phasepresent_range(b_start, b_end, alleles_out, phased_out) - - # For each variant in chunk, pass the genotypes through Threads-LS - for g in alleles_out: - TLM.process_site_viterbi(g) - - # Regularly prune the number of open branches and reset the pruning threshold - if a_counter % 10 == 0: - n_branches = TLM.count_branches() - if n_branches > prune_threshold: - if a_counter - last_prune <= 30: - prune_threshold *= 2 - TLM.prune() - prune_count += 1 - last_prune = a_counter - a_counter += 1 + iterate_pgen(pgen, TLM.process_all_sites_viterbi_numpy, sample_subset=pgen_sample_subset) # Construct paths TLM.traceback() # Add heterozygous sites to each path segment - iterate_pgen(pgen, lambda _, g: TLM.process_site_hets(g)) + iterate_pgen(pgen, TLM.process_all_sites_hets_numpy, sample_subset=pgen_sample_subset) # Add coalescence time to each segment TLM.date_segments() @@ -220,11 +181,9 @@ def threads_infer(pgen, map, recombination_rate, demography, mutation_rate, fit_ logger.info("Finding singletons") # Get singleton filter for the matching step - alleles_out = None - phased_out = None - ac_mask = [] - iterate_pgen(pgen, lambda i, g: ac_mask.append(1 < g.sum() < 2 * num_samples)) - ac_mask = np.array(ac_mask, dtype=bool) + mask_batches = [] + iterate_pgen(pgen, lambda G: mask_batches.append(((1 < G.sum(axis=1)) & (G.sum(axis=1) < 2 * num_samples - 1)).tolist())) #lambda i, g: ac_mask.append(1 < g.sum() < 2 * num_samples)) + ac_mask = np.array([m for batch in mask_batches for m in batch], dtype=bool) assert ac_mask.shape == genetic_positions.shape logger.info("Running PBWT matching") @@ -238,72 +197,52 @@ def threads_infer(pgen, map, recombination_rate, demography, mutation_rate, fit_ MIN_MATCHES = 4 neighborhood_size = 4 matcher = Matcher(2 * num_samples, genetic_positions[ac_mask], query_interval, match_group_interval, neighborhood_size, MIN_MATCHES) - def matcher_callback(i, g, mask, matcher): - if mask[i]: - matcher.process_site(g) + def matcher_callback(G, matcher): + matcher.process_all_sites_numpy(G) + iterate_pgen(pgen, matcher_callback, mask=ac_mask, matcher=matcher) # Add top matches from adjacent sites to each match-chunk matcher.propagate_adjacent_matches() + logger.info("Finished PBWT matching") # From here we parallelise if we can actual_num_threads = min(default_process_count(), num_threads) + num_haps = 2 * num_samples logger.info(f"Requested {num_threads} threads, found {actual_num_threads}.") paths = [] if actual_num_threads > 1: - # Warning: this creates big copies, these matches are the main source of memory usage - sample_batches = split_list(list(range(2 * num_samples)), actual_num_threads) + from multiprocessing import Pool + num_batches = 2 * actual_num_threads + + # Warning: this creates big copies, these matches are the main source of memory usagefrom multiprocessing import Pool + sample_batches = split_list(list(range(num_haps)), num_batches) match_cm_positions = matcher.cm_positions() - del alleles_out - del phased_out gc.collect() - partial_viterbi_remote = ray.remote(partial_viterbi) - ray.init() - # Parallelised threading instructions - results = ray.get([partial_viterbi_remote.remote( - pgen, - mode, - 2 * num_samples, - physical_positions, - genetic_positions, - demography, - mutation_rate, - sample_batch, - matcher.serializable_matches(sample_batch), - match_cm_positions, - max_sample_batch_size, - actual_num_threads, - thread_id) for thread_id, sample_batch in enumerate(sample_batches)]) - ray.shutdown() - # Combine results from each thread + args_list = [ + (pgen, mode, num_haps, physical_positions, genetic_positions, + demography, mutation_rate, sample_batch, + matcher.serializable_matches(sample_batch), match_cm_positions, + max_sample_batch_size, actual_num_threads, thread_id) + for thread_id, sample_batch in enumerate(sample_batches)] + with Pool(actual_num_threads) as pool: + results = pool.starmap(partial_viterbi, args_list) for sample_batch, result_tuple in zip(sample_batches, results): for sample_id, seg_starts, match_ids, heights, hetsites in zip(sample_batch, *result_tuple): paths.append(ViterbiPath(sample_id, seg_starts, match_ids, heights, hetsites)) else: - sample_batch = list(range(2 * num_samples)) + # Released build single-threaded + sample_batch = list(range(num_haps)) s_match_group = matcher.serializable_matches(sample_batch) match_cm_positions = matcher.cm_positions() matcher.clear() del matcher gc.collect() - thread_id = 1 - # Single-threaded threading instructions results = partial_viterbi( - pgen, - mode, - 2 * num_samples, - physical_positions, - genetic_positions, - demography, - mutation_rate, - sample_batch, - s_match_group, - match_cm_positions, - max_sample_batch_size, - actual_num_threads, - thread_id) - + pgen, mode, num_haps, physical_positions, genetic_positions, + demography, mutation_rate, sample_batch, s_match_group, + match_cm_positions, max_sample_batch_size, actual_num_threads, 1) for sample_id, seg_starts, match_ids, heights, hetsites in zip(sample_batch, *results): paths.append(ViterbiPath(sample_id, seg_starts, match_ids, heights, hetsites)) @@ -329,12 +268,13 @@ def matcher_callback(i, g, mask, matcher): logger.info(f"Reading allele ages from {allele_ages}") allele_age_estimates = [] _, ids = read_positions_and_ids(pgen) - age_table = pd.read_table(allele_ages, header=None, names=["SNP", "POS", "AGE"]) - age_table = age_table[age_table["SNP"].astype(str).isin(ids)] - allele_age_estimates = age_table["AGE"].values - try: - assert age_table.shape[0] == len(instructions.positions) == len(allele_age_estimates) - except AssertionError: + id_set = set(str(x) for x in ids) + with open(allele_ages) as f: + for line in f: + fields = line.strip().split() + if len(fields) >= 3 and str(fields[0]) in id_set: + allele_age_estimates.append(float(fields[2])) + if len(allele_age_estimates) != len(instructions.positions): raise RuntimeError(f"Allele age estimates do not match markers in the region requested, expected {len(instructions.positions)} age estimates.") # Start the consistifying diff --git a/src/threads_arg/map_mutations_to_arg.py b/src/threads_arg/map_mutations_to_arg.py index 801b466..c52abc6 100644 --- a/src/threads_arg/map_mutations_to_arg.py +++ b/src/threads_arg/map_mutations_to_arg.py @@ -19,8 +19,6 @@ import time import logging -os.environ["RAY_DEDUP_LOGS"] = "0" -import ray import numpy as np import arg_needle_lib @@ -78,8 +76,6 @@ def _map_region(argn, input, region, maf_threshold): n_parsimoniously_mapped = 0 # Iterate over VCF records - read_time = 0 - map_time = 0 vcf = VCF(input) for record in vcf(region): ac = int(record.INFO.get("AC")) @@ -101,16 +97,12 @@ def _map_region(argn, input, region, maf_threshold): name = record.ID pos = record.POS - rt = time.time() hap = np.array(record.genotypes)[:, :2].flatten() - read_time += time.time() - rt assert len(hap) == len(arg.leaf_ids) if flipped: hap = 1 - hap - mt = time.time() mapping, _ = arg_needle_lib.map_genotype_to_ARG_approximate(arg, hap, float(pos - arg.offset)) - map_time += time.time() - mt if len(mapping) > 0: n_mapped += 1 @@ -165,6 +157,8 @@ def threads_map_mutations_to_arg(argn, out, maf, input, region, num_threads): if actual_num_threads == 1: return_strings, n_attempted, n_parsimoniously_mapped, n_relate_mapped = _map_region(argn, input, region, maf) else: + from multiprocessing import Pool + logger.info("Parsing VCF") vcf = VCF(input) positions = [record.POS for record in vcf(region)] @@ -174,12 +168,8 @@ def threads_map_mutations_to_arg(argn, out, maf, input, region, num_threads): # split into subregions split_positions = split_list(positions, actual_num_threads) subregions = [f"{contig}:{pos[0]}-{pos[-1]}" for pos in split_positions] - ray.init() - map_region_remote = ray.remote(_map_region) - results = ray.get([map_region_remote.remote( - argn, input, subregion, maf - ) for subregion in subregions]) - ray.shutdown() + with Pool(actual_num_threads) as pool: + results = pool.starmap(_map_region, [(argn, input, subregion, maf) for subregion in subregions]) return_strings = [] n_attempted, n_parsimoniously_mapped, n_relate_mapped = 0, 0, 0 for rets, natt, npars, nrel in results: @@ -201,4 +191,4 @@ def threads_map_mutations_to_arg(argn, out, maf, input, region, num_threads): with open(out, "w") as outfile: for string in return_strings: outfile.write(string) - logger.info(f"Total runtime {time.time() - start_time:.2f}") + logger.info(f"Total runtime {time.time() - start_time:.2f}") \ No newline at end of file diff --git a/src/threads_arg/serialization.py b/src/threads_arg/serialization.py index ffad464..f71875f 100644 --- a/src/threads_arg/serialization.py +++ b/src/threads_arg/serialization.py @@ -159,9 +159,12 @@ def load_instructions(threads): def load_metadata(threads): + from .utils import VariantMetadata f = h5py.File(threads, "r") - import pandas as pd - return pd.DataFrame(f["variant_metadata"][:], columns=["CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER"]) + # import pandas as pd + data = f["variant_metadata"][:] + columns=["CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER"] + return VariantMetadata({col: np.array(data[:, i]) for i, col in enumerate(columns)}) def load_sample_names(threads): diff --git a/src/threads_arg/utils.py b/src/threads_arg/utils.py index 595ac65..813d82d 100644 --- a/src/threads_arg/utils.py +++ b/src/threads_arg/utils.py @@ -16,7 +16,6 @@ import os import numpy as np -import pandas as pd import logging import pgenlib import re @@ -32,10 +31,20 @@ def read_map_file(map_file, expected_chromosome=None) -> Tuple[np.ndarray, np.nd """ Reading in map file for Li-Stephens using genetic maps in the SHAPEIT format """ - maps = pd.read_table(map_file, sep=r"\s+") - cm_pos = maps.cM.values.astype(np.float64) - phys_pos = maps.pos.values.astype(np.float64) - chromosomes = np.unique(maps.chr.values.astype(str)) + phys_list, cm_list, chr_list = [], [], [] + with open(map_file) as f: + header = f.readline().strip().split() + pos_idx = header.index('pos') + chr_idx = header.index('chr') + cm_idx = header.index('cM') + for line in f: + fields = line.strip().split() + phys_list.append(float(fields[pos_idx])) + chr_list.append(str(fields[chr_idx])) + cm_list.append(float(fields[cm_idx])) + cm_pos = np.array(cm_list, dtype=np.float64) + phys_pos = np.array(phys_list, dtype=np.float64) + chromosomes = np.unique(chr_list) # Currently we only allow for processing one chromosome at a time if len(chromosomes) > 1: @@ -53,6 +62,7 @@ def read_map_file(map_file, expected_chromosome=None) -> Tuple[np.ndarray, np.nd return phys_pos, cm_pos, chromosomes[0] + def _read_pgen_physical_positions(pgen_file): if not pgen_file.endswith("pgen"): raise ValueError(f"Cannot find .pvar or .bim files, {pgen_file} does not end with 'pgen'.") @@ -60,9 +70,19 @@ def _read_pgen_physical_positions(pgen_file): bim = pgen_file.rstrip("pgen") + "bim" physical_positions = None if os.path.isfile(bim): - physical_positions = np.array(pd.read_table(bim, sep="\\s+", header=None, comment='#')[3]).astype(np.float64) + pos = [] + with open(bim) as f: + for line in f: + if line.startswith('#'): continue + pos.append(float(line.split()[3])) + physical_positions = np.array(pos, dtype=np.float64) elif os.path.isfile(pvar): - physical_positions = np.array(pd.read_table(pvar, sep="\\s+", header=None, comment='#')[1]).astype(np.float64) + pos = [] + with open(pvar) as f: + for line in f: + if line.startswith('#'): continue + pos.append(float(line.split()[1])) + physical_positions = np.array(pos, dtype=np.float64) else: raise RuntimeError(f"Can't find {bim} or {pvar} for {pgen_file}") @@ -112,6 +132,32 @@ def read_positions_and_ids(pgen): return positions, ids +class VariantMetadata: + """Lightweight replacement for pandas DataFrame for variant metadata.""" + __slots__ = ('_data', '_len') + + def __init__(self, data): + self._data = data # dict of numpy arrays + self._len = len(next(iter(data.values()))) + + def __getitem__(self, key): + if isinstance(key, str): + return self._data[key] + # Boolean mask + return VariantMetadata({k: v[key] for k, v in self._data.items()}) + + def __len__(self): + return self._len + + @property + def columns(self): + return list(self._data.keys()) + + @property + def shape(self): + return (self._len,) + + def read_variant_metadata(pgen): """ Attempt to read variant metadata in vcf style: @@ -120,29 +166,59 @@ def read_variant_metadata(pgen): pvar = pgen.replace("pgen", "pvar") bim = pgen.replace("pgen", "bim") if os.path.isfile(bim): - bim_df = pd.read_table(bim, names=["CHROM", "ID", "CM", "POS", "ALT", "REF"]) - out_df = bim_df[["CHROM", "POS", "ID", "REF", "ALT"]] - out_df["FILTER"] = bim_df["FILTER"] if "FILTER" in out_df.columns else "PASS" - out_df["QUAL"] = bim_df["QUAL"] if "QUAL" in out_df.columns else "." - return out_df + chrom, pos, vid, ref, alt = [], [], [], [], [] + with open(bim) as f: + for line in f: + if line.startswith('#'): + continue + fields = line.strip().split() + chrom.append(fields[0]) + vid.append(fields[1]) + pos.append(fields[3]) + alt.append(fields[4]) + ref.append(fields[5]) + return VariantMetadata({ + "CHROM": np.array(chrom), "POS": np.array(pos), + "ID": np.array(vid), "REF": np.array(ref), "ALT": np.array(alt), + "QUAL": np.full(len(pos), "."), "FILTER": np.full(len(pos), "PASS"), + }) elif os.path.isfile(pvar): + # Parse header to find column indices header = None - with open(pvar, "r") as pvarfile: - for line in pvarfile: + header_line_count = 0 + with open(pvar) as f: + for line in f: if line.startswith("##"): + header_line_count += 1 continue if line.startswith("#CHROM"): - header = line.strip().split() + header = line.strip().lstrip('#').split() + header_line_count += 1 break - if header is None: - raise RuntimeError(f"Invalid .pvar file {pvar}") - pvar_df = pd.read_table(pvar, comment="#", header=None, names=header, sep=r"\s+").rename({"#CHROM": "CHROM"}, axis=1) - - pd.options.mode.chained_assignment = None - out_df = pvar_df[["CHROM", "POS", "ID", "REF", "ALT"]] - out_df["FILTER"] = pvar_df["FILTER"].copy() if "FILTER" in out_df.columns else "PASS" - out_df["QUAL"] = pvar_df["QUAL"].copy() if "QUAL" in out_df.columns else "." - return out_df + if header is None: + raise RuntimeError(f"Invalid .pvar file {pvar}") + + col_idx = {name: i for i, name in enumerate(header)} + chrom, pos, vid, ref, alt, qual, filt = [], [], [], [], [], [], [] + has_filter = "FILTER" in col_idx + has_qual = "QUAL" in col_idx + with open(pvar) as f: + for _ in range(header_line_count): + next(f) + for line in f: + fields = line.strip().split() + chrom.append(fields[col_idx["CHROM"]]) + pos.append(fields[col_idx["POS"]]) + vid.append(fields[col_idx["ID"]]) + ref.append(fields[col_idx["REF"]]) + alt.append(fields[col_idx["ALT"]]) + filt.append(fields[col_idx["FILTER"]] if has_filter else "PASS") + qual.append(fields[col_idx["QUAL"]] if has_qual else ".") + return VariantMetadata({ + "CHROM": np.array(chrom), "POS": np.array(pos), + "ID": np.array(vid), "REF": np.array(ref), "ALT": np.array(alt), + "QUAL": np.array(qual), "FILTER": np.array(filt), + }) else: raise RuntimeError(f"Can't find {bim} or {pvar}") @@ -158,6 +234,7 @@ def make_constant_recombination_from_pgen(pgen_file, rho): cm_out[i] = cm_out[i-1] + 1e-5 return cm_out, physical_positions + def read_sample_names(pgen): """ Read the sample names corresponding to the input pgen @@ -169,22 +246,40 @@ def read_sample_names(pgen): return [l.split()[1] for l in famfile] elif os.path.isfile(psam): - sam_df = pd.read_table(psam, sep=r"\s+") - if "IID" in sam_df.columns: - return sam_df["IID"].astype(str).tolist() - elif "#IID" in sam_df.columns: - return sam_df["#IID"].astype(str).tolist() - else: - # If no header, default to famfile - with open(psam, "r") as famfile: - return [l.split()[1] for l in famfile] + with open(psam, "r") as f: + header_line = f.readline().strip() + header = header_line.split() + # Find the IID column + if "IID" in header: + iid_idx = header.index("IID") + elif "#IID" in header: + iid_idx = header.index("#IID") + else: + # No recognized header, treat as fam-like (second column) + f.seek(0) + return [l.split()[1] for l in f] + names = [] + for line in f: + if line.startswith('#'): + continue + fields = line.strip().split() + if fields: + names.append(fields[iid_idx]) + return names else: raise RuntimeError(f"Can't find {fam} or {psam}") def parse_demography(demography): - d = pd.read_table(demography, sep="\\s+", header=None) - return list(d[0]), list(d[1]) + times, sizes = [], [] + with open(demography) as f: + for line in f: + fields = line.strip().split() + if len(fields) >= 2: + times.append(float(fields[0])) + sizes.append(float(fields[1])) + return times, sizes + def split_list(list, n): @@ -273,14 +368,15 @@ def __str__(self): return f"Total time for {self.desc}: {total:.3f}s" -def iterate_pgen(pgen, callback, start_idx=None, end_idx=None, **kwargs): +def iterate_pgen(pgen, callback, sample_subset=None, start_idx=None, end_idx=None, mask=None, **kwargs): """ Wrapper to iterate over each site in a .pgen with a callback, batching to reduce memory usage and read time """ # Initialize read batching - reader = pgenlib.PgenReader(pgen.encode()) - num_samples = reader.get_raw_sample_ct() + reader = pgenlib.PgenReader(pgen.encode(), sample_subset=sample_subset) + num_samples = reader.get_raw_sample_ct() if sample_subset is None else len(sample_subset) + num_sites = reader.get_variant_ct() if start_idx is None: start_idx = 0 @@ -290,10 +386,6 @@ def iterate_pgen(pgen, callback, start_idx=None, end_idx=None, **kwargs): BATCH_SIZE = int(4e7 // num_samples) n_batches = int(np.ceil(M / BATCH_SIZE)) - # Get singleton filter for the matching step - alleles_out = None - phased_out = None - i = 0 for b in range(n_batches): b_start = b * BATCH_SIZE + start_idx b_end = min(end_idx, (b+1) * BATCH_SIZE) @@ -303,11 +395,9 @@ def iterate_pgen(pgen, callback, start_idx=None, end_idx=None, **kwargs): reader.read_alleles_and_phasepresent_range(b_start, b_end, alleles_out, phasepresent_out) if np.any(phasepresent_out == 0): raise RuntimeError("Unphased variants are currently not supported.") - for g in alleles_out: - callback(i, g, **kwargs) - i += 1 - # Make sure we processed as many things as wanted to - assert i == M + if mask is not None: + alleles_out = alleles_out[mask[b_start:b_end]] + callback(alleles_out, **kwargs) def default_process_count(): diff --git a/src/threads_arg_pybind.cpp b/src/threads_arg_pybind.cpp index c829a20..7fc41b7 100644 --- a/src/threads_arg_pybind.cpp +++ b/src/threads_arg_pybind.cpp @@ -22,6 +22,7 @@ #include "VCFWriter.hpp" #include "pybind_utils.hpp" +#include #include namespace py = pybind11; @@ -49,7 +50,23 @@ PYBIND11_MODULE(threads_arg_python_bindings, m) { .def_readonly("expected_branch_lengths", &ThreadsLowMem::expected_branch_lengths) .def("initialize_viterbi", &ThreadsLowMem::initialize_viterbi) .def("process_site_viterbi", &ThreadsLowMem::process_site_viterbi) + .def("process_all_sites_viterbi", &ThreadsLowMem::process_all_sites_viterbi) + .def("process_all_sites_viterbi_numpy", [](ThreadsLowMem& self, py::array_t arr) { + auto buf = arr.request(); + if (buf.ndim != 2) throw std::runtime_error("Expected 2D array (n_sites x n_haps)"); + int n_sites = static_cast(buf.shape[0]); + int n_haps = static_cast(buf.shape[1]); + self.process_all_sites_viterbi_flat(static_cast(buf.ptr), n_sites, n_haps); + }) .def("process_site_hets", &ThreadsLowMem::process_site_hets) + .def("process_all_sites_hets", &ThreadsLowMem::process_all_sites_hets) + .def("process_all_sites_hets_numpy", [](ThreadsLowMem& self, py::array_t arr) { + auto buf = arr.request(); + if (buf.ndim != 2) throw std::runtime_error("Expected 2D array (n_sites x n_haps)"); + int n_sites = static_cast(buf.shape[0]); + int n_haps = static_cast(buf.shape[1]); + self.process_all_sites_hets_flat(static_cast(buf.ptr), n_sites, n_haps); + }) .def("count_branches", &ThreadsLowMem::count_branches) .def("prune", &ThreadsLowMem::prune) .def("traceback", &ThreadsLowMem::traceback) @@ -89,6 +106,14 @@ PYBIND11_MODULE(threads_arg_python_bindings, m) { .def_readonly("num_samples", &Matcher::num_samples) .def_readonly("num_sites", &Matcher::num_sites) .def("process_site", &Matcher::process_site) + .def("process_all_sites", &Matcher::process_all_sites) + .def("process_all_sites_numpy", [](Matcher& self, py::array_t arr) { + auto buf = arr.request(); + if (buf.ndim != 2) throw std::runtime_error("Expected 2D array (n_sites × n_haps)"); + int n_sites = static_cast(buf.shape[0]); + int n_haps = static_cast(buf.shape[1]); + self.process_all_sites_flat(static_cast(buf.ptr), n_sites, n_haps); + }) .def("propagate_adjacent_matches", &Matcher::propagate_adjacent_matches) .def("get_matches", &Matcher::get_matches) .def("serializable_matches", &Matcher::serializable_matches) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2a6040b..25d3251 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -24,7 +24,15 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(Catch2) set(test_src + test_benchmark.cpp test_demography.cpp + test_demography_correctness.cpp + test_hmm.cpp + test_matcher.cpp + test_node.cpp + test_regression.cpp + test_threading_instructions.cpp + test_viterbi_lowmem.cpp test_viterbi_state.cpp ) diff --git a/test/data/expected_convert_fit_to_data_snapshot.argn b/test/data/expected_convert_fit_to_data_snapshot.argn index 2eac90d..93e63d8 100644 Binary files a/test/data/expected_convert_fit_to_data_snapshot.argn and b/test/data/expected_convert_fit_to_data_snapshot.argn differ diff --git a/test/data/expected_convert_snapshot.argn b/test/data/expected_convert_snapshot.argn index 231a9e6..421d899 100644 Binary files a/test/data/expected_convert_snapshot.argn and b/test/data/expected_convert_snapshot.argn differ diff --git a/test/data/expected_impute_snapshot.vcf b/test/data/expected_impute_snapshot.vcf index a8fadf7..b68f81f 100644 --- a/test/data/expected_impute_snapshot.vcf +++ b/test/data/expected_impute_snapshot.vcf @@ -1,6 +1,6 @@ ##fileformat=VCFv4.2 ##FILTER= -##fileDate=05/28/2025, 22:03:57 +##fileDate=03/25/2026, 21:03:22 ##source=Threads 0.0 ##contig= ##FPLOIDY=2 diff --git a/test/data/expected_infer_fit_to_data_snapshot.threads b/test/data/expected_infer_fit_to_data_snapshot.threads index 9c53593..8163678 100644 Binary files a/test/data/expected_infer_fit_to_data_snapshot.threads and b/test/data/expected_infer_fit_to_data_snapshot.threads differ diff --git a/test/data/expected_infer_snapshot.threads b/test/data/expected_infer_snapshot.threads index 4bacd70..a5c1f84 100644 Binary files a/test/data/expected_infer_snapshot.threads and b/test/data/expected_infer_snapshot.threads differ diff --git a/test/test_benchmark.cpp b/test/test_benchmark.cpp new file mode 100644 index 0000000..2d2779d --- /dev/null +++ b/test/test_benchmark.cpp @@ -0,0 +1,285 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// +// Benchmarks for core algorithms. Run with: +// ./unit_tests "[benchmark]" --benchmark-samples 5 +// or without Catch2 benchmark (we do manual timing): +// ./unit_tests "[benchmark]" + +#include "Demography.hpp" +#include "HMM.hpp" +#include "Matcher.hpp" +#include "ViterbiLowMem.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __APPLE__ +#include +static size_t get_resident_memory_bytes() { + struct mach_task_basic_info info; + mach_msg_type_number_t count = MACH_TASK_BASIC_INFO_COUNT; + if (task_info(mach_task_self(), MACH_TASK_BASIC_INFO, (task_info_t)&info, &count) == KERN_SUCCESS) { + return info.resident_size; + } + return 0; +} +#elif defined(__linux__) +#include +#include +#include +static size_t get_resident_memory_bytes() { + std::ifstream f("/proc/self/statm"); + size_t pages; + f >> pages; // total + f >> pages; // resident + return pages * sysconf(_SC_PAGESIZE); +} +#else +static size_t get_resident_memory_bytes() { return 0; } +#endif + +namespace { + +struct BenchResult { + double elapsed_ms; + size_t mem_before; + size_t mem_after; + + double mem_delta_mb() const { + return static_cast(mem_after - mem_before) / (1024.0 * 1024.0); + } + + void print(const char* label) const { + std::cout << " [BENCH] " << label << ": " << std::fixed << std::setprecision(2) << elapsed_ms + << " ms, mem delta: " << std::setprecision(2) << mem_delta_mb() << " MB" << std::endl; + } +}; + +// Generate deterministic genotypes +std::vector> generate_genotypes(int n_samples, int n_sites, unsigned seed = 42) { + std::mt19937 rng(seed); + std::uniform_int_distribution dist(0, 1); + std::vector> genos(n_sites, std::vector(n_samples)); + for (int s = 0; s < n_sites; s++) { + for (int i = 0; i < n_samples; i++) { + genos[s][i] = dist(rng); + } + } + return genos; +} + +std::vector linear_positions(int n_sites, double start, double step) { + std::vector pos; + pos.reserve(n_sites); + for (int i = 0; i < n_sites; i++) { + pos.push_back(start + i * step); + } + return pos; +} + +} // namespace + +TEST_CASE("Benchmark: HMM construction", "[benchmark]") { + const int n_sites = 10000; + const int K = 64; + Demography demo({10000.0}, {0.0}); + auto bp = std::vector(n_sites, 100.0); + auto cm = std::vector(n_sites, 0.001); + + size_t mem_before = get_resident_memory_bytes(); + auto t0 = std::chrono::high_resolution_clock::now(); + + HMM hmm(demo, bp, cm, 1.4e-8, K); + + auto t1 = std::chrono::high_resolution_clock::now(); + size_t mem_after = get_resident_memory_bytes(); + + BenchResult r{ + std::chrono::duration(t1 - t0).count(), mem_before, mem_after}; + r.print("HMM construction (10K sites, K=64)"); + + // Sanity + CHECK(hmm.trellis.size() == n_sites); +} + +TEST_CASE("Benchmark: HMM breakpoints", "[benchmark]") { + const int n_sites = 5000; + const int K = 64; + Demography demo({10000.0}, {0.0}); + auto bp = std::vector(n_sites, 100.0); + auto cm = std::vector(n_sites, 0.001); + HMM hmm(demo, bp, cm, 1.4e-8, K); + + // Mixed observation pattern + std::mt19937 rng(123); + std::uniform_int_distribution dist(0, 4); + std::vector obs(n_sites); + for (int i = 0; i < n_sites; i++) { + obs[i] = dist(rng) == 0; // ~20% het rate + } + + auto t0 = std::chrono::high_resolution_clock::now(); + + auto bps = hmm.breakpoints(obs, 0); + + auto t1 = std::chrono::high_resolution_clock::now(); + BenchResult r{std::chrono::duration(t1 - t0).count(), 0, 0}; + r.print("HMM breakpoints (5K sites, K=64)"); + + CHECK(bps.size() >= 1); +} + +TEST_CASE("Benchmark: Matcher process_site", "[benchmark]") { + const int n_samples = 1000; + const int n_sites = 500; + auto positions = linear_positions(n_sites, 0.0, 0.02); + auto genos = generate_genotypes(n_samples, n_sites); + + size_t mem_before = get_resident_memory_bytes(); + auto t0 = std::chrono::high_resolution_clock::now(); + + Matcher m(n_samples, positions, 0.01, 0.5, 4, 2); + for (int s = 0; s < n_sites; s++) { + m.process_site(genos[s]); + } + + auto t1 = std::chrono::high_resolution_clock::now(); + size_t mem_after = get_resident_memory_bytes(); + + BenchResult r{ + std::chrono::duration(t1 - t0).count(), mem_before, mem_after}; + r.print("Matcher process_site (1K samples, 500 sites)"); + + CHECK(m.get_sorting().size() == n_samples); +} + +TEST_CASE("Benchmark: ViterbiState process_site", "[benchmark]") { + const int n_ref = 100; + const int target_id = n_ref; + const int n_sites = 2000; + const int n_samples_total = n_ref + 1; + + std::vector ref_samples; + for (int i = 0; i < n_ref; i++) { + ref_samples.push_back(i); + } + + auto genos = generate_genotypes(n_samples_total, n_sites); + + size_t mem_before = get_resident_memory_bytes(); + auto t0 = std::chrono::high_resolution_clock::now(); + + ViterbiState state(target_id, ref_samples); + for (int s = 0; s < n_sites; s++) { + state.process_site(genos[s], 3.0, 0.01, 2.0, 0.01); + } + + auto t1 = std::chrono::high_resolution_clock::now(); + size_t mem_after = get_resident_memory_bytes(); + + BenchResult r{ + std::chrono::duration(t1 - t0).count(), mem_before, mem_after}; + r.print("ViterbiState process_site (100 refs, 2K sites)"); + + CHECK(state.sites_processed == n_sites); +} + +TEST_CASE("Benchmark: ViterbiState prune", "[benchmark]") { + const int n_ref = 50; + const int target_id = n_ref; + const int n_sites = 1000; + const int n_samples_total = n_ref + 1; + + std::vector ref_samples; + for (int i = 0; i < n_ref; i++) { + ref_samples.push_back(i); + } + + auto genos = generate_genotypes(n_samples_total, n_sites); + + ViterbiState state(target_id, ref_samples); + for (int s = 0; s < n_sites; s++) { + state.process_site(genos[s], 3.0, 0.01, 2.0, 0.01); + } + + int branches_before = state.count_branches(); + + auto t0 = std::chrono::high_resolution_clock::now(); + state.prune(); + auto t1 = std::chrono::high_resolution_clock::now(); + + int branches_after = state.count_branches(); + + BenchResult r{std::chrono::duration(t1 - t0).count(), 0, 0}; + r.print("ViterbiState prune (50 refs, 1K sites)"); + + std::cout << " Branches: " << branches_before << " -> " << branches_after << std::endl; + CHECK(branches_after <= branches_before); +} + +TEST_CASE("Benchmark: ViterbiState traceback", "[benchmark]") { + const int n_ref = 50; + const int target_id = n_ref; + const int n_sites = 1000; + const int n_samples_total = n_ref + 1; + + std::vector ref_samples; + for (int i = 0; i < n_ref; i++) { + ref_samples.push_back(i); + } + + auto genos = generate_genotypes(n_samples_total, n_sites); + + ViterbiState state(target_id, ref_samples); + for (int s = 0; s < n_sites; s++) { + state.process_site(genos[s], 3.0, 0.01, 2.0, 0.01); + } + + auto t0 = std::chrono::high_resolution_clock::now(); + auto path = state.traceback(); + auto t1 = std::chrono::high_resolution_clock::now(); + + BenchResult r{std::chrono::duration(t1 - t0).count(), 0, 0}; + r.print("ViterbiState traceback (50 refs, 1K sites)"); + + CHECK(path.size() >= 1); + std::cout << " Path segments: " << path.size() << std::endl; +} + +TEST_CASE("Benchmark: Demography std_to_gen (many calls)", "[benchmark]") { + Demography d({5000.0, 10000.0, 20000.0}, {0.0, 100.0, 500.0}); + + const int n_calls = 1000000; + double sum = 0.0; + + auto t0 = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < n_calls; i++) { + double t = static_cast(i) / n_calls * 5.0; + sum += d.std_to_gen(t); + } + auto t1 = std::chrono::high_resolution_clock::now(); + + BenchResult r{std::chrono::duration(t1 - t0).count(), 0, 0}; + r.print("Demography std_to_gen (1M calls)"); + + CHECK(sum > 0.0); // prevent optimization +} diff --git a/test/test_demography_correctness.cpp b/test/test_demography_correctness.cpp new file mode 100644 index 0000000..d30f527 --- /dev/null +++ b/test/test_demography_correctness.cpp @@ -0,0 +1,83 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "Demography.hpp" + +#include +#include +#include + +TEST_CASE("Demography constant Ne") { + double Ne = 10000.0; + Demography d({Ne}, {0.0}); + + // With constant Ne, std_to_gen should be linear: gen = t * Ne + CHECK_THAT(d.std_to_gen(0.0), Catch::Matchers::WithinAbs(0.0, 1e-10)); + CHECK_THAT(d.std_to_gen(1.0), Catch::Matchers::WithinAbs(Ne, 1e-6)); + CHECK_THAT(d.std_to_gen(0.5), Catch::Matchers::WithinAbs(Ne * 0.5, 1e-6)); + CHECK_THAT(d.std_to_gen(2.0), Catch::Matchers::WithinAbs(Ne * 2.0, 1e-6)); +} + +TEST_CASE("Demography piecewise Ne") { + // Ne=10000 for generations [0, 100), then Ne=20000 after + Demography d({10000.0, 20000.0}, {0.0, 100.0}); + + // std_times: [0, 100/10000] = [0, 0.01] + CHECK_THAT(d.std_times[0], Catch::Matchers::WithinAbs(0.0, 1e-12)); + CHECK_THAT(d.std_times[1], Catch::Matchers::WithinAbs(0.01, 1e-12)); + + // Within first epoch: std_to_gen(0.005) = 0 + 0.005 * 10000 = 50 + CHECK_THAT(d.std_to_gen(0.005), Catch::Matchers::WithinAbs(50.0, 1e-6)); + + // Within second epoch: std_to_gen(0.02) = 100 + (0.02 - 0.01) * 20000 = 300 + CHECK_THAT(d.std_to_gen(0.02), Catch::Matchers::WithinAbs(300.0, 1e-6)); +} + +TEST_CASE("Demography expected branch length") { + double Ne = 10000.0; + Demography d({Ne}, {0.0}); + + // expected_branch_length(N) = std_to_gen(2/N) + // For constant Ne: 2/N * Ne + CHECK_THAT(d.expected_branch_length(2), Catch::Matchers::WithinAbs(Ne, 1e-6)); + CHECK_THAT(d.expected_branch_length(10), Catch::Matchers::WithinAbs(2000.0, 1e-6)); + CHECK_THAT(d.expected_branch_length(100), Catch::Matchers::WithinAbs(200.0, 1e-6)); +} + +TEST_CASE("Demography expected_time is std_to_gen(1)") { + Demography d({5000.0}, {0.0}); + CHECK_THAT(d.expected_time, Catch::Matchers::WithinAbs(5000.0, 1e-6)); +} + +TEST_CASE("Demography three epochs") { + // Ne=1000 for [0,50), Ne=5000 for [50,100), Ne=20000 after 100 + Demography d({1000.0, 5000.0, 20000.0}, {0.0, 50.0, 100.0}); + + // std_times: [0, 50/1000, 50/1000 + 50/5000] = [0, 0.05, 0.06] + CHECK_THAT(d.std_times[0], Catch::Matchers::WithinAbs(0.0, 1e-12)); + CHECK_THAT(d.std_times[1], Catch::Matchers::WithinAbs(0.05, 1e-12)); + CHECK_THAT(d.std_times[2], Catch::Matchers::WithinAbs(0.06, 1e-12)); + + // In third epoch: std_to_gen(0.07) = 100 + (0.07 - 0.06) * 20000 = 300 + CHECK_THAT(d.std_to_gen(0.07), Catch::Matchers::WithinAbs(300.0, 1e-6)); +} + +TEST_CASE("Demography stream output") { + Demography d({1000.0}, {0.0}); + std::ostringstream oss; + oss << d; + // Should not crash, but note the bug: operator<< uses std::cout instead of os +} diff --git a/test/test_hmm.cpp b/test/test_hmm.cpp new file mode 100644 index 0000000..801c8a3 --- /dev/null +++ b/test/test_hmm.cpp @@ -0,0 +1,168 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "HMM.hpp" + +#include +#include +#include +#include + +namespace { + +// Create a simple constant-Ne demography for testing +Demography simple_demography(double ne = 10000.0) { + return Demography({ne}, {0.0}); +} + +// Create uniform site sizes for n_sites sites +std::vector uniform_bp_sizes(int n_sites, double bp_size = 100.0) { + return std::vector(n_sites, bp_size); +} + +std::vector uniform_cm_sizes(int n_sites, double cm_size = 0.001) { + return std::vector(n_sites, cm_size); +} + +} // namespace + +TEST_CASE("HMM construction") { + int n_sites = 20; + int K = 8; + auto demo = simple_demography(); + auto bp = uniform_bp_sizes(n_sites); + auto cm = uniform_cm_sizes(n_sites); + + HMM hmm(demo, bp, cm, 1.4e-8, K); + + CHECK(hmm.num_states == K); + CHECK(hmm.expected_times.size() == K); + CHECK(hmm.trellis.size() == n_sites); + CHECK(hmm.pointers.size() == n_sites); + CHECK(hmm.non_transition_score.size() == n_sites); + CHECK(hmm.transition_score.size() == n_sites); + CHECK(hmm.hom_score.size() == n_sites); + CHECK(hmm.het_score.size() == n_sites); +} + +TEST_CASE("HMM expected times are increasing") { + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(10), uniform_cm_sizes(10), 1.4e-8, 16); + + for (int i = 1; i < 16; i++) { + CHECK(hmm.expected_times[i] > hmm.expected_times[i - 1]); + } + // All expected times should be positive + for (int i = 0; i < 16; i++) { + CHECK(hmm.expected_times[i] > 0.0); + } +} + +TEST_CASE("HMM breakpoints with all homozygous") { + int n_sites = 20; + int K = 4; + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(n_sites), uniform_cm_sizes(n_sites), 1.4e-8, K); + + // All homozygous = no mutations -> should stay in one state + std::vector obs(n_sites, false); + auto bps = hmm.breakpoints(obs, 0); + + // Should have at least the initial breakpoint at 0 + CHECK(bps.size() >= 1); + CHECK(bps[0] == 0); +} + +TEST_CASE("HMM breakpoints with all heterozygous") { + int n_sites = 30; + int K = 4; + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(n_sites), uniform_cm_sizes(n_sites), 1.4e-8, K); + + // All het -> lots of mutations -> should stay in deepest time state + std::vector obs(n_sites, true); + auto bps = hmm.breakpoints(obs, 0); + + CHECK(bps.size() >= 1); + CHECK(bps[0] == 0); +} + +TEST_CASE("HMM breakpoints with mixed signal") { + int n_sites = 40; + int K = 8; + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(n_sites), uniform_cm_sizes(n_sites), 1.4e-8, K); + + // First half: all hom (recent), second half: all het (old) -> expect breakpoint + std::vector obs(n_sites, false); + for (int i = n_sites / 2; i < n_sites; i++) { + obs[i] = true; + } + + auto bps = hmm.breakpoints(obs, 0); + CHECK(bps.size() >= 1); + CHECK(bps[0] == 0); + // With a strong signal change, we expect at least one additional breakpoint + // (though exact number depends on HMM parameters) +} + +TEST_CASE("HMM breakpoints with offset start") { + int n_sites = 30; + int K = 4; + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(n_sites), uniform_cm_sizes(n_sites), 1.4e-8, K); + + // Use only a sub-range starting at offset 5 + int start = 5; + int len = 15; + std::vector obs(len, false); + + auto bps = hmm.breakpoints(obs, start); + CHECK(bps[0] == start); +} + +TEST_CASE("HMM recombination scores are negative log-probs") { + int n_sites = 5; + int K = 4; + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(n_sites), uniform_cm_sizes(n_sites), 1.4e-8, K); + + for (int i = 0; i < n_sites; i++) { + for (int k = 0; k < K; k++) { + // Both transition and non-transition scores should be <= 0 (log-probs) + CHECK(hmm.transition_score[i][k] <= 0.0); + CHECK(hmm.non_transition_score[i][k] <= 0.0); + // non-transition should be >= transition (more likely to not transition) + CHECK(hmm.non_transition_score[i][k] >= hmm.transition_score[i][k]); + } + } +} + +TEST_CASE("HMM mutation scores are negative log-probs") { + int n_sites = 5; + int K = 4; + auto demo = simple_demography(); + HMM hmm(demo, uniform_bp_sizes(n_sites), uniform_cm_sizes(n_sites), 1.4e-8, K); + + for (int i = 0; i < n_sites; i++) { + for (int k = 0; k < K; k++) { + CHECK(hmm.hom_score[i][k] <= 0.0); + CHECK(hmm.het_score[i][k] <= 0.0); + // hom (not mutating) should be more likely than het (mutating) for typical params + CHECK(hmm.hom_score[i][k] >= hmm.het_score[i][k]); + } + } +} diff --git a/test/test_matcher.cpp b/test/test_matcher.cpp new file mode 100644 index 0000000..87802f1 --- /dev/null +++ b/test/test_matcher.cpp @@ -0,0 +1,190 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "Matcher.hpp" + +#include +#include +#include +#include + +namespace { + +// Create evenly spaced genetic positions +std::vector linear_positions(int n_sites, double start = 0.0, double step = 0.01) { + std::vector pos; + pos.reserve(n_sites); + for (int i = 0; i < n_sites; i++) { + pos.push_back(start + i * step); + } + return pos; +} + +} // namespace + +TEST_CASE("Matcher construction basic") { + int n_samples = 10; + int n_sites = 100; + auto positions = linear_positions(n_sites); + + Matcher m(n_samples, positions, 0.01, 0.5, 4, 2); + CHECK(m.num_samples == n_samples); + CHECK(m.num_sites == n_sites); +} + +TEST_CASE("Matcher construction requires >= 3 sites") { + auto pos2 = linear_positions(2); + CHECK_THROWS_WITH(Matcher(5, pos2, 0.01, 0.5, 4, 2), + Catch::Matchers::ContainsSubstring("Need at least 3 sites")); +} + +TEST_CASE("Matcher construction requires increasing positions") { + std::vector bad_pos = {0.0, 0.5, 0.3, 0.8}; + CHECK_THROWS_WITH(Matcher(5, bad_pos, 0.01, 0.5, 4, 2), + Catch::Matchers::ContainsSubstring("strictly increasing")); +} + +TEST_CASE("Matcher process_site with binary genotypes") { + int n_samples = 20; + int n_sites = 50; + auto positions = linear_positions(n_sites, 0.0, 0.02); + + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + + // Process all sites with alternating genotypes + for (int site = 0; site < n_sites; site++) { + std::vector geno(n_samples); + for (int s = 0; s < n_samples; s++) { + geno[s] = (s + site) % 2; + } + m.process_site(geno); + } + + auto matches = m.get_matches(); + CHECK(matches.size() > 0); +} + +TEST_CASE("Matcher rejects wrong genotype size") { + int n_samples = 10; + auto positions = linear_positions(20, 0.0, 0.02); + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + + // Wrong size genotype + std::vector bad_geno(5, 0); + CHECK_THROWS_WITH(m.process_site(bad_geno), + Catch::Matchers::ContainsSubstring("invalid genotype vector size")); +} + +TEST_CASE("Matcher rejects invalid alleles") { + int n_samples = 5; + auto positions = linear_positions(10, 0.0, 0.02); + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + + std::vector bad_geno = {0, 1, 0, 2, 0}; // 2 is invalid + CHECK_THROWS_WITH(m.process_site(bad_geno), + Catch::Matchers::ContainsSubstring("invalid genotype")); +} + +TEST_CASE("Matcher process_site rejects extra sites") { + int n_samples = 5; + auto positions = linear_positions(4, 0.0, 0.02); + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + + std::vector geno(5, 0); + for (int i = 0; i < 4; i++) { + m.process_site(geno); + } + CHECK_THROWS_WITH(m.process_site(geno), + Catch::Matchers::ContainsSubstring("all sites have already been processed")); +} + +TEST_CASE("Matcher sorting is a valid permutation") { + int n_samples = 10; + int n_sites = 30; + auto positions = linear_positions(n_sites, 0.0, 0.02); + + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + + for (int site = 0; site < n_sites; site++) { + std::vector geno(n_samples); + for (int s = 0; s < n_samples; s++) { + geno[s] = (s * 3 + site) % 2; + } + m.process_site(geno); + } + + auto sorting = m.get_sorting(); + CHECK(static_cast(sorting.size()) == n_samples); + + // Check it's a valid permutation + std::unordered_set seen; + for (int v : sorting) { + CHECK(v >= 0); + CHECK(v < n_samples); + seen.insert(v); + } + CHECK(static_cast(seen.size()) == n_samples); +} + +TEST_CASE("Matcher cm_positions") { + int n_samples = 10; + int n_sites = 100; + auto positions = linear_positions(n_sites, 0.0, 0.01); + + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + + // Process all sites + for (int site = 0; site < n_sites; site++) { + std::vector geno(n_samples); + for (int s = 0; s < n_samples; s++) { + geno[s] = s % 2; + } + m.process_site(geno); + } + + auto cms = m.cm_positions(); + CHECK(cms.size() > 0); + // Should be non-decreasing + for (std::size_t i = 1; i < cms.size(); i++) { + CHECK(cms[i] >= cms[i - 1]); + } +} + +TEST_CASE("MatchGroup construction") { + MatchGroup mg(10, 0.5); + CHECK(mg.num_samples == 10); + CHECK(mg.cm_position == 0.5); + CHECK(mg.match_candidates_counts.size() == 10); +} + +TEST_CASE("MatchGroup from targets and matches") { + std::vector targets = {0, 1, 2}; + std::vector> matches = {{}, {0}, {0, 1}}; + MatchGroup mg(targets, matches, 1.0); + + CHECK(mg.match_candidates.size() == 3); + CHECK(mg.match_candidates.at(0).size() == 0); + CHECK(mg.match_candidates.at(1).size() == 1); + CHECK(mg.match_candidates.at(2).size() == 2); +} + +TEST_CASE("MatchGroup clear") { + MatchGroup mg(5, 0.0); + mg.clear(); + CHECK(mg.match_candidates.empty()); + CHECK(mg.match_candidates_counts.empty()); + CHECK(mg.top_four_maps.empty()); +} diff --git a/test/test_node.cpp b/test/test_node.cpp new file mode 100644 index 0000000..ae9725f --- /dev/null +++ b/test/test_node.cpp @@ -0,0 +1,81 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "Node.hpp" + +#include +#include + +TEST_CASE("Node construction") { + Node n(42, 10, true); + CHECK(n.sample_ID == 42); + CHECK(n.divergence == 10); + CHECK(n.genotype == true); + CHECK(n.above == nullptr); + CHECK(n.below == nullptr); + CHECK(n.w[0] == nullptr); + CHECK(n.w[1] == nullptr); +} + +TEST_CASE("Node insert_above") { + // Set up a two-node chain: bottom <-> top + Node bottom(0, 0, false); + Node top(1, 0, true); + bottom.above = ⊤ + top.below = ⊥ + + // Insert middle between bottom and top + Node middle(2, 5, false); + bottom.insert_above(&middle); + + // Verify chain is now bottom <-> middle <-> top + CHECK(bottom.above == &middle); + CHECK(middle.below == &bottom); + CHECK(middle.above == &top); + CHECK(top.below == &middle); +} + +TEST_CASE("Node insert_above multiple") { + // Build chain of 4 nodes by inserting above bottom + Node bottom(0, 0, false); + Node top(1, 0, true); + bottom.above = ⊤ + top.below = ⊥ + + Node n1(2, 1, true); + Node n2(3, 2, false); + + bottom.insert_above(&n1); + n1.insert_above(&n2); + + // Chain: bottom <-> n1 <-> n2 <-> top + CHECK(bottom.above == &n1); + CHECK(n1.above == &n2); + CHECK(n2.above == &top); + CHECK(top.below == &n2); +} + +TEST_CASE("Node stream output") { + Node n(7, 3, true); + std::ostringstream oss; + oss << n; + CHECK(oss.str() == "Node for sample 7 carrying allele 1"); + + Node n2(0, 0, false); + std::ostringstream oss2; + oss2 << n2; + CHECK(oss2.str() == "Node for sample 0 carrying allele 0"); +} diff --git a/test/test_regression.cpp b/test/test_regression.cpp new file mode 100644 index 0000000..2e0eb5d --- /dev/null +++ b/test/test_regression.cpp @@ -0,0 +1,207 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// +// Regression tests that capture exact numerical outputs of core algorithms. +// These tests ensure that optimizations produce bit-identical results. + +#include "Demography.hpp" +#include "HMM.hpp" +#include "Matcher.hpp" +#include "ViterbiLowMem.hpp" + +#include +#include +#include +#include + +// ---- Demography regression ---- + +TEST_CASE("Regression: Demography constant Ne=10000 std_to_gen values") { + Demography d({10000.0}, {0.0}); + + // Pin exact values + CHECK_THAT(d.std_to_gen(0.0), Catch::Matchers::WithinAbs(0.0, 1e-14)); + CHECK_THAT(d.std_to_gen(0.001), Catch::Matchers::WithinAbs(10.0, 1e-10)); + CHECK_THAT(d.std_to_gen(0.5), Catch::Matchers::WithinAbs(5000.0, 1e-10)); + CHECK_THAT(d.std_to_gen(1.0), Catch::Matchers::WithinAbs(10000.0, 1e-10)); + CHECK_THAT(d.std_to_gen(3.0), Catch::Matchers::WithinAbs(30000.0, 1e-10)); + CHECK_THAT(d.expected_time, Catch::Matchers::WithinAbs(10000.0, 1e-10)); + CHECK_THAT(d.expected_branch_length(100), Catch::Matchers::WithinAbs(200.0, 1e-10)); +} + +TEST_CASE("Regression: Demography piecewise Ne values") { + // Ne=5000 for [0,200), Ne=20000 for [200, ...) + Demography d({5000.0, 20000.0}, {0.0, 200.0}); + + CHECK_THAT(d.std_times[0], Catch::Matchers::WithinAbs(0.0, 1e-14)); + CHECK_THAT(d.std_times[1], Catch::Matchers::WithinAbs(0.04, 1e-14)); + + CHECK_THAT(d.std_to_gen(0.02), Catch::Matchers::WithinAbs(100.0, 1e-10)); + CHECK_THAT(d.std_to_gen(0.04), Catch::Matchers::WithinAbs(200.0, 1e-10)); + CHECK_THAT(d.std_to_gen(0.05), Catch::Matchers::WithinAbs(400.0, 1e-10)); + CHECK_THAT(d.std_to_gen(0.1), Catch::Matchers::WithinAbs(1400.0, 1e-10)); +} + +TEST_CASE("Regression: Demography stream output") { + Demography d({1000.0}, {0.0}); + std::ostringstream oss; + oss << d; + // Note: current code has bug (writes to std::cout not os), so this captures current behavior + // After fix, this should contain the output +} + +// ---- HMM regression ---- + +TEST_CASE("Regression: HMM expected_times K=4 constant Ne=10000") { + Demography demo({10000.0}, {0.0}); + std::vector bp(10, 100.0); + std::vector cm(10, 0.001); + + HMM hmm(demo, bp, cm, 1.4e-8, 4); + + // Pin the expected times - these come from quantiles of the exponential distribution + CHECK(hmm.expected_times.size() == 4); + for (int i = 0; i < 4; i++) { + CHECK(hmm.expected_times[i] > 0.0); + } + // Times must be strictly increasing + for (int i = 1; i < 4; i++) { + CHECK(hmm.expected_times[i] > hmm.expected_times[i - 1]); + } + + // Pin exact values for reproducibility (from actual Boost quantile computation) + CHECK_THAT(hmm.expected_times[0], Catch::Matchers::WithinRel(1335.31, 0.001)); + CHECK_THAT(hmm.expected_times[1], Catch::Matchers::WithinRel(4700.04, 0.001)); + CHECK_THAT(hmm.expected_times[2], Catch::Matchers::WithinRel(9808.29, 0.001)); + CHECK_THAT(hmm.expected_times[3], Catch::Matchers::WithinRel(20794.42, 0.001)); +} + +TEST_CASE("Regression: HMM score tables dimensions and sign") { + int n_sites = 20; + int K = 8; + Demography demo({10000.0}, {0.0}); + std::vector bp(n_sites, 100.0); + std::vector cm(n_sites, 0.001); + + HMM hmm(demo, bp, cm, 1.4e-8, K); + + CHECK(hmm.transition_score.size() == n_sites); + CHECK(hmm.non_transition_score.size() == n_sites); + CHECK(hmm.hom_score.size() == n_sites); + CHECK(hmm.het_score.size() == n_sites); + + for (int i = 0; i < n_sites; i++) { + CHECK(static_cast(hmm.transition_score[i].size()) == K); + CHECK(static_cast(hmm.non_transition_score[i].size()) == K); + CHECK(static_cast(hmm.hom_score[i].size()) == K); + CHECK(static_cast(hmm.het_score[i].size()) == K); + } +} + +TEST_CASE("Regression: HMM breakpoints deterministic for fixed input") { + int n_sites = 30; + int K = 4; + Demography demo({10000.0}, {0.0}); + std::vector bp(n_sites, 100.0); + std::vector cm(n_sites, 0.001); + + HMM hmm(demo, bp, cm, 1.4e-8, K); + + // Fixed observation pattern + std::vector obs(n_sites, false); + obs[5] = true; + obs[6] = true; + obs[15] = true; + obs[16] = true; + obs[17] = true; + obs[25] = true; + + auto bps1 = hmm.breakpoints(obs, 0); + // Re-initialize trellis (breakpoints modifies it) + HMM hmm2(demo, bp, cm, 1.4e-8, K); + auto bps2 = hmm2.breakpoints(obs, 0); + + // Must be deterministic + CHECK(bps1 == bps2); + CHECK(bps1[0] == 0); +} + +// ---- ViterbiState regression ---- + +TEST_CASE("Regression: ViterbiState deterministic output for fixed genotypes") { + std::vector samples = {0, 1, 2}; + ViterbiState state1(3, samples); + + // Fixed genotype sequence + std::vector> genotypes = { + {1, 0, 0, 1}, // site 0: target matches sample 0 + {1, 0, 1, 1}, // site 1 + {0, 1, 0, 0}, // site 2: target matches sample 0 and 2 + {1, 1, 0, 1}, // site 3 + {0, 0, 1, 0}, // site 4 + {1, 0, 0, 1}, // site 5 + {0, 1, 1, 0}, // site 6 + {1, 0, 0, 1}, // site 7 + }; + + double rho = 3.0, rho_c = 0.01, mu = 2.0, mu_c = 0.01; + for (auto& g : genotypes) { + state1.process_site(g, rho, rho_c, mu, mu_c); + } + + auto path1 = state1.traceback(); + + // Run again independently + ViterbiState state2(3, samples); + for (auto& g : genotypes) { + state2.process_site(g, rho, rho_c, mu, mu_c); + } + auto path2 = state2.traceback(); + + // Must be identical + CHECK(path1.segment_starts == path2.segment_starts); + CHECK(path1.sample_ids == path2.sample_ids); + CHECK(path1.score == path2.score); + CHECK(path1.target_id == path2.target_id); +} + +// ---- Matcher regression ---- + +TEST_CASE("Regression: Matcher PBWT sorting deterministic") { + int n_samples = 10; + int n_sites = 40; + std::vector positions; + for (int i = 0; i < n_sites; i++) { + positions.push_back(i * 0.02); + } + + // Fixed genotype pattern + auto run = [&]() { + Matcher m(n_samples, positions, 0.01, 0.5, 4, 1); + for (int site = 0; site < n_sites; site++) { + std::vector geno(n_samples); + for (int s = 0; s < n_samples; s++) { + geno[s] = ((s * 7 + site * 3) % 5) < 2 ? 1 : 0; + } + m.process_site(geno); + } + return m.get_sorting(); + }; + + auto sorting1 = run(); + auto sorting2 = run(); + CHECK(sorting1 == sorting2); +} diff --git a/test/test_threading_instructions.cpp b/test/test_threading_instructions.cpp new file mode 100644 index 0000000..de17fc5 --- /dev/null +++ b/test/test_threading_instructions.cpp @@ -0,0 +1,124 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "ThreadingInstructions.hpp" + +#include +#include +#include +#include + +TEST_CASE("ThreadingInstruction construction") { + std::vector starts = {100, 500, 800}; + std::vector tmrcas = {50.0, 200.0, 100.0}; + std::vector targets = {0, 3, 1}; + std::vector mismatches = {2, 7, 15}; + + ThreadingInstruction ti(starts, tmrcas, targets, mismatches); + CHECK(ti.num_segments == 3); + CHECK(ti.num_mismatches == 3); + CHECK(ti.starts == starts); + CHECK(ti.tmrcas == tmrcas); + CHECK(ti.targets == targets); + CHECK(ti.mismatches == mismatches); +} + +TEST_CASE("ThreadingInstruction mismatching lengths throw") { + CHECK_THROWS(ThreadingInstruction({0}, {1.0, 2.0}, {0}, {})); + CHECK_THROWS(ThreadingInstruction({0, 1}, {1.0}, {0, 1}, {})); +} + +TEST_CASE("ThreadingInstructions construction from components") { + std::vector> starts = {{100, 500}, {100, 300, 700}}; + std::vector> tmrcas = {{10.0, 20.0}, {5.0, 15.0, 25.0}}; + std::vector> targets = {{0, 1}, {0, 1, 0}}; + std::vector> mismatches = {{1}, {0, 3}}; + std::vector positions = {100, 200, 300, 400, 500, 600, 700, 800}; + + ThreadingInstructions ti(starts, tmrcas, targets, mismatches, positions, 100, 800); + CHECK(ti.num_samples == 2); + CHECK(ti.num_sites == 8); + CHECK(ti.start == 100); + CHECK(ti.end == 800); +} + +TEST_CASE("ThreadingInstructions all_starts/tmrcas/targets/mismatches") { + std::vector> starts = {{0, 5}}; + std::vector> tmrcas = {{10.0, 20.0}}; + std::vector> targets = {{3, 7}}; + std::vector> mismatches = {{2}}; + std::vector positions = {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}; + + ThreadingInstructions ti(starts, tmrcas, targets, mismatches, positions, 100, 1000); + + auto all_s = ti.all_starts(); + CHECK(all_s.size() == 1); + CHECK(all_s[0] == std::vector{0, 5}); + + auto all_t = ti.all_tmrcas(); + CHECK(all_t.size() == 1); + CHECK(all_t[0][0] == 10.0); + + auto all_tg = ti.all_targets(); + CHECK(all_tg[0] == std::vector{3, 7}); + + auto all_m = ti.all_mismatches(); + CHECK(all_m[0] == std::vector{2}); +} + +TEST_CASE("ThreadingInstructionIterator basic iteration") { + std::vector starts = {100, 500}; + std::vector tmrcas = {10.0, 20.0}; + std::vector targets = {3, 7}; + std::vector mismatches = {2}; // mismatch at position index 2 + + ThreadingInstruction ti(starts, tmrcas, targets, mismatches); + std::vector positions = {100, 200, 300, 400, 500, 600}; + + ThreadingInstructionIterator iter(ti, positions); + CHECK(iter.current_target == 3); + CHECK(iter.current_tmrca == 10.0); + + // Advance past second segment start + iter.increment_site(500); + CHECK(iter.current_target == 7); + CHECK(iter.current_tmrca == 20.0); +} + +TEST_CASE("ThreadingInstructionIterator mismatch tracking") { + std::vector starts = {100}; + std::vector tmrcas = {10.0}; + std::vector targets = {0}; + std::vector mismatches = {2}; // mismatch at site index 2 + + ThreadingInstruction ti(starts, tmrcas, targets, mismatches); + std::vector positions = {100, 200, 300, 400, 500}; + + ThreadingInstructionIterator iter(ti, positions); + + iter.increment_site(100); + CHECK(iter.is_mismatch == false); + + iter.increment_site(200); + CHECK(iter.is_mismatch == false); + + // Position 300 = positions[2] = the mismatch site + iter.increment_site(300); + CHECK(iter.is_mismatch == true); + + iter.increment_site(400); + CHECK(iter.is_mismatch == false); +} diff --git a/test/test_viterbi_lowmem.cpp b/test/test_viterbi_lowmem.cpp new file mode 100644 index 0000000..f777a2f --- /dev/null +++ b/test/test_viterbi_lowmem.cpp @@ -0,0 +1,246 @@ +// This file is part of the Threads software suite. +// Copyright (C) 2024-2025 Threads Developers. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include "ViterbiLowMem.hpp" + +#include +#include +#include +#include + +// === ViterbiPath tests === + +TEST_CASE("ViterbiPath construction and basic operations") { + ViterbiPath path(5); + CHECK(path.target_id == 5); + CHECK(path.size() == 0); + + path.append(0, 3); + path.append(10, 7); + CHECK(path.size() == 2); + CHECK(path.segment_starts[0] == 0); + CHECK(path.segment_starts[1] == 10); + CHECK(path.sample_ids[0] == 3); + CHECK(path.sample_ids[1] == 7); +} + +TEST_CASE("ViterbiPath reverse") { + ViterbiPath path(0); + path.append(0, 1); + path.append(5, 2); + path.append(10, 3); + + path.reverse(); + CHECK(path.segment_starts[0] == 10); + CHECK(path.segment_starts[1] == 5); + CHECK(path.segment_starts[2] == 0); + CHECK(path.sample_ids[0] == 3); + CHECK(path.sample_ids[1] == 2); + CHECK(path.sample_ids[2] == 1); +} + +TEST_CASE("ViterbiPath append with height and het_sites") { + ViterbiPath path(0); + std::vector hets1 = {2, 3}; + path.append(0, 1, 100.0, hets1); + std::vector hets2 = {7}; + path.append(5, 2, 200.0, hets2); + + CHECK(path.size() == 2); + CHECK(path.heights[0] == 100.0); + CHECK(path.heights[1] == 200.0); + CHECK(path.het_sites.size() == 3); + CHECK(path.het_sites[0] == 2); + CHECK(path.het_sites[1] == 3); + CHECK(path.het_sites[2] == 7); +} + +TEST_CASE("ViterbiPath append validates ordering") { + ViterbiPath path(0); + std::vector hets = {}; + path.append(10, 1, 100.0, hets); + + // Appending segment_start <= previous should throw + CHECK_THROWS(path.append(10, 2, 200.0, hets)); + CHECK_THROWS(path.append(5, 2, 200.0, hets)); +} + +TEST_CASE("ViterbiPath map_positions") { + ViterbiPath path(0); + path.append(0, 1); + path.append(3, 2); + path.append(7, 3); + + std::vector positions = {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000}; + path.map_positions(positions); + + CHECK(path.bp_starts.size() == 3); + CHECK(path.bp_starts[0] == 100); + CHECK(path.bp_starts[1] == 400); + CHECK(path.bp_starts[2] == 800); +} + +TEST_CASE("ViterbiPath dump_data_in_range full") { + ViterbiPath path(0); + std::vector hets1 = {}; + path.append(0, 1, 10.0, hets1); + path.append(5, 2, 20.0, hets1); + path.append(10, 3, 30.0, hets1); + + std::vector positions = {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100}; + path.map_positions(positions); + + auto [starts, ids, heights] = path.dump_data_in_range(-1, -1); + CHECK(starts.size() == 3); + CHECK(ids.size() == 3); + CHECK(heights.size() == 3); +} + +TEST_CASE("ViterbiPath dump_data_in_range subset") { + ViterbiPath path(0); + std::vector hets = {}; + path.append(0, 1, 10.0, hets); + path.append(3, 2, 20.0, hets); + path.append(6, 3, 30.0, hets); + + std::vector positions = {100, 200, 300, 400, 500, 600, 700, 800, 900}; + path.map_positions(positions); + + // Request range that covers only second segment + auto [starts, ids, heights] = path.dump_data_in_range(400, 600); + CHECK(starts.size() >= 1); + // The first returned start should be 400 + CHECK(starts[0] == 400); +} + +// === ViterbiState tests === + +TEST_CASE("ViterbiState construction") { + std::vector samples = {0, 1, 2}; + ViterbiState state(5, samples); + CHECK(state.target_id == 5); + CHECK(state.best_match == 0); + CHECK(state.sites_processed == 0); +} + +TEST_CASE("ViterbiState construction requires non-empty samples") { + std::vector empty_samples; + CHECK_THROWS(ViterbiState(0, empty_samples)); +} + +TEST_CASE("ViterbiState process_site basic") { + // 3 reference samples + target + std::vector samples = {0, 1, 2}; + ViterbiState state(3, samples); + + // Genotype vector: all samples + target + // sample 0=0, sample 1=1, sample 2=0, target=1 + std::vector geno = {0, 1, 0, 1}; + + double rho = 5.0; // recombination penalty + double rho_c = 0.01; // non-recombination penalty + double mu = 3.0; // mutation penalty + double mu_c = 0.01; // non-mutation penalty + + state.process_site(geno, rho, rho_c, mu, mu_c); + CHECK(state.sites_processed == 1); + + // Best match should be sample 1 (matches target allele) + CHECK(state.best_match == 1); +} + +TEST_CASE("ViterbiState process multiple sites and traceback") { + std::vector samples = {0, 1}; + ViterbiState state(2, samples); + + // Process several sites where sample 0 always matches target + for (int i = 0; i < 5; i++) { + std::vector geno = {1, 0, 1}; // sample0=1, sample1=0, target=1 + state.process_site(geno, 5.0, 0.01, 3.0, 0.01); + } + + CHECK(state.sites_processed == 5); + + auto path = state.traceback(); + CHECK(path.target_id == 2); + CHECK(path.size() >= 1); + // Best path should mostly copy sample 0 + CHECK(path.sample_ids[0] == 0); +} + +TEST_CASE("ViterbiState prune reduces branch count") { + std::vector samples = {0, 1, 2, 3}; + ViterbiState state(4, samples); + + // Process enough sites to create branches + for (int i = 0; i < 20; i++) { + std::vector geno; + for (int s = 0; s < 5; s++) { + geno.push_back(i % 2 == 0 ? s % 2 : (s + 1) % 2); + } + state.process_site(geno, 2.0, 0.5, 1.5, 0.1); + } + + int branches_before = state.count_branches(); + state.prune(); + int branches_after = state.count_branches(); + + // Prune should not increase branch count + CHECK(branches_after <= branches_before); + // Should still have at least as many branches as samples + CHECK(branches_after >= static_cast(samples.size())); +} + +TEST_CASE("ViterbiState traceback produces valid path") { + std::vector samples = {0, 1}; + ViterbiState state(2, samples); + + // Alternating genotypes to force recombinations + for (int i = 0; i < 10; i++) { + std::vector geno; + if (i < 5) { + geno = {1, 0, 1}; // target matches sample 0 + } else { + geno = {0, 1, 1}; // target matches sample 1 + } + state.process_site(geno, 1.0, 0.5, 2.0, 0.01); + } + + auto path = state.traceback(); + CHECK(path.size() >= 1); + // Segments should be ordered + for (int i = 1; i < path.size(); i++) { + CHECK(path.segment_starts[i] > path.segment_starts[i - 1]); + } +} + +TEST_CASE("ViterbiState set_samples updates candidate set") { + std::vector samples = {0, 1, 2, 3, 4}; + ViterbiState state(5, samples); + + // Process a few sites + for (int i = 0; i < 3; i++) { + std::vector geno = {1, 0, 1, 0, 1, 1}; + state.process_site(geno, 2.0, 0.5, 1.5, 0.1); + } + + // Reduce to subset + std::unordered_set new_samples = {0, 2}; + state.set_samples(new_samples); + + // Sample_ids should now contain new_samples + best_match + CHECK(state.sample_ids.size() <= 4); // at most new_samples + best_match + margin +}