Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 145 additions & 37 deletions src/Matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
MatchGroup::MatchGroup(int _num_samples, double _cm_position)
: num_samples(_num_samples), cm_position(_cm_position) {
for (int i = 0; i < num_samples; i++) {
// For each sample, we have a mapping match->score assigning a
// matching score to each candidate closest cousin
match_candidates_counts.push_back(std::unordered_map<int, int>());
match_candidates[i] = {};
}
}

Expand All @@ -44,6 +47,38 @@ MatchGroup::MatchGroup(const std::vector<int>& target_ids,
}
}

MatchGroupDifference::MatchGroupDifference(const MatchGroup& prev, const MatchGroup& next, const int _site)
: site(_site) {
// added = std::unordered_map<int, std::unordered_set<int>>();
// removed = std::unordered_map<int, std::unordered_set<int>>();
if (prev.cm_position >= next.cm_position) {
throw std::runtime_error("Match group position out of order");
}
if (prev.num_samples != next.num_samples) {
throw std::runtime_error("Incompatible match groups");
}
for (int i = 0; i < prev.num_samples; i++) {
// "added" is "next_group - prev_group"
std::unordered_set<int> added_i(next.match_candidates.at(i));
for (auto s : prev.match_candidates.at(i)) {
added_i.erase(s);
}

// "removed" is "prev_group - next_group"
std::unordered_set<int> removed_i(prev.match_candidates.at(i));
for (auto s : next.match_candidates.at(i)) {
removed_i.erase(s);
}

if (added_i.size() > 0) {
added[i] = added_i;
}
if (removed_i.size() > 0) {
removed[i] = removed_i;
}
}
}

void MatchGroup::filter_matches(int min_matches) {
// First set the candidates for this group
// For sequences of high index, we lower the number of sequences used to save time and memory
Expand All @@ -56,6 +91,8 @@ void MatchGroup::filter_matches(int min_matches) {
}
else if (i < 1000) {
for (auto counts : match_candidates_counts.at(i)) {
// TODO: should move this min(2, min_matches) to the "i<100" clause
// and have this identical with the 10,000 one
if (counts.second >= std::min(2, min_matches)) {
match_candidates.at(i).insert(counts.first);
}
Expand Down Expand Up @@ -180,11 +217,20 @@ Matcher::Matcher(int _n, const std::vector<double>& _genetic_positions, double _
std::cout << "Will use " << query_sites.size() << " query sites and " << match_group_sites.size()
<< " match_group_sites" << std::endl;

match_groups.reserve(match_group_sites.size());
for (int match_group_site : match_group_sites) {
match_groups.emplace_back(num_samples, genetic_positions[match_group_site]);
}
// once "current_group" has been constructed, we can process "prev_group"
// and compute the match_group_diff
current_group = MatchGroup(num_samples, genetic_positions[0]);
prev_group = MatchGroup(num_samples, genetic_positions[0] - 1);
prevprev_group = MatchGroup(num_samples, genetic_positions[0] - 2);

match_diffs.reserve(match_group_sites.size());

// match_groups.reserve(match_group_sites.size());
// for (int match_group_site : match_group_sites) {
// match_groups.emplace_back(num_samples, genetic_positions[match_group_site]);
// }

// PBWT quantities
sorting.reserve(num_samples);
next_sorting.reserve(num_samples);
permutation.reserve(num_samples);
Expand Down Expand Up @@ -235,8 +281,34 @@ void Matcher::process_site(const std::vector<int>& genotype) {
// Threading-neighbor queries
if (match_group_idx < (static_cast<int>(match_group_sites.size()) - 1) &&
(sites_processed >= match_group_sites.at(match_group_idx + 1))) {
// std::cout << "MATCH GROUP\n";
// std::cout << match_group_idx << "\n";
// Process all matches for this group
// std::cout << "filtering\n";
current_group.filter_matches(min_matches);
// int total_current = 0;
// for (int k = 0; k < num_samples; k++) {
// total_current += current_group.match_candidates.at(k).size();
// }
// std::cout << "TOTAL CURRENT: " << total_current << "\n";

// Share top matches for adjacent groups
if (match_group_idx > 0) {
// std::cout << "tops\n";
prev_group.insert_tops_from(current_group);
current_group.insert_tops_from(prev_group);
// std::cout << "matchdiff\n";
match_diffs.emplace_back(prevprev_group, prev_group, match_group_sites[match_group_idx - 1]);
}

// std::cout << "increment\n";
prevprev_group = prev_group;
prev_group = current_group;
match_group_idx++;
match_groups.at(match_group_idx - 1).filter_matches(min_matches);

// std::cout << "new matchgroup\n";
current_group = MatchGroup(num_samples, genetic_positions[match_group_sites[match_group_idx]]);
// std::cout << "done\n";
}

// If we've reached a query site, query
Expand Down Expand Up @@ -274,7 +346,7 @@ void Matcher::process_site(const std::vector<int>& genotype) {
}
for (int m : matches) {
std::unordered_map<int, int>& mmmap =
match_groups.at(match_group_idx).match_candidates_counts.at(i);
current_group.match_candidates_counts.at(i);
if (m >= i) {
throw std::runtime_error("Illegal match candidate " + std::to_string(m) +
", something is very wrong");
Expand All @@ -290,54 +362,90 @@ void Matcher::process_site(const std::vector<int>& genotype) {

// Special case for last query
if (next_query_site_idx == static_cast<int>(query_sites.size())) {
match_groups.at(match_group_sites.size() - 1).filter_matches(min_matches);
// match_groups.at(match_group_sites.size() - 1).filter_matches(min_matches);
current_group.filter_matches(min_matches);
prev_group.insert_tops_from(current_group);
current_group.insert_tops_from(prev_group);
match_diffs.emplace_back(prevprev_group, prev_group, match_group_sites[match_group_idx - 1]);
match_diffs.emplace_back(prev_group, current_group, match_group_sites[match_group_idx]);
}
}
sites_processed++;
}

// Propagate top 4 matches from left and right match groups
void Matcher::propagate_adjacent_matches() {
for (int i = 1; i < static_cast<int>(match_groups.size()); i++) {
MatchGroup& group = match_groups.at(i);
MatchGroup& prev = match_groups.at(i - 1);
group.insert_tops_from(prev);
prev.insert_tops_from(group);
}
// for (int i = 1; i < static_cast<int>(match_groups.size()); i++) {
// MatchGroup& group = match_groups.at(i);
// MatchGroup& prev = match_groups.at(i - 1);
// group.insert_tops_from(prev);
// prev.insert_tops_from(group);
// }
}

std::vector<MatchGroup> Matcher::get_matches() {
return match_groups;
}
// std::vector<MatchGroup> Matcher::get_matches() {
// return match_groups;
// }

// This returns a list (groups) of lists (targets) of sets (matches)
std::vector<std::vector<std::unordered_set<int>>>
Matcher::serializable_matches(std::vector<int>& target_ids) {
std::vector<std::vector<std::unordered_set<int>>> serialized_matches(match_groups.size());
int group_counter = 0;
for (MatchGroup& match_group : match_groups) {
std::vector<std::unordered_set<int>> current_group_matches(target_ids.size());
int match_counter = 0;
for (int target_id : target_ids) {
current_group_matches[match_counter] = std::move(match_group.match_candidates.at(target_id));
match_group.match_candidates.at(target_id).clear();
match_counter++;
// This returns a list (match-group sites) of lists (probands) of sets (matches)
// std::vector<std::vector<std::unordered_set<int>>>
// Matcher::serializable_matches(std::vector<int>& target_ids) {
// std::vector<std::vector<std::unordered_set<int>>> serialized_matches(match_groups.size());
// int group_counter = 0;
// for (MatchGroup& match_group : match_groups) {
// std::vector<std::unordered_set<int>> current_group_matches(target_ids.size());
// int match_counter = 0;
// for (int target_id : target_ids) {
// current_group_matches[match_counter] = std::move(match_group.match_candidates.at(target_id));
// match_group.match_candidates.at(target_id).clear();
// match_counter++;
// }
// serialized_matches[group_counter] = std::move(current_group_matches);
// group_counter++;
// }
// return serialized_matches;
// }

// This returns a list of uint-quadruples:
// sample_id: the sample this entry refers to
// target_id: the closest cousin candidate
// added/removed: 1/0 depending on type of entry
// cm_idx: index of position of change indexed into the genetic_positions vector
std::vector<std::vector<int>> Matcher::serializable_matches(std::vector<int>& sample_ids) {
std::vector<std::vector<int>> out;
int group_counter;
for (MatchGroupDifference& group_diff : match_diffs) {
int site = group_diff.site;
//
for (int sample_id : sample_ids) {
if (group_diff.added.find(sample_id) != group_diff.added.end()) {
for (auto target_id : group_diff.added[sample_id]) {
// int entry[4] = {sample_id, target_id, 1, site};
std::vector<int> entry = {sample_id, target_id, 1, site};
out.push_back(entry);
}
}
if (group_diff.removed.find(sample_id) != group_diff.removed.end()) {
for (auto target_id : group_diff.removed[sample_id]) {
// int entry[4] = {};
std::vector<int> entry = {sample_id, target_id, 0, site};
out.push_back(entry);
}
}
}
serialized_matches[group_counter] = std::move(current_group_matches);
group_counter++;
}
return serialized_matches;
return out;
}

void Matcher::clear() {
match_groups.clear();
}
// void Matcher::clear() {
// match_groups.clear();
// }

std::vector<double> Matcher::cm_positions() {
std::vector<double> cms;
cms.reserve(match_groups.size());
for (MatchGroup& match_group : match_groups) {
cms.push_back(match_group.cm_position);
cms.reserve(match_diffs.size());
for (MatchGroupDifference& match_diff : match_diffs) {
cms.push_back(genetic_positions[match_diff.site]);
}
return cms;
}
Expand Down
36 changes: 31 additions & 5 deletions src/Matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
/// for a certain interval, store the matches for all samples
class MatchGroup {
public:
MatchGroup() : num_samples(0), cm_position(0.0) {};
MatchGroup(int _num_samples, double cm_position);
MatchGroup(const std::vector<int>& target_ids,
const std::vector<std::unordered_set<int>>& matches, const double _cm_position);
Expand All @@ -39,6 +40,26 @@ class MatchGroup {
double cm_position = 0.0;
};

class MatchGroupDifference {
public:
MatchGroupDifference(const MatchGroup& prev, const MatchGroup& next, const int _site);

public:
std::unordered_map<int, std::unordered_set<int>> added;
std::unordered_map<int, std::unordered_set<int>> removed;
int site = 0;
};

class MatchGroupEntry {
public:
MatchGroupEntry(int _sample_id, int _target_id, int _added, int _site)
: sample_id(_sample_id), target_id(_target_id), added(_added), site(_site) {};
const int sample_id = 0;
const int target_id = 0;
const int added = 0;
const int site = 0;
};

class Matcher {
public:
Matcher(int _n, const std::vector<double>& _genetic_positions, double _query_interval_size,
Expand All @@ -47,11 +68,12 @@ class Matcher {
// Do all the work
void process_site(const std::vector<int>& genotype);
void propagate_adjacent_matches();
void clear();
// void clear();

std::vector<MatchGroup> get_matches();
std::vector<std::vector<std::unordered_set<int>>>
serializable_matches(std::vector<int>& target_ids);
// std::vector<MatchGroup> get_matches();
// std::vector<std::vector<std::unordered_set<int>>>
// serializable_matches(std::vector<int>& target_ids);
std::vector<std::vector<int>> serializable_matches(std::vector<int>& sample_ids);
std::vector<double> cm_positions();

std::vector<int> get_sorting();
Expand All @@ -66,14 +88,18 @@ class Matcher {
std::vector<int> query_sites;
std::vector<int> match_group_sites;
int num_sites = 0;
std::vector<MatchGroupDifference> match_diffs;
// matches in these groups are considered together in the hmm

private:
int min_matches = 0;
int sites_processed = 0;
int next_query_site_idx = 0;
int match_group_idx = 0;
std::vector<MatchGroup> match_groups;
// std::vector<MatchGroup> match_groups;
MatchGroup current_group;
MatchGroup prev_group;
MatchGroup prevprev_group;
std::vector<int> sorting;
std::vector<int> next_sorting;
std::vector<int> permutation;
Expand Down
Loading
Loading