Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions scripts/slicing_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
import os
import sys
import time

pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) # noqa
sys.path.insert(0, pkg_root) # noqa

from xsamtools import cram # noqa


cram_crai_pairs = [
('gs://lons-test/NWD938777.b38.irc.v1.cram', 'gs://lons-test/NWD938777.b38.irc.v1.cram.crai'),
('gs://lons-test/ce#5b.cram', 'gs://lons-test/ce#5b.cram.crai')
]

uris = []
for cram_gs_path, crai_gs_path in cram_crai_pairs:
for slicing_bool in [True]:
if cram_gs_path == 'gs://lons-test/ce#5b.cram':
regions = 'CHROMOSOME_I', 'CHROMOSOME_II', 'CHROMOSOME_I:100,CHROMOSOME_II', 'CHROMOSOME_IV', None
elif cram_gs_path == 'gs://lons-test/NWD938777.b38.irc.v1.cram':
regions = 'chr1', 'chr2', 'chr1:100,chr2', 'chr23', None
else:
print(f'Add the regions for this cram: {cram_gs_path}. Skipping... ')
regions = []

for region in regions:
print(f'Now running: {cram_gs_path} {crai_gs_path} {region} w/slicing={slicing_bool}')
start = time.time()
cram_output = cram.view(cram=cram_gs_path, crai=crai_gs_path, regions=region, cram_format=True,
slicing=slicing_bool)
end = time.time()
print(f'Timing was {(end - start) / 60} minutes.')
print('=' * 40)
4 changes: 2 additions & 2 deletions tests/test_cram.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def cram_view_with_regions(self, cram_uri, crai_uri, regions):

def run_cram_view_api_with_regions(self, cram, crai):
for region in self.regions:
with self.subTest(f'xsamtools view cram file:// {region}'):
with self.subTest(f'xsamtools view cram {cram} {crai} {region}'):
stdout, stderr = self.cram_view_with_regions(cram, crai, regions=region)
self.assertEqual(stdout, self.regions[region]['expected_output'])

# these don't change the output with the normal samtools command?
# we still need to test them but...
# TODO: make a better test for these
for subregion in ['1', '10', '3-100']:
with self.subTest(f'xsamtools view cram file:// {region}:{subregion}'):
with self.subTest(f'xsamtools view cram {cram} {crai} {region}:{subregion}'):
stdout, stderr = self.cram_view_with_regions(cram, crai, regions=f'{region}:{subregion}')
self.assertEqual(stdout, self.regions[region]['expected_output'])

Expand Down
10 changes: 9 additions & 1 deletion xsamtools/cli/cram.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
required=False,
help="Input crai file. This can be a Google Storage file (e.g. gs://bucket/key) or a local file. "
"If not specified, one will be generated for you (this may take a long time)."),
"--no-slicing": dict(action='store_true',
dest='no_slicing',
required=False,
default=False,
help="When acting upon cloud cram files, this will download whole files rather than "
"attempting to only download the 'slices' of those files that corresponding to "
"the --regions argument (slicing saves I/O and therefore time)."),
# TODO: add an argument to intake a BED file.
"--regions": dict(type=str,
required=False,
Expand All @@ -37,4 +44,5 @@ def view(args: argparse.Namespace):
"""
A limited wrapper around "samtools view", but with functions to operate on google cloud bucket keys.
"""
cram.view(cram=args.cram, crai=args.crai, regions=args.regions, output=args.output, cram_format=args.C)
cram.view(cram=args.cram, crai=args.crai, regions=args.regions, output=args.output, cram_format=args.C,
slicing=not args.no_slicing)
186 changes: 177 additions & 9 deletions xsamtools/cram.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CRAM/CRAI spec here:
http://samtools.github.io/hts-specs/CRAMv3.pdf
"""
import copy
import os
import datetime
import logging
Expand All @@ -12,16 +13,25 @@

from collections import namedtuple
from tempfile import TemporaryDirectory
from typing import Optional, Dict, Any, Union
from typing import Optional, Dict, Any, Union, Tuple, List
from urllib.request import urlretrieve
from terra_notebook_utils import xprofile

from xsamtools import gs_utils
from xsamtools.utils import run

try:
from gzip import BadGzipFile # Only on newer versions (py3.8+)
except ImportError:
BadGzipFile = OSError # type: ignore

CramLocation = namedtuple("CramLocation", "chr alignment_start alignment_span offset slice_offset slice_size")
log = logging.getLogger(__name__)


class SeqMapError(Exception):
pass

def read_fixed_length_cram_file_definition(fh: io.BytesIO) -> Dict[str, Union[int, str]]:
"""
This definition is always the first 26 bytes of a cram file.
Expand Down Expand Up @@ -97,6 +107,95 @@ def read_cram_container_header(fh: io.BytesIO) -> Dict[str, Any]:
"crc_hash": fh.read(4)
}

def read_seq_names_from_sam_header(fh, block_size: int) -> Tuple[Dict[bytes, int], int]:
"""
Every CRAM file contains a SAM header positioned after the first "container block header" which
is either raw or gzip compressed.

This function reads the SAM header and maps the human readable sequence names in it to their
numerical identifiers. The CRAM index file only points to blocks using the numerical identifiers,
so this allows us to take a user input of, for example, "chr1,chr2" and map which blocks contain
the numerical identifiers of those sequences in the CRAM index file.

Note: Every @SQ (sequence) tag is ordered and has a required SN (Sequence Name) value and possibly
optional AN (Alternate Name) values. These are the human readable names. The numerical
identifier is assigned in order, with a 1-index, assigning 1, 2, 3... etc. up until the
number of human readable sequence identifiers.

See SAM spec for detailed format: http://samtools.github.io/hts-specs/SAMv1.pdf
"""
sequence = {}
total_seq_identifiers = 0
handle = gzip.GzipFile(fileobj=fh) if is_gzipped_block(fh, block_size) else fh
try:
for line in handle:
if b'@SQ' in line:
for tag in line.split(b'\t'):
if tag == b'@SQ':
total_seq_identifiers += 1
elif tag[:2] in [b'SN', b'AN']:
tag_value = tag[3:]
assert tag_value not in sequence
sequence[tag_value] = total_seq_identifiers
if fh.tell() > block_size:
return sequence, total_seq_identifiers
except BadGzipFile:
return sequence, total_seq_identifiers
return sequence, total_seq_identifiers # it's unlikely we get to here


def is_gzipped_block(fh: io.StringIO, block_size: int) -> bool:
"""
Loops through a search space number of bytes equal to the block_size to see
if there is a gzip marker or raw b'@SQ' marker.

Will return True if a gzip marker b'\x1f\x8b' is found first.
Will return False if a raw b'@SQ' marker is found first.
Will error if neither is found.

Return True if a gzip marker is found and seek fh to the index that b'\x1f\x8b' begins at,
otherwise False and seek fh to the index that b'@SQ' begins at.
"""
c1 = fh.read(1)
bytes_read = 1
while bytes_read < block_size:
c2 = fh.read(1)
bytes_read += 1
if c1 == b'\x1f' and c2 == b'\x8b':
fh.seek(-2, 1)
return True
elif c1 == b'@' and c2 == b'S':
if fh.read(1) == b'Q':
fh.seek(-3, 1)
return False
else:
fh.seek(-1, 1) # go back one to undo the extra byte we just read
else:
c1 = copy.copy(c2)
# we did not see a gzip marker or a @SQ flag within the search_space.
raise SeqMapError('SAM header block markers not found.')

def get_seq_map(cram: str, crai_indices: List[CramLocation]) -> Dict[bytes, int]:
block_size = crai_indices[2].offset

# download the cram header contents
blob = gs_utils._blob_for_url(cram)
fh = io.BytesIO(blob.download_as_bytes(start=0, end=block_size, raw_download=False, checksum=None))

read_fixed_length_cram_file_definition(fh)
read_cram_container_header(fh)
# reading the above two should put us pretty close to the start of the SAM header
# TODO: Find out why this is not exactly the index location of the SAM header... sigh
seq_map, total_seq_identifiers = read_seq_names_from_sam_header(fh, block_size=block_size)
if total_seq_identifiers > len(crai_indices) or total_seq_identifiers == 0:
# TODO: crai_indices may have duplicates; make this exact (==) ^
raise SeqMapError(f'Something went wrong reading the cram header (total_seq_identifiers != len(crai_indices)).\n'
f'total_seq_identifiers: {total_seq_identifiers}\n'
f'seq_map: {seq_map}\n'
f'len(crai_indices): {len(crai_indices)}\n'
f'{crai_indices}\n')
return seq_map

def decode_int32(fh: io.BytesIO) -> int:
"""A CRAM defined 32-bit signed integer type."""
return int.from_bytes(fh.read(4), byteorder='little', signed=True)
Expand Down Expand Up @@ -344,11 +443,11 @@ def decode_itf8_array(handle: io.BytesIO, size: Optional[int] = None):
size = decode_itf8(handle)
return [decode_itf8(handle) for _ in range(size)]

def get_crai_indices(crai):
def get_crai_indices(crai: str) -> List[CramLocation]:
crai_indices = []
with open(crai, "rb") as fh:
with gzip.GzipFile(fileobj=fh) as gzip_reader:
with io.TextIOWrapper(gzip_reader, encoding='ascii') as reader:
with io.TextIOWrapper(gzip_reader, encoding='ascii') as reader: # type: ignore
for line in reader:
crai_indices.append(CramLocation(*[int(d) for d in line.split("\t")]))
return crai_indices
Expand All @@ -359,7 +458,64 @@ def download_full_gs(gs_path: str, output_filename: str = None) -> str:
output_filename = output_filename if output_filename else os.path.abspath(os.path.basename(key_name))
blob = gs_utils._blob_for_url(gs_path)
blob.download_to_filename(output_filename)
log.debug(f'Entire file "{gs_path}" downloaded to: {output_filename}')
log.info(f'Entire file {gs_path} downloaded to: {output_filename}')
return output_filename

def ordered_slices_from_seq_identifiers(seq_identifiers: List[int],
crai_indices: List[CramLocation]) -> List[Tuple[int, Optional[int]]]:
"""
crai_indices is a list of all blocks in a cram file. Each block's contents are referenced by a seq_identifier
("chr"), and contain that block's start position in the file as an absolute index ("offset").

Given a list of these seq_identifiers, return a list of tuples representing absolute start and end indices for
blocks we're interested in.

Note: The first and last block are always included.
"""
slices: List[Tuple[int, Optional[int]]] = []
slice_start = 0
for crai_line in crai_indices:
if not slices or crai_line.chr in seq_identifiers or crai_line.chr == -1:
slices.append((slice_start, crai_line.offset))
slice_start = crai_line.offset

# always include the last block of the file, and make sure it always ends in "None"
slices.append((slice_start, None)) # "None" signals a read to the very end of the file
# TODO: Join like ends together to optimize
return slices

def download_sliced_cram(cram_gs_path: str,
crai_local_path: str,
regions: str,
output_filename: str = None) -> None:
"""
Given a cram, crai, and a set of regions, download only the sections of that cram file that correspond to
the first block, all relevant region blocks, & the final block, and write them into an output file.
"""
crai_indices = get_crai_indices(crai_local_path)
seq_map = get_seq_map(cram_gs_path, crai_indices)

sequence_integer_references = []
for region in regions.split(','):
seq_name = region.split(':')[0].encode('utf-8')
if seq_name in seq_map:
sequence_integer_references.append(seq_map[seq_name])

ordered_slices = ordered_slices_from_seq_identifiers(sequence_integer_references, crai_indices)
download_sliced_gs(cram_gs_path, ordered_slices, output_filename)

def download_sliced_gs(gs_path: str, ordered_slices: List[Tuple[int, int]], output_filename: str = None):
"""Given a gs:// path, download only referenced slices, and write them into output_filename."""
# TODO: use gs_chunked_io instead
output_filename = output_filename if output_filename else gs_path[len('gs://'):].split('/', 1)[-1]
blob = gs_utils._blob_for_url(gs_path)
with open(output_filename, "wb") as f:
for start, end in ordered_slices:
# TODO: google raises google.resumable_media.common.DataCorruption when checksumming (erroneously?)
new_string = blob.download_as_bytes(start=start, end=end, raw_download=False, checksum=None)
f.seek(start)
f.write(new_string)
log.info(f'Sliced file "{gs_path}" downloaded to: {output_filename}')
return output_filename

def format_and_check_cram(cram: str) -> str:
Expand All @@ -381,15 +537,14 @@ def write_final_file_with_samtools(cram: str,
if crai:
crai_arg = f'-X {crai}'
else:
log.warning('No crai file present, this may take a while.')
crai_arg = ''

# we can get away with a simple split on spaces here because there's nothing complicated going on
cmd = f'samtools view {cram_format_arg} {cram} {crai_arg} {region_args}'.split()

log.info(f'Now running: {cmd}')
run(cmd, stdout=open(output, 'w'), check=True)
log.debug(f'Output CRAM successfully generated at: {output}')
log.info(f'Output CRAM successfully generated at: {output}')

def stage(uri: str, output: str) -> None:
"""
Expand Down Expand Up @@ -421,21 +576,34 @@ def view(cram: str,
crai: Optional[str],
regions: Optional[str],
output: Optional[str] = None,
cram_format: bool = True) -> str:
cram_format: bool = True,
slicing: bool = True) -> str:
"""A limited wrapper around "samtools view", but with functions to operate on google cloud bucket keys."""
output = output or timestamped_filename(cram_format)
output = output[len('file://'):] if output.startswith('file://') else output
assert ':' not in output, f'Unsupported schema for output: "{output}".\n' \
f'Only local file outputs are currently supported.'

with TemporaryDirectory() as staging_dir:
staged_cram = os.path.join(staging_dir, 'tmp.cram')
stage(uri=cram, output=staged_cram)
if crai:
staged_crai = os.path.join(staging_dir, 'tmp.crai')
stage(uri=crai, output=staged_crai)
else:
log.warning('No crai file specified, this may take a long long time.')
staged_crai = None

staged_cram = os.path.join(staging_dir, 'tmp.cram')
if cram.startswith('gs://') and regions and slicing and staged_crai:
# try:
download_sliced_cram(cram, staged_crai, regions, output_filename=staged_cram)
# except Exception as e:
# log.warning(f'Slicing failed with:'
# f'\n{e}\n'
# f'Now making attempt without slicing.')
# stage(uri=cram, output=staged_cram)
else:
stage(uri=cram, output=staged_cram)

write_final_file_with_samtools(staged_cram, staged_crai, regions, cram_format, output)

return output