Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,8 @@ pgatk/testdata/Meleagris_gallopavo*

# Internal working docs (implementation plans, scratch notes)
docs/plans/

# BioPython SeqIO.index_db SQLite indexes — built lazily on first use,
# rebuilt automatically when the source FASTA changes (mtime check).
*.fa.idx
*.fasta.idx
9 changes: 7 additions & 2 deletions pgatk/commands/vcf_to_proteindb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@
help="enabling this option causes or variants to be parsed. By default only variants that have not failed any filters will be processed (FILTER column is PASS, None, .) or if the filters are subset of the accepted filters. (default is False)",
is_flag=True)
@click.option('--accepted_filters', help="Accepted filters for variant parsing")
@click.option('-w', '--workers', type=int, default=None,
help="Number of worker processes to fan out across (default: 1, sequential). Per-chromosome split.")
@click.pass_context
def vcf_to_proteindb(ctx, config_file, input_fasta, vcf, gene_annotations_gtf, translation_table,
mito_translation_table,
protein_prefix, report_ref_seq, output_proteindb, annotation_field_name,
af_field, af_threshold, transcript_str, biotype_str, exclude_biotypes,
include_biotypes, consequence_str, exclude_consequences,
skip_including_all_cds, include_consequences,
ignore_filters, accepted_filters):
ignore_filters, accepted_filters, workers):

config_data = load_config("ensembl_config", config_file)

Expand Down Expand Up @@ -105,5 +107,8 @@ def vcf_to_proteindb(ctx, config_file, input_fasta, vcf, gene_annotations_gtf, t
if accepted_filters is not None:
pipeline_arguments[EnsemblDataService.ACCEPTED_FILTERS] = accepted_filters

if workers is not None:
pipeline_arguments[EnsemblDataService.WORKERS] = workers

ensembl_data_service = EnsemblDataService(config_data, pipeline_arguments)
ensembl_data_service.vcf_to_proteindb(vcf, input_fasta, gene_annotations_gtf)
ensembl_data_service.vcf_to_proteindb(vcf, input_fasta, gene_annotations_gtf, workers=workers)
185 changes: 178 additions & 7 deletions pgatk/ensembl/ensembl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import logging
import os
import re
import sqlite3
from pathlib import Path
from typing import Any, Optional
Expand Down Expand Up @@ -45,6 +47,77 @@ def put(self, key: tuple, value: tuple) -> None:
_MISSING = object()


def _safe_chrom(chrom: str) -> str:
"""Sanitise a chromosome identifier into a filesystem-safe filename component."""
return re.sub(r'[^A-Za-z0-9._-]', '_', str(chrom))


def _ensure_fasta_index(input_fasta: str) -> str:
"""Return the path to a `.idx` SQLite index for input_fasta, building it
if absent or stale. Stale means: the .idx exists but the FASTA's mtime
is newer than the .idx's. The index is keyed by `EnsemblDataService.get_key`.
"""
idx_path = input_fasta + ".idx"
if os.path.exists(idx_path):
if os.path.getmtime(idx_path) >= os.path.getmtime(input_fasta):
return idx_path
try:
os.remove(idx_path)
except OSError:
pass
# Build the index. SeqIO.index_db materialises the SQLite file on disk
# as a side effect; we don't need to keep the returned handle here.
SeqIO.index_db(idx_path, [input_fasta], "fasta", key_function=EnsemblDataService.get_key)
return idx_path


def _split_vcf_by_chrom(vcf_file: str, output_dir: str) -> dict[str, str]:
"""Stream `vcf_file` once, writing per-chromosome VCFs into `output_dir`.

Each output file is `<output_dir>/chunk_<safe_chrom>.vcf` and contains the
full VCF header followed by only the data lines for that chromosome.
Preserves line ordering within each chromosome.

Returns a mapping `{chrom: chunk_path}`. Constant memory — holds at most
one open file handle per chromosome seen so far.
"""
header: list[str] = []
handles: dict[str, Any] = {}
chunk_paths: dict[str, str] = {}
try:
with open(vcf_file, 'r', encoding='utf-8') as f:
for line in f:
if line.startswith('#'):
header.append(line)
continue
if not line.strip():
continue
chrom = line.split('\t', 1)[0]
handle = handles.get(chrom)
if handle is None:
chunk_path = os.path.join(output_dir, "chunk_" + _safe_chrom(chrom) + ".vcf")
handle = open(chunk_path, 'w', encoding='utf-8')
handle.writelines(header)
handles[chrom] = handle
chunk_paths[chrom] = chunk_path
handle.write(line)
finally:
for h in handles.values():
h.close()
return chunk_paths


def _vcf_to_proteindb_worker(default_params, pipeline_args, vcf_file,
input_fasta, gene_annotations_gtf, output_path):
"""Module-level worker for multiprocessing.Pool.starmap.

Re-constructs an EnsemblDataService in the worker process from the
pickled config dict + pipeline args, then runs the chunk pipeline.
"""
svc = EnsemblDataService(default_params, pipeline_args)
svc._vcf_to_proteindb_chunk(vcf_file, input_fasta, gene_annotations_gtf, output_path)


class EnsemblDataService(ParameterConfiguration):
CONFIG_KEY_VCF = "ensembl_translation"
INPUT_FASTA = "input_fasta"
Expand Down Expand Up @@ -73,6 +146,7 @@ class EnsemblDataService(ParameterConfiguration):
EXPRESSION_THRESH = "expression_thresh"
IGNORE_FILTERS = "ignore_filters"
ACCEPTED_FILTERS = "accepted_filters"
WORKERS = "workers"

def __init__(self, config_file: dict, pipeline_arguments: dict) -> None:
"""
Expand Down Expand Up @@ -129,6 +203,12 @@ def __init__(self, config_file: dict, pipeline_arguments: dict) -> None:
self._accepted_filters = self.get_multiple_options(
self.get_translation_properties(variable=self.ACCEPTED_FILTERS, default_value='PASS'))

raw_workers = self.get_translation_properties(variable=self.WORKERS, default_value=1)
try:
self._workers = max(1, int(raw_workers))
except (TypeError, ValueError):
self._workers = 1

def get_translation_properties(self, variable: str, default_value: Any) -> Any:
value_return = default_value
if variable in self.get_pipeline_parameters():
Expand Down Expand Up @@ -436,7 +516,7 @@ def vcf_from_file(vcf_file: str) -> tuple[list, pd.DataFrame]:

return metadata, vcf_df

def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf: str) -> str:
def _vcf_to_proteindb_chunk(self, vcf_file: str, input_fasta: str, gene_annotations_gtf: str, output_path: str) -> str:
"""
Generate proteins for variants by modifying sequences of affected transcripts.
In case of already annotated variants it only considers variants within
Expand All @@ -447,13 +527,17 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf
:param vcf_file:
:param input_fasta:
:param gene_annotations_gtf:
:param output_path: path for writing the output FASTA
:return:
"""
db = self.parse_gtf(gene_annotations_gtf, str(Path(gene_annotations_gtf).with_suffix('.db')))

transcripts_dict = SeqIO.index(input_fasta, "fasta", key_function=self.get_key)
idx_path = _ensure_fasta_index(input_fasta)
transcripts_dict = SeqIO.index_db(idx_path, [input_fasta], "fasta",
key_function=self.get_key)
# handle cases where the transcript has version in the GTF but not in the VCF
transcript_id_mapping = {k.split('.')[0]: k for k in transcripts_dict.keys()}
# Built lazily on the first KeyError to avoid iterating 207k keys up-front.
transcript_id_mapping: Optional[dict[str, str]] = None
feature_cache = _FeatureCache()
# Value is (ref_seq, desc) for a known transcript, or None for a transcript
# we've already looked up and confirmed isn't in the FASTA index (avoids re-trying
Expand Down Expand Up @@ -491,6 +575,12 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf
annotation_cols, vcf_file)
self.get_logger().debug(msg)

# Fall back to index 0 when no structured ##INFO header was found (e.g.
# transcriptOverlaps-annotated VCFs produced by annoate_vcf, which write
# plain transcript IDs without a pipe-delimited FORMAT declaration).
if transcript_index is None:
transcript_index = 0

else:
# in case the given VCF is not annotated, annotate it by identifying the overlapping transcripts
vcf_file = self.annoate_vcf(vcf_file, gene_annotations_gtf)
Expand All @@ -506,7 +596,7 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf

self._accepted_filters = [x.upper() for x in self._accepted_filters]

with open(self._proteindb_output, 'w', encoding='utf-8') as prots_fn:
with open(output_path, 'w', encoding='utf-8') as prots_fn:
for record in vcf_reader.itertuples(index=False, name='VCFRecord'):
trans = False
if [x for x in str(record.REF) if x not in 'ACGT']:
Expand Down Expand Up @@ -607,9 +697,16 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf
continue

try:
transcript_id_v = transcript_id_mapping[transcript_id]
except KeyError:
transcript_id_v = transcript_id
transcript_id_v = transcript_id_mapping[transcript_id] # type: ignore[index]
except (KeyError, TypeError):
if transcript_id_mapping is None:
transcript_id_mapping = {k.split('.')[0]: k for k in transcripts_dict.keys()}
try:
transcript_id_v = transcript_id_mapping[transcript_id]
except KeyError:
transcript_id_v = transcript_id
else:
transcript_id_v = transcript_id

cached_row = seq_cache.get(transcript_id_v, _MISSING)
if cached_row is _MISSING:
Expand Down Expand Up @@ -726,6 +823,80 @@ def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf
'\n'.join([x + ":" + str(invalid_records[x]) for x in invalid_records.keys()]))
self.get_logger().info(msg)

return output_path

def vcf_to_proteindb(self, vcf_file: str, input_fasta: str, gene_annotations_gtf: str, workers=None) -> str:
"""Generate proteins for variants by modifying sequences of affected transcripts.

If workers is None, falls back to self._workers (config) which defaults
to 1 (sequential, backward-compatible). Pass workers > 1 to fan out per
chromosome via multiprocessing.Pool.
:param vcf_file:
:param input_fasta:
:param gene_annotations_gtf:
:param workers: number of parallel worker processes (None => use config default)
:return: path to the output proteindb FASTA
"""
if workers is None:
workers = self._workers if self._workers else 1

# Fast path: sequential — single call, identical behaviour to the original implementation.
# For sequential runs we do NOT pre-annotate here; _vcf_to_proteindb_chunk handles that
# in its else-branch so that transcript_index=0 is set correctly for unannotated VCFs.
if workers <= 1:
return self._vcf_to_proteindb_chunk(vcf_file, input_fasta, gene_annotations_gtf, self._proteindb_output)

# Parallel path: pre-annotate unannotated VCFs in the main process. This avoids each
# worker racing on the same bedtools-output bed file (which annoate_vcf writes to cwd)
# and amortises the bedtools intersect across workers.
if not self._annotation_field_name:
vcf_file = self.annoate_vcf(vcf_file, gene_annotations_gtf)
self._annotation_field_name = 'transcriptOverlaps'

# Build the FASTA index ONCE in the main process so all workers share it
# instead of each re-scanning the FASTA (~14 s × N workers wasted).
_ensure_fasta_index(input_fasta)

# Parallel — split into per-chrom temp VCFs, fan out to a Pool.
import multiprocessing as mp
import shutil
import tempfile

with tempfile.TemporaryDirectory(prefix='pgatk_v2p_') as tmpdir:
# Stream-split VCF by chromosome directly to per-chrom files (constant memory).
chunk_paths = _split_vcf_by_chrom(vcf_file, tmpdir)

if len(chunk_paths) <= 1:
# Only one chromosome — run sequentially on the original file.
return self._vcf_to_proteindb_chunk(vcf_file, input_fasta, gene_annotations_gtf,
self._proteindb_output)

tasks = []
for chrom in sorted(chunk_paths.keys()):
chunk_vcf = chunk_paths[chrom]
chunk_out = os.path.join(tmpdir, f"out_{_safe_chrom(chrom)}.fa")
pa = dict(self.get_pipeline_parameters())
pa[EnsemblDataService.PROTEIN_DB_OUTPUT] = chunk_out
# Force annotated mode for workers (annoate_vcf already ran in main):
pa[EnsemblDataService.ANNOTATION_FIELD_NAME] = self._annotation_field_name
tasks.append((self.get_default_parameters(), pa, chunk_vcf,
input_fasta, gene_annotations_gtf, chunk_out))

self.get_logger().info(
"vcf-to-proteindb: dispatching %d chromosome chunk(s) across %d worker(s)",
len(tasks), min(workers, len(tasks)))

with mp.get_context('spawn').Pool(min(workers, len(tasks))) as pool:
pool.starmap(_vcf_to_proteindb_worker, tasks)

# Concatenate the per-chunk FASTAs into the final output.
with open(self._proteindb_output, 'wb') as out:
for task in tasks:
chunk_out = task[5]
if os.path.exists(chunk_out):
with open(chunk_out, 'rb') as f:
shutil.copyfileobj(f, out)

return self._proteindb_output

@staticmethod
Expand Down
64 changes: 64 additions & 0 deletions pgatk/tests/test_vcf_to_proteindb_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Equivalence test: vcf_to_proteindb with workers=2 produces the same sequence
content as workers=1 on a small fixture."""
import os
from pathlib import Path

from click.testing import CliRunner

from pgatk.cli import cli


def _sequence_set(fasta_path) -> set:
seqs = []
current = []
with open(fasta_path, 'r', encoding='utf-8') as f:
for line in f:
if line.startswith('>'):
if current:
seqs.append(''.join(current))
current = []
else:
current.append(line.strip())
if current:
seqs.append(''.join(current))
return set(seqs)


def test_parallel_matches_sequential(tmp_path):
# pgatk package root: .../pgatk/pgatk/
pkg_root = Path(__file__).resolve().parents[1]
testdata = pkg_root / 'testdata'
config = pkg_root / 'config' / 'ensembl_config.yaml'
out_seq = tmp_path / 'seq.fa'
out_par = tmp_path / 'par.fa'

common_args = [
'vcf-to-proteindb',
'--config_file', str(config),
'--vcf', str(testdata / 'test.vcf'),
'--input_fasta', str(testdata / 'test.fa'),
'--gene_annotations_gtf', str(testdata / 'test.gtf'),
'--protein_prefix', 'ensvar',
'--af_field', 'MAF',
'--annotation_field_name', 'CSQ',
'--biotype_str', 'feature_type',
'--include_biotypes', 'mRNA,ncRNA',
]

runner = CliRunner()
orig_dir = os.getcwd()
try:
# Run from the package root so the .db file is created next to the .gtf.
os.chdir(str(pkg_root))

r1 = runner.invoke(cli, common_args + ['--output_proteindb', str(out_seq)])
assert r1.exit_code == 0, r1.output + (str(r1.exception) if r1.exception else '')

Check warning on line 55 in pgatk/tests/test_vcf_to_proteindb_parallel.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

pgatk/tests/test_vcf_to_proteindb_parallel.py#L55

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

r2 = runner.invoke(cli, common_args + ['--output_proteindb', str(out_par), '--workers', '2'])
assert r2.exit_code == 0, r2.output + (str(r2.exception) if r2.exception else '')

Check warning on line 58 in pgatk/tests/test_vcf_to_proteindb_parallel.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

pgatk/tests/test_vcf_to_proteindb_parallel.py#L58

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
finally:
os.chdir(orig_dir)

# FASTA header order may differ across runs (per-worker ordering); compare
# sequence content as a set.
assert _sequence_set(out_seq) == _sequence_set(out_par)

Check warning on line 64 in pgatk/tests/test_vcf_to_proteindb_parallel.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

pgatk/tests/test_vcf_to_proteindb_parallel.py#L64

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
Loading