Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ requires-python = ">=3.9"

dependencies = [
"click",
"xarray",
"h5py",
"pandas",
"numba",
"numpy",
"tszip",
"arg-needle-lib==1.2.1",
"cyvcf2",
"ray",
"pgenlib",
"tqdm"
]
Expand Down
2 changes: 1 addition & 1 deletion src/DataConsistency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ ThreadingInstructions ConsistencyWrapper::get_consistent_instructions() {
// Make output threading instructions
std::vector<ThreadingInstruction> output_instructions;
output_instructions.reserve(instruction_converters.size());
for (InstructionConverter converter : instruction_converters) {
for (auto& converter : instruction_converters) {
output_instructions.push_back(converter.parse_converted_instructions());
}

Expand Down
2 changes: 1 addition & 1 deletion src/Demography.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ double Demography::expected_branch_length(const int N) {

std::ostream& operator<<(std::ostream& os, const Demography& d) {
for (std::size_t i = 0; i < d.sizes.size(); i++) {
std::cout << d.times[i] << " " << d.sizes[i] << " " << d.std_times[i] << std::endl;
os << d.times[i] << " " << d.sizes[i] << " " << d.std_times[i] << "\n";
}
return os;
}
54 changes: 31 additions & 23 deletions src/HMM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ HMM::HMM(Demography demography, std::vector<double> bp_sizes, std::vector<double
compute_recombination_scores(cm_sizes);
compute_mutation_scores(bp_sizes, mutation_rate);

// TODO Profile usage of std containers (ticket #25)
trellis.reserve(bp_sizes.size());
pointers.reserve(bp_sizes.size());
for (std::size_t i = 0; i < bp_sizes.size(); i++) {
std::vector<double> trellis_row(num_states, 0.0);
std::vector<unsigned short> pointer_row(num_states, 0);
trellis.push_back(trellis_row);
pointers.push_back(pointer_row);
trellis.emplace_back(num_states, 0.0);
pointers.emplace_back(num_states, 0);
}
}

std::vector<double> HMM::compute_expected_times(Demography demography, const int K) {
std::vector<double> result;
result.reserve(K);
double k = static_cast<double>(num_states);
boost::math::exponential e;

Expand All @@ -50,39 +50,45 @@ std::vector<double> HMM::compute_expected_times(Demography demography, const int
}

void HMM::compute_recombination_scores(std::vector<double> cm_sizes) {
const double log_K = std::log(num_states);
non_transition_score.reserve(cm_sizes.size());
transition_score.reserve(cm_sizes.size());
for (std::size_t i = 0; i < cm_sizes.size(); i++) {
non_transition_score.push_back(std::vector<double>());
transition_score.push_back(std::vector<double>());
std::vector<double> non_trans(num_states);
std::vector<double> trans(num_states);
for (int k = 0; k < num_states; k++) {
double t = expected_times[k];
const double l = 2. * 0.01 * cm_sizes[i] * t;
const double trans = std::log1p(-std::exp(-l)) - std::log(num_states);
const double l = 2. * 0.01 * cm_sizes[i] * expected_times[k];
const double t = std::log1p(-std::exp(-l)) - log_K;

// log-prob of transitioning
transition_score[i].push_back(trans);
trans[k] = t;

// log-prob of *not* transitioning
non_transition_score[i].push_back(std::log(std::exp(-l) + std::exp(trans)));
non_trans[k] = std::log(std::exp(-l) + std::exp(t));
}
transition_score.push_back(std::move(trans));
non_transition_score.push_back(std::move(non_trans));
}
}

void HMM::compute_mutation_scores(std::vector<double> bp_sizes, double mutation_rate) {
hom_score.reserve(bp_sizes.size());
het_score.reserve(bp_sizes.size());
for (std::size_t i = 0; i < bp_sizes.size(); i++) {
hom_score.push_back(std::vector<double>());
het_score.push_back(std::vector<double>());
std::vector<double> hom(num_states);
std::vector<double> het(num_states);
for (int k = 0; k < num_states; k++) {
double t = expected_times[k];

// TODO: use mean-bp sizes here as in the main algorithm
const double l = 2. * mutation_rate * bp_sizes[i] * t;
const double l = 2. * mutation_rate * bp_sizes[i] * expected_times[k];

// log-prob of mutating
het_score[i].push_back(std::log1p(-std::exp(-l)));
het[k] = std::log1p(-std::exp(-l));

// log-prob of *not* mutating
hom_score[i].push_back(-l);
hom[k] = -l;
}
het_score.push_back(std::move(het));
hom_score.push_back(std::move(hom));
}
}

Expand All @@ -107,12 +113,14 @@ std::vector<int> HMM::breakpoints(std::vector<bool> observations, int start) {
double score = 0.0;
unsigned short running_argmax = 0;
for (int j = 1; j < neighborhood_size; j++) {
const int js = j + start;
for (int i = 0; i < num_states; i++) {
// Hoist mut_score out of the inner k-loop: it only depends on i, not k
const double mut_score = observations[j] ? het_score[js][i] : hom_score[js][i];
double running_max = 0;
for (int k = 0; k < num_states; k++) {
double mut_score = observations[j] ? het_score[j + start][i] : hom_score[j + start][i];
double rec_score =
k == i ? non_transition_score[j + start][k] : transition_score[j + start][k];
k == i ? non_transition_score[js][k] : transition_score[js][k];

score = trellis[j - 1 + start][k] + rec_score + mut_score;

Expand All @@ -121,8 +129,8 @@ std::vector<int> HMM::breakpoints(std::vector<bool> observations, int start) {
running_argmax = static_cast<unsigned short>(k);
}
}
trellis[j + start][i] = running_max;
pointers[j + start][i] = running_argmax;
trellis[js][i] = running_max;
pointers[js][i] = running_argmax;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/ImputationMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void ImputationMatcher::process_site(const std::vector<bool>& genotype) {
throw std::runtime_error(prompt);
}
}
sorting = next_sorting;
std::swap(sorting, next_sorting);
sites_processed++;
}

Expand Down
90 changes: 53 additions & 37 deletions src/Matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ void MatchGroup::filter_matches(int min_matches) {
}
}
else if (i < 1000) {
for (auto counts : match_candidates_counts.at(i)) {
for (const auto& counts : match_candidates_counts.at(i)) {
if (counts.second >= std::min(2, min_matches)) {
match_candidates.at(i).insert(counts.first);
}
}
}
else if (i < 10000) {
for (auto counts : match_candidates_counts.at(i)) {
for (const auto& counts : match_candidates_counts.at(i)) {
if (counts.second >= min_matches) {
match_candidates.at(i).insert(counts.first);
}
}
}
else {
// Don't want too much stuff for very big studies
for (auto counts : match_candidates_counts.at(i)) {
for (const auto& counts : match_candidates_counts.at(i)) {
if (counts.second >= 2 * min_matches) {
match_candidates.at(i).insert(counts.first);
}
Expand All @@ -80,7 +80,7 @@ void MatchGroup::filter_matches(int min_matches) {
if (match_candidates.at(i).size() == 0) {
int tmp_min_matches = min_matches;
while (match_candidates.at(i).size() == 0 && tmp_min_matches > 0) {
for (auto counts : match_candidates_counts.at(i)) {
for (const auto& counts : match_candidates_counts.at(i)) {
if (counts.second >= tmp_min_matches) {
match_candidates.at(i).insert(counts.first);
}
Expand All @@ -106,7 +106,7 @@ void MatchGroup::filter_matches(int min_matches) {

void MatchGroup::insert_tops_from(MatchGroup& other) {
for (int i = 1; i < num_samples; i++) {
for (auto p : other.top_four_maps.at(i)) {
for (const auto& p : other.top_four_maps.at(i)) {
match_candidates.at(i).insert(p.first);
}
}
Expand Down Expand Up @@ -230,7 +230,7 @@ void Matcher::process_site(const std::vector<int>& genotype) {
throw std::runtime_error(prompt);
}
}
sorting = next_sorting;
std::swap(sorting, next_sorting);

// Threading-neighbor queries
if (match_group_idx < (static_cast<int>(match_group_sites.size()) - 1) &&
Expand All @@ -248,42 +248,41 @@ void Matcher::process_site(const std::vector<int>& genotype) {
}
next_query_site_idx++;

// Initialize the red-black tree
std::set<int> threaded = {permutation.at(0)};
// Boolean array for O(1) mark + sequential scan neighbor finding
std::vector<char> inserted(num_samples, 0);
inserted[permutation[0]] = 1;

// Insert sequences and query in order
for (int i = 1; i < num_samples; i++) {
std::vector<int> matches;
matches.reserve(neighborhood_size);
auto iter = threaded.insert(permutation.at(i));
auto iter_up = iter.first;
auto iter_down = iter.first;
// Check if genotypes are identical, just to be sure
while ((static_cast<int>(matches.size()) < neighborhood_size) &&
(iter_down != threaded.begin() || iter_up != threaded.end())) {
if (iter_down != threaded.begin()) {
iter_down--;
matches.push_back(sorting.at(*iter_down));
}
if (static_cast<int>(matches.size()) < neighborhood_size && iter_up != threaded.end()) {
iter_up++;
if (iter_up != threaded.end()) {
matches.push_back(sorting.at(*iter_up));
const int pos = permutation[i];
inserted[pos] = 1;

// Find neighborhood_size nearest neighbors by scanning left/right
int n_found = 0;
int left = pos - 1;
int right = pos + 1;
std::unordered_map<int, int>& mmmap =
match_groups[match_group_idx].match_candidates_counts[i];
while (n_found < neighborhood_size && (left >= 0 || right < num_samples)) {
// Scan left for next set bit
if (left >= 0) {
while (left >= 0 && !inserted[left]) left--;
if (left >= 0) {
int m = sorting[left];
mmmap[m]++;
n_found++;
left--;
}
}
}
for (int m : matches) {
std::unordered_map<int, int>& mmmap =
match_groups.at(match_group_idx).match_candidates_counts.at(i);
if (m >= i) {
throw std::runtime_error("Illegal match candidate " + std::to_string(m) +
", something is very wrong");
}
if (!mmmap.count(m)) {
mmmap[m] = 1;
}
else {
mmmap[m]++;
// Scan right for next set bit
if (n_found < neighborhood_size && right < num_samples) {
while (right < num_samples && !inserted[right]) right++;
if (right < num_samples) {
int m = sorting[right];
mmmap[m]++;
n_found++;
right++;
}
}
}
}
Expand All @@ -296,6 +295,23 @@ void Matcher::process_site(const std::vector<int>& genotype) {
sites_processed++;
}

void Matcher::process_all_sites(const std::vector<std::vector<int>>& genotypes) {
for (const auto& genotype : genotypes) {
process_site(genotype);
}
}

void Matcher::process_all_sites_flat(const int32_t* data, int n_sites, int n_haps) {
std::vector<int> genotype(n_haps);
for (int s = 0; s < n_sites; s++) {
const int32_t* row = data + static_cast<std::size_t>(s) * n_haps;
for (int h = 0; h < n_haps; h++) {
genotype[h] = row[h];
}
process_site(genotype);
}
}

// Propagate top 4 matches from left and right match groups
void Matcher::propagate_adjacent_matches() {
for (int i = 1; i < static_cast<int>(match_groups.size()); i++) {
Expand Down
3 changes: 3 additions & 0 deletions src/Matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef THREADS_ARG_MATCHER_HPP
#define THREADS_ARG_MATCHER_HPP

#include <cstdint>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -46,6 +47,8 @@ class Matcher {

// Do all the work
void process_site(const std::vector<int>& genotype);
void process_all_sites(const std::vector<std::vector<int>>& genotypes);
void process_all_sites_flat(const int32_t* data, int n_sites, int n_haps);
void propagate_adjacent_matches();
void clear();

Expand Down
13 changes: 5 additions & 8 deletions src/State.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,22 @@ void StateBranch::prune() {
}

StateTree::StateTree(std::vector<State>& states) {
for (auto s : states) {
for (const auto& s : states) {
int sample_ID = s.below->sample_ID;
if (branches.find(sample_ID) == branches.end()) {
branches[sample_ID] = StateBranch();
}
branches[sample_ID].insert(s);
}
}

void StateTree::prune() {
for (auto pair : branches) {
branches[pair.first].prune();
for (auto& [key, branch] : branches) {
branch.prune();
}
}

std::vector<State> StateTree::dump() const {
std::vector<State> states;
for (auto pair : branches) {
for (auto s : pair.second.states) {
for (const auto& [key, branch] : branches) {
for (const auto& s : branch.states) {
states.push_back(s);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ThreadingInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ std::vector<double> ThreadingInstructions::right_multiply(const std::vector<doub
// Check input vector lengths are correct
if (x.size() != num_sites) {
std::ostringstream oss;
oss << "Input vector must have length " << num_samples / 2 << ".";
oss << "Input vector must have length " << num_sites << ".";
throw std::runtime_error(oss.str());
}

Expand Down
Loading