diff --git a/.gitignore b/.gitignore index c9aa01f..2112a66 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,8 @@ pgatk/testdata/Meleagris_gallopavo* # Internal working docs (implementation plans, scratch notes) docs/plans/ + +# BioPython SeqIO.index_db SQLite indexes — built lazily on first use, +# rebuilt automatically when the source FASTA changes (mtime check). +*.fa.idx +*.fasta.idx diff --git a/pgatk/commands/vcf_to_proteindb.py b/pgatk/commands/vcf_to_proteindb.py index da2d474..7cb2faa 100644 --- a/pgatk/commands/vcf_to_proteindb.py +++ b/pgatk/commands/vcf_to_proteindb.py @@ -38,6 +38,8 @@ help="enabling this option causes or variants to be parsed. By default only variants that have not failed any filters will be processed (FILTER column is PASS, None, .) or if the filters are subset of the accepted filters. (default is False)", is_flag=True) @click.option('--accepted_filters', help="Accepted filters for variant parsing") +@click.option('-w', '--workers', type=int, default=None, + help="Number of worker processes to fan out across (default: 1, sequential). Per-chromosome split.") @click.pass_context def vcf_to_proteindb(ctx, config_file, input_fasta, vcf, gene_annotations_gtf, translation_table, mito_translation_table, @@ -45,7 +47,7 @@ def vcf_to_proteindb(ctx, config_file, input_fasta, vcf, gene_annotations_gtf, t af_field, af_threshold, transcript_str, biotype_str, exclude_biotypes, include_biotypes, consequence_str, exclude_consequences, skip_including_all_cds, include_consequences, - ignore_filters, accepted_filters): + ignore_filters, accepted_filters, workers): config_data = load_config("ensembl_config", config_file) @@ -105,5 +107,8 @@ def vcf_to_proteindb(ctx, config_file, input_fasta, vcf, gene_annotations_gtf, t if accepted_filters is not None: pipeline_arguments[EnsemblDataService.ACCEPTED_FILTERS] = accepted_filters + if workers is not None: + pipeline_arguments[EnsemblDataService.WORKERS] = workers + ensembl_data_service = EnsemblDataService(config_data, pipeline_arguments) - ensembl_data_service.vcf_to_proteindb(vcf, input_fasta, gene_annotations_gtf) + ensembl_data_service.vcf_to_proteindb(vcf, input_fasta, gene_annotations_gtf, workers=workers) diff --git a/pgatk/ensembl/ensembl.py b/pgatk/ensembl/ensembl.py index d83d031..49fe852 100644 --- a/pgatk/ensembl/ensembl.py +++ b/pgatk/ensembl/ensembl.py @@ -1,6 +1,8 @@ from __future__ import annotations import logging +import os +import re import sqlite3 from pathlib import Path from typing import Any, Optional @@ -45,6 +47,77 @@ def put(self, key: tuple, value: tuple) -> None: _MISSING = object() +def _safe_chrom(chrom: str) -> str: + """Sanitise a chromosome identifier into a filesystem-safe filename component.""" + return re.sub(r'[^A-Za-z0-9._-]', '_', str(chrom)) + + +def _ensure_fasta_index(input_fasta: str) -> str: + """Return the path to a `.idx` SQLite index for input_fasta, building it + if absent or stale. Stale means: the .idx exists but the FASTA's mtime + is newer than the .idx's. The index is keyed by `EnsemblDataService.get_key`. + """ + idx_path = input_fasta + ".idx" + if os.path.exists(idx_path): + if os.path.getmtime(idx_path) >= os.path.getmtime(input_fasta): + return idx_path + try: + os.remove(idx_path) + except OSError: + pass + # Build the index. SeqIO.index_db materialises the SQLite file on disk + # as a side effect; we don't need to keep the returned handle here. + SeqIO.index_db(idx_path, [input_fasta], "fasta", key_function=EnsemblDataService.get_key) + return idx_path + + +def _split_vcf_by_chrom(vcf_file: str, output_dir: str) -> dict[str, str]: + """Stream `vcf_file` once, writing per-chromosome VCFs into `output_dir`. + + Each output file is `/chunk_.vcf` and contains the + full VCF header followed by only the data lines for that chromosome. + Preserves line ordering within each chromosome. + + Returns a mapping `{chrom: chunk_path}`. Constant memory — holds at most + one open file handle per chromosome seen so far. + """ + header: list[str] = [] + handles: dict[str, Any] = {} + chunk_paths: dict[str, str] = {} + try: + with open(vcf_file, 'r', encoding='utf-8') as f: + for line in f: + if line.startswith('#'): + header.append(line) + continue + if not line.strip(): + continue + chrom = line.split('\t', 1)[0] + handle = handles.get(chrom) + if handle is None: + chunk_path = os.path.join(output_dir, "chunk_" + _safe_chrom(chrom) + ".vcf") + handle = open(chunk_path, 'w', encoding='utf-8') + handle.writelines(header) + handles[chrom] = handle + chunk_paths[chrom] = chunk_path + handle.write(line) + finally: + for h in handles.values(): + h.close() + return chunk_paths + + +def _vcf_to_proteindb_worker(default_params, pipeline_args, vcf_file, + input_fasta, gene_annotations_gtf, output_path): + """Module-level worker for multiprocessing.Pool.starmap. + + Re-constructs an EnsemblDataService in the worker process from the + pickled config dict + pipeline args, then runs the chunk pipeline. + """ + svc = EnsemblDataService(default_params, pipeline_args) + svc._vcf_to_proteindb_chunk(vcf_file, input_fasta, gene_annotations_gtf, output_path) + + class EnsemblDataService(ParameterConfiguration): CONFIG_KEY_VCF = "ensembl_translation" INPUT_FASTA = "input_fasta" @@ -73,6 +146,7 @@ class EnsemblDataService(ParameterConfiguration): EXPRESSION_THRESH = "expression_thresh" IGNORE_FILTERS = "ignore_filters" ACCEPTED_FILTERS = "accepted_filters" + WORKERS = "workers" def __init__(self, config_file: dict, pipeline_arguments: dict) -> None: """ @@ -129,6 +203,12 @@ def __init__(self, config_file: dict, pipeline_arguments: dict) -> None: self._accepted_filters = self.get_multiple_options( self.get_translation_properties(variable=self.ACCEPTED_FILTERS, default_value='PASS')) + raw_workers = self.get_translation_properties(variable=self.WORKERS, default_value=1) + try: + self._workers = max(1, int(raw_workers)) + except (TypeError, ValueError): + self._workers = 1 + def get_translation_properties(self, variable: str, default_value: Any) -> Any: value_return = default_value if variable in self.get_pipeline_parameters(): @@ -436,7 +516,7 @@ def vcf_from_file(vcf_file: str) -> tuple[list, pd.DataFrame]: return metadata, vcf_df - def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf: str) -> str: + def _vcf_to_proteindb_chunk(self, vcf_file: str, input_fasta: str, gene_annotations_gtf: str, output_path: str) -> str: """ Generate proteins for variants by modifying sequences of affected transcripts. In case of already annotated variants it only considers variants within @@ -447,13 +527,17 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf :param vcf_file: :param input_fasta: :param gene_annotations_gtf: + :param output_path: path for writing the output FASTA :return: """ db = self.parse_gtf(gene_annotations_gtf, str(Path(gene_annotations_gtf).with_suffix('.db'))) - transcripts_dict = SeqIO.index(input_fasta, "fasta", key_function=self.get_key) + idx_path = _ensure_fasta_index(input_fasta) + transcripts_dict = SeqIO.index_db(idx_path, [input_fasta], "fasta", + key_function=self.get_key) # handle cases where the transcript has version in the GTF but not in the VCF - transcript_id_mapping = {k.split('.')[0]: k for k in transcripts_dict.keys()} + # Built lazily on the first KeyError to avoid iterating 207k keys up-front. + transcript_id_mapping: Optional[dict[str, str]] = None feature_cache = _FeatureCache() # Value is (ref_seq, desc) for a known transcript, or None for a transcript # we've already looked up and confirmed isn't in the FASTA index (avoids re-trying @@ -491,6 +575,12 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf annotation_cols, vcf_file) self.get_logger().debug(msg) + # Fall back to index 0 when no structured ##INFO header was found (e.g. + # transcriptOverlaps-annotated VCFs produced by annoate_vcf, which write + # plain transcript IDs without a pipe-delimited FORMAT declaration). + if transcript_index is None: + transcript_index = 0 + else: # in case the given VCF is not annotated, annotate it by identifying the overlapping transcripts vcf_file = self.annoate_vcf(vcf_file, gene_annotations_gtf) @@ -506,7 +596,7 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf self._accepted_filters = [x.upper() for x in self._accepted_filters] - with open(self._proteindb_output, 'w', encoding='utf-8') as prots_fn: + with open(output_path, 'w', encoding='utf-8') as prots_fn: for record in vcf_reader.itertuples(index=False, name='VCFRecord'): trans = False if [x for x in str(record.REF) if x not in 'ACGT']: @@ -607,9 +697,16 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf continue try: - transcript_id_v = transcript_id_mapping[transcript_id] - except KeyError: - transcript_id_v = transcript_id + transcript_id_v = transcript_id_mapping[transcript_id] # type: ignore[index] + except (KeyError, TypeError): + if transcript_id_mapping is None: + transcript_id_mapping = {k.split('.')[0]: k for k in transcripts_dict.keys()} + try: + transcript_id_v = transcript_id_mapping[transcript_id] + except KeyError: + transcript_id_v = transcript_id + else: + transcript_id_v = transcript_id cached_row = seq_cache.get(transcript_id_v, _MISSING) if cached_row is _MISSING: @@ -726,6 +823,80 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf '\n'.join([x + ":" + str(invalid_records[x]) for x in invalid_records.keys()])) self.get_logger().info(msg) + return output_path + + def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf: str, workers=None) -> str: + """Generate proteins for variants by modifying sequences of affected transcripts. + + If workers is None, falls back to self._workers (config) which defaults + to 1 (sequential, backward-compatible). Pass workers > 1 to fan out per + chromosome via multiprocessing.Pool. + :param vcf_file: + :param input_fasta: + :param gene_annotations_gtf: + :param workers: number of parallel worker processes (None => use config default) + :return: path to the output proteindb FASTA + """ + if workers is None: + workers = self._workers if self._workers else 1 + + # Fast path: sequential — single call, identical behaviour to the original implementation. + # For sequential runs we do NOT pre-annotate here; _vcf_to_proteindb_chunk handles that + # in its else-branch so that transcript_index=0 is set correctly for unannotated VCFs. + if workers <= 1: + return self._vcf_to_proteindb_chunk(vcf_file, input_fasta, gene_annotations_gtf, self._proteindb_output) + + # Parallel path: pre-annotate unannotated VCFs in the main process. This avoids each + # worker racing on the same bedtools-output bed file (which annoate_vcf writes to cwd) + # and amortises the bedtools intersect across workers. + if not self._annotation_field_name: + vcf_file = self.annoate_vcf(vcf_file, gene_annotations_gtf) + self._annotation_field_name = 'transcriptOverlaps' + + # Build the FASTA index ONCE in the main process so all workers share it + # instead of each re-scanning the FASTA (~14 s × N workers wasted). + _ensure_fasta_index(input_fasta) + + # Parallel — split into per-chrom temp VCFs, fan out to a Pool. + import multiprocessing as mp + import shutil + import tempfile + + with tempfile.TemporaryDirectory(prefix='pgatk_v2p_') as tmpdir: + # Stream-split VCF by chromosome directly to per-chrom files (constant memory). + chunk_paths = _split_vcf_by_chrom(vcf_file, tmpdir) + + if len(chunk_paths) <= 1: + # Only one chromosome — run sequentially on the original file. + return self._vcf_to_proteindb_chunk(vcf_file, input_fasta, gene_annotations_gtf, + self._proteindb_output) + + tasks = [] + for chrom in sorted(chunk_paths.keys()): + chunk_vcf = chunk_paths[chrom] + chunk_out = os.path.join(tmpdir, f"out_{_safe_chrom(chrom)}.fa") + pa = dict(self.get_pipeline_parameters()) + pa[EnsemblDataService.PROTEIN_DB_OUTPUT] = chunk_out + # Force annotated mode for workers (annoate_vcf already ran in main): + pa[EnsemblDataService.ANNOTATION_FIELD_NAME] = self._annotation_field_name + tasks.append((self.get_default_parameters(), pa, chunk_vcf, + input_fasta, gene_annotations_gtf, chunk_out)) + + self.get_logger().info( + "vcf-to-proteindb: dispatching %d chromosome chunk(s) across %d worker(s)", + len(tasks), min(workers, len(tasks))) + + with mp.get_context('spawn').Pool(min(workers, len(tasks))) as pool: + pool.starmap(_vcf_to_proteindb_worker, tasks) + + # Concatenate the per-chunk FASTAs into the final output. + with open(self._proteindb_output, 'wb') as out: + for task in tasks: + chunk_out = task[5] + if os.path.exists(chunk_out): + with open(chunk_out, 'rb') as f: + shutil.copyfileobj(f, out) + return self._proteindb_output @staticmethod diff --git a/pgatk/tests/test_vcf_to_proteindb_parallel.py b/pgatk/tests/test_vcf_to_proteindb_parallel.py new file mode 100644 index 0000000..c9183ff --- /dev/null +++ b/pgatk/tests/test_vcf_to_proteindb_parallel.py @@ -0,0 +1,64 @@ +"""Equivalence test: vcf_to_proteindb with workers=2 produces the same sequence +content as workers=1 on a small fixture.""" +import os +from pathlib import Path + +from click.testing import CliRunner + +from pgatk.cli import cli + + +def _sequence_set(fasta_path) -> set: + seqs = [] + current = [] + with open(fasta_path, 'r', encoding='utf-8') as f: + for line in f: + if line.startswith('>'): + if current: + seqs.append(''.join(current)) + current = [] + else: + current.append(line.strip()) + if current: + seqs.append(''.join(current)) + return set(seqs) + + +def test_parallel_matches_sequential(tmp_path): + # pgatk package root: .../pgatk/pgatk/ + pkg_root = Path(__file__).resolve().parents[1] + testdata = pkg_root / 'testdata' + config = pkg_root / 'config' / 'ensembl_config.yaml' + out_seq = tmp_path / 'seq.fa' + out_par = tmp_path / 'par.fa' + + common_args = [ + 'vcf-to-proteindb', + '--config_file', str(config), + '--vcf', str(testdata / 'test.vcf'), + '--input_fasta', str(testdata / 'test.fa'), + '--gene_annotations_gtf', str(testdata / 'test.gtf'), + '--protein_prefix', 'ensvar', + '--af_field', 'MAF', + '--annotation_field_name', 'CSQ', + '--biotype_str', 'feature_type', + '--include_biotypes', 'mRNA,ncRNA', + ] + + runner = CliRunner() + orig_dir = os.getcwd() + try: + # Run from the package root so the .db file is created next to the .gtf. + os.chdir(str(pkg_root)) + + r1 = runner.invoke(cli, common_args + ['--output_proteindb', str(out_seq)]) + assert r1.exit_code == 0, r1.output + (str(r1.exception) if r1.exception else '') + + r2 = runner.invoke(cli, common_args + ['--output_proteindb', str(out_par), '--workers', '2']) + assert r2.exit_code == 0, r2.output + (str(r2.exception) if r2.exception else '') + finally: + os.chdir(orig_dir) + + # FASTA header order may differ across runs (per-worker ordering); compare + # sequence content as a set. + assert _sequence_set(out_seq) == _sequence_set(out_par) diff --git a/scripts/benchmark_vcf_to_proteindb.py b/scripts/benchmark_vcf_to_proteindb.py index bf17700..8901565 100755 --- a/scripts/benchmark_vcf_to_proteindb.py +++ b/scripts/benchmark_vcf_to_proteindb.py @@ -1,18 +1,27 @@ #!/usr/bin/env python3 -"""One-shot benchmark for ``pgatk vcf-to-proteindb``. +"""One-shot benchmark / profiler for ``pgatk vcf-to-proteindb``. -Times one full run end-to-end. Compare two runs of this script (one on a -pre-fix build, one on a post-fix build) to measure the speedup achieved -by issue #99. Not invoked by CI. +Times one full run end-to-end. With ``--profile-out PATH``, also wraps the +run in cProfile and writes a ``.prof`` file; print the top-N hotspots with +``--print-top``. Not invoked by CI. -Usage: +Basic usage: python scripts/benchmark_vcf_to_proteindb.py \ --vcf /path/to/sample.vcf \ --fasta /path/to/transcripts.fa \ --gtf /path/to/annotation.gtf \ --output /tmp/benchmark_proteindb.fa + +Profile a chr22 run and print the 30 hottest functions: + python scripts/benchmark_vcf_to_proteindb.py \ + --vcf chr22.vcf --fasta tx.fa --gtf chr22.gtf \ + --output /tmp/chr22.fa \ + --profile-out /tmp/chr22.prof --print-top 30 """ import argparse +import cProfile +import io +import pstats import time from pathlib import Path @@ -32,6 +41,17 @@ def main() -> None: help="VCF INFO field that carries per-transcript annotation (default: CSQ). " "Use empty string to force re-annotation via bedtools intersect.", ) + parser.add_argument( + "--profile-out", + help="If set, wrap the run in cProfile and write the .prof file here.", + ) + parser.add_argument( + "--print-top", + type=int, + default=0, + help="With --profile-out, also print the top-N functions by cumulative time " + "(default: 0, no printing). Recommended N=30 for a first look.", + ) args = parser.parse_args() config_data = load_config("ensembl_config", None) @@ -41,10 +61,20 @@ def main() -> None: } svc = EnsemblDataService(config_data, pipeline_arguments) + profiler = cProfile.Profile() if args.profile_out else None start = time.perf_counter() - svc.vcf_to_proteindb(args.vcf, args.fasta, args.gtf) + if profiler is not None: + profiler.enable() + try: + svc.vcf_to_proteindb(args.vcf, args.fasta, args.gtf) + finally: + if profiler is not None: + profiler.disable() elapsed = time.perf_counter() - start + if profiler is not None: + profiler.dump_stats(args.profile_out) + output_path = Path(args.output) output_size = output_path.stat().st_size if output_path.exists() else 0 seq_count = 0 @@ -59,6 +89,21 @@ def main() -> None: print(f" VCF: {args.vcf}") print(f" Elapsed: {elapsed:.2f} s ({elapsed / 60:.2f} min)") print(f" Output: {args.output} ({output_size} bytes, {seq_count} sequences)") + if args.profile_out: + print(f" Profile: {args.profile_out}") + + if args.profile_out and args.print_top > 0: + print() + print(f"=== TOP {args.print_top} BY CUMULATIVE TIME ===") + buf = io.StringIO() + stats = pstats.Stats(args.profile_out, stream=buf) + stats.strip_dirs().sort_stats('cumulative').print_stats(args.print_top) + print(buf.getvalue()) + print(f"=== TOP {args.print_top} BY OWN (tottime) TIME ===") + buf = io.StringIO() + stats = pstats.Stats(args.profile_out, stream=buf) + stats.strip_dirs().sort_stats('tottime').print_stats(args.print_top) + print(buf.getvalue()) if __name__ == "__main__":