diff --git a/src/threads_arg/__main__.py b/src/threads_arg/__main__.py index a82b38b..9c712ae 100644 --- a/src/threads_arg/__main__.py +++ b/src/threads_arg/__main__.py @@ -41,7 +41,7 @@ def main(): @main.command() @click.option("--pgen", required=True, help="Path to input genotypes in pgen format") -@click.option("--map", required=True, help="Path to genotype map in SHAPEIT format") +@click.option("--map", help="Path to genotype map in SHAPEIT format") @click.option("--recombination_rate", default=1.3e-8, type=float, help="Genome-wide recombination rate. Ignored if a map is passed") @click.option("--demography", required=True, help="Path to input genotype") @click.option("--mode", required=True, type=click.Choice(['array', 'wgs']), default="wgs", help="Inference mode (wgs or array)") @@ -53,7 +53,7 @@ def main(): @click.option("--mutation_rate", required=True, type=float, default=1.4e-8, help="Genome-wide mutation rate") @click.option("--num_threads", type=int, default=1, help="Number of computational threads to request") @click.option("--region", help="Region of genome in chr:start-end format for which ARG is output. The full genotype is still used for inference") -@click.option("--max_sample_batch_size", help="Max number of LS processes run simultaneously per thread", default=None, type=int) +@click.option("--max_sample_batch_size", help="Max number of LS processes run simultaneously per thread", default=None, type=int) @click.option("--save_metadata", is_flag=True, default=False, help="If specified, the output will include sample/variant metadata (sample IDs, marker names, allele symbols, etc).") @click.option("--out") def infer(**kwargs): diff --git a/src/threads_arg/infer.py b/src/threads_arg/infer.py index de2d4f2..88faeac 100644 --- a/src/threads_arg/infer.py +++ b/src/threads_arg/infer.py @@ -81,7 +81,7 @@ def partial_viterbi(pgen, mode, num_samples_hap, physical_positions, genetic_pos else: raise RuntimeError - # Batching here saves a small amount of memory + # Batching here saves a small amount of memory num_samples = len(sample_batch) sample_indices = list(range(num_samples)) num_subsets = int(np.ceil(num_samples / max_sample_batch_size)) @@ -198,7 +198,7 @@ def threads_infer(pgen, map, recombination_rate, demography, mutation_rate, fit_ genetic_positions, physical_positions = make_recombination_from_map_and_pgen(map, pgen, chrom) else: logger.info(f"Using constant recombination rate of {recombination_rate}") - genetic_positions, physical_positions = make_constant_recombination_from_pgen(pgen, recombination_rate, chrom) + genetic_positions, physical_positions = make_constant_recombination_from_pgen(pgen, recombination_rate) # Load/set CHR, POS, ID, REF, ALT, QUAL, FILTER variant_metadata = read_variant_metadata(pgen) if save_metadata else None @@ -361,4 +361,3 @@ def matcher_callback(i, g, mask, matcher): # Save results logger.info(f"Done in (s): {time.time() - start_time}") -