From 54af29f106fe9a1c61c9c933efe8cd61c785ccb4 Mon Sep 17 00:00:00 2001 From: Paolo Cozzi Date: Tue, 25 Nov 2025 17:52:08 +0100 Subject: [PATCH] relax requirement for map file and fix function call Makes the map file optional and calculates genetic positions from the provided recombination rate if no map is provided. Also fix make_constant_recombination_from_pgen call related to #117 --- src/threads_arg/__main__.py | 4 ++-- src/threads_arg/infer.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/threads_arg/__main__.py b/src/threads_arg/__main__.py index a82b38b..9c712ae 100644 --- a/src/threads_arg/__main__.py +++ b/src/threads_arg/__main__.py @@ -41,7 +41,7 @@ def main(): @main.command() @click.option("--pgen", required=True, help="Path to input genotypes in pgen format") -@click.option("--map", required=True, help="Path to genotype map in SHAPEIT format") +@click.option("--map", help="Path to genotype map in SHAPEIT format") @click.option("--recombination_rate", default=1.3e-8, type=float, help="Genome-wide recombination rate. Ignored if a map is passed") @click.option("--demography", required=True, help="Path to input genotype") @click.option("--mode", required=True, type=click.Choice(['array', 'wgs']), default="wgs", help="Inference mode (wgs or array)") @@ -53,7 +53,7 @@ def main(): @click.option("--mutation_rate", required=True, type=float, default=1.4e-8, help="Genome-wide mutation rate") @click.option("--num_threads", type=int, default=1, help="Number of computational threads to request") @click.option("--region", help="Region of genome in chr:start-end format for which ARG is output. The full genotype is still used for inference") -@click.option("--max_sample_batch_size", help="Max number of LS processes run simultaneously per thread", default=None, type=int) +@click.option("--max_sample_batch_size", help="Max number of LS processes run simultaneously per thread", default=None, type=int) @click.option("--save_metadata", is_flag=True, default=False, help="If specified, the output will include sample/variant metadata (sample IDs, marker names, allele symbols, etc).") @click.option("--out") def infer(**kwargs): diff --git a/src/threads_arg/infer.py b/src/threads_arg/infer.py index de2d4f2..88faeac 100644 --- a/src/threads_arg/infer.py +++ b/src/threads_arg/infer.py @@ -81,7 +81,7 @@ def partial_viterbi(pgen, mode, num_samples_hap, physical_positions, genetic_pos else: raise RuntimeError - # Batching here saves a small amount of memory + # Batching here saves a small amount of memory num_samples = len(sample_batch) sample_indices = list(range(num_samples)) num_subsets = int(np.ceil(num_samples / max_sample_batch_size)) @@ -198,7 +198,7 @@ def threads_infer(pgen, map, recombination_rate, demography, mutation_rate, fit_ genetic_positions, physical_positions = make_recombination_from_map_and_pgen(map, pgen, chrom) else: logger.info(f"Using constant recombination rate of {recombination_rate}") - genetic_positions, physical_positions = make_constant_recombination_from_pgen(pgen, recombination_rate, chrom) + genetic_positions, physical_positions = make_constant_recombination_from_pgen(pgen, recombination_rate) # Load/set CHR, POS, ID, REF, ALT, QUAL, FILTER variant_metadata = read_variant_metadata(pgen) if save_metadata else None @@ -361,4 +361,3 @@ def matcher_callback(i, g, mask, matcher): # Save results logger.info(f"Done in (s): {time.time() - start_time}") -