diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md
index cf5f86a..7443e53 100644
--- a/RELEASE_NOTES.md
+++ b/RELEASE_NOTES.md
@@ -2,6 +2,10 @@
## [Unreleased]
+### Added
+
+- Add left_multiplication and right_multiplication to ThreadingInstructions (#99)
+
### Changed
- Build wheels on macOS 14 for arm64 and macOS 15 for x86_64 (#108)
diff --git a/src/ThreadingInstructions.cpp b/src/ThreadingInstructions.cpp
index 121fde0..4ba5be4 100644
--- a/src/ThreadingInstructions.cpp
+++ b/src/ThreadingInstructions.cpp
@@ -15,7 +15,9 @@
// along with this program. If not, see .
#include "ThreadingInstructions.hpp"
+#include "GenotypeIterator.hpp"
+#include
#include
#include
#include
@@ -259,3 +261,169 @@ ThreadingInstructions ThreadingInstructions::sub_range(const int range_start, co
std::move(range_positions)
};
}
+
+std::vector ThreadingInstructions::left_multiply(const std::vector& x, bool diploid, bool normalize) {
+ // Left-multiplication of the genotype matrix by a vector of doubles
+
+ // Check input vector lengths are correct
+ if (diploid) {
+ if (x.size() != num_samples / 2) {
+ std::ostringstream oss;
+ oss << "Input vector must have length " << num_samples / 2 << ".";
+ throw std::runtime_error(oss.str());
+ }
+ } else {
+ if (x.size() != num_samples) {
+ std::ostringstream oss;
+ oss << "Input vector must have length " << num_samples << ".";
+ throw std::runtime_error(oss.str());
+ }
+ }
+
+ // Initialize genotype traversal
+ GenotypeIterator gi = GenotypeIterator(*this);
+ std::size_t site_counter = 0;
+ std::vector out(num_sites);
+
+ while (gi.has_next_genotype()) {
+ // Fetch the next genotype
+ const std::vector& g = gi.next_genotype();
+
+ // Initialize the next entry
+ double entry = 0.0;
+
+ if (normalize) {
+ // If we want to normalize, we need the mean and standard deviation of g.
+ double ac = 0.0;
+ for (auto a : g) {
+ ac += a;
+ }
+ if (diploid) {
+ // We do the diploid standard deviation by hand
+ double mu = 2.0 * ac / num_samples;
+ double sample_var = 0.0;
+ for (std::size_t i=0; i < x.size(); i++) {
+ int h = g[2 * i] + g[2 * i + 1];
+ double d = h - mu;
+ sample_var += d * d;
+ }
+ sample_var /= (num_samples / 2);
+
+ double std = std::sqrt(sample_var);
+ for (std::size_t i=0; i < x.size(); i++) {
+ int h = g[2 * i] + g[2 * i + 1];
+ double w = x[i];
+ entry += w * (h - mu) / std;
+ }
+ } else {
+ double mu = ac / num_samples;
+ double std = std::sqrt(mu * (1 - mu));
+ for (std::size_t i=0; i < g.size(); i++) {
+ double w = x[i];
+ entry += w * (g[i] - mu) / std;
+ }
+ }
+ } else {
+ for (std::size_t i=0; i < g.size(); i++) {
+ double w = diploid ? x[i / 2] : x[i];
+ entry += w * g[i];
+ }
+ }
+ out[site_counter] = entry;
+ site_counter++;
+ }
+ return out;
+}
+
+std::vector ThreadingInstructions::right_multiply(const std::vector& x, bool diploid, bool normalize) {
+ // Right-multiplication of the genotype matrix by a vector of doubles
+
+ // Check input vector lengths are correct
+ if (x.size() != num_sites) {
+ std::ostringstream oss;
+ oss << "Input vector must have length " << num_samples / 2 << ".";
+ throw std::runtime_error(oss.str());
+ }
+
+ GenotypeIterator gi = GenotypeIterator(*this);
+ std::size_t site_counter = 0;
+ if (diploid) {
+ // Initialize output
+ std::vector out(num_samples / 2, 0.0);
+ if (normalize) {
+ while (gi.has_next_genotype()) {
+ // Fetch the next genotype
+ const std::vector& g = gi.next_genotype();
+
+ // If we want to normalize, we need the mean and standard deviation of g.
+ double ac = 0.0;
+ for (auto a : g) {
+ ac += a;
+ }
+
+ // We do the diploid standard deviation by hand
+ const double mu = 2.0 * ac / num_samples;
+ double sample_var = 0.0;
+ for (std::size_t i=0; i < out.size(); i++) {
+ int h = g[2 * i] + g[2 * i + 1];
+ double d = h - mu;
+ sample_var += d * d;
+ }
+ sample_var /= (num_samples / 2);
+ const double std = std::sqrt(sample_var);
+
+ const double w = x[site_counter] / std;
+ for (std::size_t i=0; i < out.size(); i++) {
+ const int h = g[2 * i] + g[2 * i + 1];
+ out[i] += w * (h - mu);
+ }
+ site_counter++;
+ }
+ } else {
+ while (gi.has_next_genotype()) {
+ // Fetch the next genotype
+ const std::vector& g = gi.next_genotype();
+ const double w = x[site_counter];
+ for (std::size_t i=0; i < out.size(); i++) {
+ const int h = g[2 * i] + g[2 * i + 1];
+ out[i] += w * h;
+ }
+ site_counter++;
+ }
+ }
+ return out;
+ } else {
+ // Initialize output
+ std::vector out(num_samples, 0.0);
+ if (normalize) {
+ while (gi.has_next_genotype()) {
+ // Fetch the next genotype
+ const std::vector& g = gi.next_genotype();
+ double ac = 0.0;
+ for (auto a : g) {
+ ac += a;
+ }
+
+ // Normalization constants
+ double mu = ac / num_samples;
+ double std = std::sqrt(mu * (1 - mu));
+ const double w = x[site_counter] / std;
+ for (std::size_t i=0; i < out.size(); i++) {
+ out[i] += w * (g[i] - mu);
+ }
+ site_counter++;
+ }
+ } else {
+ while (gi.has_next_genotype()) {
+ // Fetch the next genotype
+ const std::vector& g = gi.next_genotype();
+ const double w = x[site_counter];
+ for (std::size_t i=0; i < out.size(); i++) {
+ out[i] += w * g[i];
+ }
+ site_counter++;
+ }
+ }
+ return out;
+ }
+}
diff --git a/src/ThreadingInstructions.hpp b/src/ThreadingInstructions.hpp
index ae0136b..f404be2 100644
--- a/src/ThreadingInstructions.hpp
+++ b/src/ThreadingInstructions.hpp
@@ -85,6 +85,10 @@ class ThreadingInstructions {
ThreadingInstructions sub_range(const int range_start, const int range_end) const;
+ // Common operations
+ std::vector left_multiply(const std::vector& x, bool diploid=false, bool normalize=false);
+ std::vector right_multiply(const std::vector& x, bool diploid=false, bool normalize=false);
+
public:
int start = 0;
int end = 0;
diff --git a/src/threads_arg_pybind.cpp b/src/threads_arg_pybind.cpp
index b8ca12d..c829a20 100644
--- a/src/threads_arg_pybind.cpp
+++ b/src/threads_arg_pybind.cpp
@@ -140,7 +140,9 @@ PYBIND11_MODULE(threads_arg_python_bindings, m) {
.def("sub_range", &ThreadingInstructions::sub_range)
.def(py::pickle(
&threading_instructions_get_state,
- &threading_instructions_set_state));
+ &threading_instructions_set_state))
+ .def("left_multiply", &ThreadingInstructions::left_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false)
+ .def("right_multiply", &ThreadingInstructions::right_multiply, py::arg("x"), py::arg("diploid") = false, py::arg("normalize") = false);
py::class_(m, "ConsistencyWrapper")
.def(py::init>&, const std::vector>&, const std::vector>&,
diff --git a/test/test_multiply.py b/test/test_multiply.py
new file mode 100644
index 0000000..9b2e1e5
--- /dev/null
+++ b/test/test_multiply.py
@@ -0,0 +1,117 @@
+# This file is part of the Threads software suite.
+# Copyright (C) 2024-2025 Threads Developers.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import numpy as np
+import pgenlib
+import pytest
+
+from threads_arg.serialization import load_instructions
+
+from snapshot_runners import (
+ TEST_DATA_DIR
+)
+
+def _col_normalize(x):
+ z = x.copy()
+ mu = z.mean(axis=0, keepdims=True)
+ std = z.std(axis=0, keepdims=True)
+ return (z - mu) / std
+
+def test_left_multiply():
+ # Read ground truth genotypes
+ pgen_path = str(TEST_DATA_DIR / "panel.pgen")
+ reader = pgenlib.PgenReader(str(pgen_path).encode())
+ expected_num_variants = reader.get_variant_ct()
+ num_samples = reader.get_raw_sample_ct()
+ expected_gt = np.empty((expected_num_variants, 2 * num_samples), dtype=np.int32)
+ reader.read_alleles_range(0, expected_num_variants, expected_gt)
+ gt_matrix = expected_gt.transpose()
+ gt_matrix_dip = gt_matrix[::2] + gt_matrix[1::2]
+ gt_matrix_norm = _col_normalize(gt_matrix)
+ gt_matrix_dip_norm = _col_normalize(gt_matrix_dip)
+
+ # Read threading instructions
+ threads_path = str(TEST_DATA_DIR / "expected_infer_snapshot.threads")
+ instructions = load_instructions(threads_path)
+
+ # Random vector to multiply with
+ rng = np.random.default_rng(130222)
+ x_hap = rng.normal(0, 1, 2 * num_samples)
+ x_dip = rng.normal(0, 1, num_samples)
+
+ # Make sure length checks are performed
+ with pytest.raises(RuntimeError):
+ instructions.left_multiply(x_dip)
+ with pytest.raises(RuntimeError):
+ instructions.left_multiply(x_hap, diploid=True)
+
+ # Do normal left-multiplication
+ expected = x_hap @ gt_matrix
+ expected_norm = x_hap @ gt_matrix_norm
+ expected_dip = x_dip @ gt_matrix_dip
+ expected_dip_norm = x_dip @ gt_matrix_dip_norm
+
+ # Do threads left-multiplication and confirm results are correct
+ found = instructions.left_multiply(x_hap)
+ assert np.allclose(expected, found)
+ found_norm = instructions.left_multiply(x_hap, normalize=True)
+ assert np.allclose(expected_norm, found_norm)
+ found_dip = instructions.left_multiply(x_dip, diploid=True)
+ assert np.allclose(expected_dip, found_dip)
+ found_dip_norm = instructions.left_multiply(x_dip, normalize=True, diploid=True)
+ assert np.allclose(expected_dip_norm, found_dip_norm)
+
+def test_right_multiply():
+ # Read ground truth genotypes
+ pgen_path = str(TEST_DATA_DIR / "panel.pgen")
+ reader = pgenlib.PgenReader(str(pgen_path).encode())
+ expected_num_variants = reader.get_variant_ct()
+ num_samples = reader.get_raw_sample_ct()
+ expected_gt = np.empty((expected_num_variants, 2 * num_samples), dtype=np.int32)
+ reader.read_alleles_range(0, expected_num_variants, expected_gt)
+ gt_matrix = expected_gt.transpose()
+ gt_matrix_dip = gt_matrix[::2] + gt_matrix[1::2]
+ gt_matrix_norm = _col_normalize(gt_matrix)
+ gt_matrix_dip_norm = _col_normalize(gt_matrix_dip)
+
+ # Read threading instructions
+ threads_path = str(TEST_DATA_DIR / "expected_infer_snapshot.threads")
+ instructions = load_instructions(threads_path)
+
+ # Random vector to multiply with
+ rng = np.random.default_rng(130222)
+ x = rng.normal(0, 1, expected_num_variants)
+ x_wrong_length = rng.normal(0, 1, expected_num_variants + 1)
+
+ # Make sure length check is performed
+ with pytest.raises(RuntimeError):
+ instructions.left_multiply(x_wrong_length)
+
+ # Do normal right-multiplication
+ expected = gt_matrix @ x
+ expected_norm = gt_matrix_norm @ x
+ expected_dip = gt_matrix_dip @ x
+ expected_dip_norm = gt_matrix_dip_norm @ x
+
+ # Do threads right-multiplication and confirm results are correct
+ found = instructions.right_multiply(x)
+ assert np.allclose(expected, found)
+ found_norm = instructions.right_multiply(x, normalize=True)
+ assert np.allclose(expected_norm, found_norm)
+ found_dip = instructions.right_multiply(x, diploid=True)
+ assert np.allclose(expected_dip, found_dip)
+ found_dip_norm = instructions.right_multiply(x, normalize=True, diploid=True)
+ assert np.allclose(expected_dip_norm, found_dip_norm)