diff --git a/scripts/slicing_tests.py b/scripts/slicing_tests.py new file mode 100755 index 0000000..53d43b2 --- /dev/null +++ b/scripts/slicing_tests.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +import os +import sys +import time + +pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) # noqa +sys.path.insert(0, pkg_root) # noqa + +from xsamtools import cram # noqa + + +cram_crai_pairs = [ + ('gs://lons-test/NWD938777.b38.irc.v1.cram', 'gs://lons-test/NWD938777.b38.irc.v1.cram.crai'), + ('gs://lons-test/ce#5b.cram', 'gs://lons-test/ce#5b.cram.crai') +] + +uris = [] +for cram_gs_path, crai_gs_path in cram_crai_pairs: + for slicing_bool in [True]: + if cram_gs_path == 'gs://lons-test/ce#5b.cram': + regions = 'CHROMOSOME_I', 'CHROMOSOME_II', 'CHROMOSOME_I:100,CHROMOSOME_II', 'CHROMOSOME_IV', None + elif cram_gs_path == 'gs://lons-test/NWD938777.b38.irc.v1.cram': + regions = 'chr1', 'chr2', 'chr1:100,chr2', 'chr23', None + else: + print(f'Add the regions for this cram: {cram_gs_path}. Skipping... ') + regions = [] + + for region in regions: + print(f'Now running: {cram_gs_path} {crai_gs_path} {region} w/slicing={slicing_bool}') + start = time.time() + cram_output = cram.view(cram=cram_gs_path, crai=crai_gs_path, regions=region, cram_format=True, + slicing=slicing_bool) + end = time.time() + print(f'Timing was {(end - start) / 60} minutes.') + print('=' * 40) diff --git a/tests/test_cram.py b/tests/test_cram.py index 6b96f35..e262027 100755 --- a/tests/test_cram.py +++ b/tests/test_cram.py @@ -164,7 +164,7 @@ def cram_view_with_regions(self, cram_uri, crai_uri, regions): def run_cram_view_api_with_regions(self, cram, crai): for region in self.regions: - with self.subTest(f'xsamtools view cram file:// {region}'): + with self.subTest(f'xsamtools view cram {cram} {crai} {region}'): stdout, stderr = self.cram_view_with_regions(cram, crai, regions=region) self.assertEqual(stdout, self.regions[region]['expected_output']) @@ -172,7 +172,7 @@ def run_cram_view_api_with_regions(self, cram, crai): # we still need to test them but... # TODO: make a better test for these for subregion in ['1', '10', '3-100']: - with self.subTest(f'xsamtools view cram file:// {region}:{subregion}'): + with self.subTest(f'xsamtools view cram {cram} {crai} {region}:{subregion}'): stdout, stderr = self.cram_view_with_regions(cram, crai, regions=f'{region}:{subregion}') self.assertEqual(stdout, self.regions[region]['expected_output']) diff --git a/xsamtools/cli/cram.py b/xsamtools/cli/cram.py index e756efa..3bac193 100755 --- a/xsamtools/cli/cram.py +++ b/xsamtools/cli/cram.py @@ -18,6 +18,13 @@ required=False, help="Input crai file. This can be a Google Storage file (e.g. gs://bucket/key) or a local file. " "If not specified, one will be generated for you (this may take a long time)."), + "--no-slicing": dict(action='store_true', + dest='no_slicing', + required=False, + default=False, + help="When acting upon cloud cram files, this will download whole files rather than " + "attempting to only download the 'slices' of those files that corresponding to " + "the --regions argument (slicing saves I/O and therefore time)."), # TODO: add an argument to intake a BED file. "--regions": dict(type=str, required=False, @@ -37,4 +44,5 @@ def view(args: argparse.Namespace): """ A limited wrapper around "samtools view", but with functions to operate on google cloud bucket keys. """ - cram.view(cram=args.cram, crai=args.crai, regions=args.regions, output=args.output, cram_format=args.C) + cram.view(cram=args.cram, crai=args.crai, regions=args.regions, output=args.output, cram_format=args.C, + slicing=not args.no_slicing) diff --git a/xsamtools/cram.py b/xsamtools/cram.py index 89fca28..c55d66a 100755 --- a/xsamtools/cram.py +++ b/xsamtools/cram.py @@ -4,6 +4,7 @@ CRAM/CRAI spec here: http://samtools.github.io/hts-specs/CRAMv3.pdf """ +import copy import os import datetime import logging @@ -12,16 +13,25 @@ from collections import namedtuple from tempfile import TemporaryDirectory -from typing import Optional, Dict, Any, Union +from typing import Optional, Dict, Any, Union, Tuple, List from urllib.request import urlretrieve from terra_notebook_utils import xprofile from xsamtools import gs_utils from xsamtools.utils import run +try: + from gzip import BadGzipFile # Only on newer versions (py3.8+) +except ImportError: + BadGzipFile = OSError # type: ignore + CramLocation = namedtuple("CramLocation", "chr alignment_start alignment_span offset slice_offset slice_size") log = logging.getLogger(__name__) + +class SeqMapError(Exception): + pass + def read_fixed_length_cram_file_definition(fh: io.BytesIO) -> Dict[str, Union[int, str]]: """ This definition is always the first 26 bytes of a cram file. @@ -97,6 +107,95 @@ def read_cram_container_header(fh: io.BytesIO) -> Dict[str, Any]: "crc_hash": fh.read(4) } +def read_seq_names_from_sam_header(fh, block_size: int) -> Tuple[Dict[bytes, int], int]: + """ + Every CRAM file contains a SAM header positioned after the first "container block header" which + is either raw or gzip compressed. + + This function reads the SAM header and maps the human readable sequence names in it to their + numerical identifiers. The CRAM index file only points to blocks using the numerical identifiers, + so this allows us to take a user input of, for example, "chr1,chr2" and map which blocks contain + the numerical identifiers of those sequences in the CRAM index file. + + Note: Every @SQ (sequence) tag is ordered and has a required SN (Sequence Name) value and possibly + optional AN (Alternate Name) values. These are the human readable names. The numerical + identifier is assigned in order, with a 1-index, assigning 1, 2, 3... etc. up until the + number of human readable sequence identifiers. + + See SAM spec for detailed format: http://samtools.github.io/hts-specs/SAMv1.pdf + """ + sequence = {} + total_seq_identifiers = 0 + handle = gzip.GzipFile(fileobj=fh) if is_gzipped_block(fh, block_size) else fh + try: + for line in handle: + if b'@SQ' in line: + for tag in line.split(b'\t'): + if tag == b'@SQ': + total_seq_identifiers += 1 + elif tag[:2] in [b'SN', b'AN']: + tag_value = tag[3:] + assert tag_value not in sequence + sequence[tag_value] = total_seq_identifiers + if fh.tell() > block_size: + return sequence, total_seq_identifiers + except BadGzipFile: + return sequence, total_seq_identifiers + return sequence, total_seq_identifiers # it's unlikely we get to here + + +def is_gzipped_block(fh: io.StringIO, block_size: int) -> bool: + """ + Loops through a search space number of bytes equal to the block_size to see + if there is a gzip marker or raw b'@SQ' marker. + + Will return True if a gzip marker b'\x1f\x8b' is found first. + Will return False if a raw b'@SQ' marker is found first. + Will error if neither is found. + + Return True if a gzip marker is found and seek fh to the index that b'\x1f\x8b' begins at, + otherwise False and seek fh to the index that b'@SQ' begins at. + """ + c1 = fh.read(1) + bytes_read = 1 + while bytes_read < block_size: + c2 = fh.read(1) + bytes_read += 1 + if c1 == b'\x1f' and c2 == b'\x8b': + fh.seek(-2, 1) + return True + elif c1 == b'@' and c2 == b'S': + if fh.read(1) == b'Q': + fh.seek(-3, 1) + return False + else: + fh.seek(-1, 1) # go back one to undo the extra byte we just read + else: + c1 = copy.copy(c2) + # we did not see a gzip marker or a @SQ flag within the search_space. + raise SeqMapError('SAM header block markers not found.') + +def get_seq_map(cram: str, crai_indices: List[CramLocation]) -> Dict[bytes, int]: + block_size = crai_indices[2].offset + + # download the cram header contents + blob = gs_utils._blob_for_url(cram) + fh = io.BytesIO(blob.download_as_bytes(start=0, end=block_size, raw_download=False, checksum=None)) + + read_fixed_length_cram_file_definition(fh) + read_cram_container_header(fh) + # reading the above two should put us pretty close to the start of the SAM header + # TODO: Find out why this is not exactly the index location of the SAM header... sigh + seq_map, total_seq_identifiers = read_seq_names_from_sam_header(fh, block_size=block_size) + if total_seq_identifiers > len(crai_indices) or total_seq_identifiers == 0: + # TODO: crai_indices may have duplicates; make this exact (==) ^ + raise SeqMapError(f'Something went wrong reading the cram header (total_seq_identifiers != len(crai_indices)).\n' + f'total_seq_identifiers: {total_seq_identifiers}\n' + f'seq_map: {seq_map}\n' + f'len(crai_indices): {len(crai_indices)}\n' + f'{crai_indices}\n') + return seq_map + def decode_int32(fh: io.BytesIO) -> int: """A CRAM defined 32-bit signed integer type.""" return int.from_bytes(fh.read(4), byteorder='little', signed=True) @@ -344,11 +443,11 @@ def decode_itf8_array(handle: io.BytesIO, size: Optional[int] = None): size = decode_itf8(handle) return [decode_itf8(handle) for _ in range(size)] -def get_crai_indices(crai): +def get_crai_indices(crai: str) -> List[CramLocation]: crai_indices = [] with open(crai, "rb") as fh: with gzip.GzipFile(fileobj=fh) as gzip_reader: - with io.TextIOWrapper(gzip_reader, encoding='ascii') as reader: + with io.TextIOWrapper(gzip_reader, encoding='ascii') as reader: # type: ignore for line in reader: crai_indices.append(CramLocation(*[int(d) for d in line.split("\t")])) return crai_indices @@ -359,7 +458,64 @@ def download_full_gs(gs_path: str, output_filename: str = None) -> str: output_filename = output_filename if output_filename else os.path.abspath(os.path.basename(key_name)) blob = gs_utils._blob_for_url(gs_path) blob.download_to_filename(output_filename) - log.debug(f'Entire file "{gs_path}" downloaded to: {output_filename}') + log.info(f'Entire file {gs_path} downloaded to: {output_filename}') + return output_filename + +def ordered_slices_from_seq_identifiers(seq_identifiers: List[int], + crai_indices: List[CramLocation]) -> List[Tuple[int, Optional[int]]]: + """ + crai_indices is a list of all blocks in a cram file. Each block's contents are referenced by a seq_identifier + ("chr"), and contain that block's start position in the file as an absolute index ("offset"). + + Given a list of these seq_identifiers, return a list of tuples representing absolute start and end indices for + blocks we're interested in. + + Note: The first and last block are always included. + """ + slices: List[Tuple[int, Optional[int]]] = [] + slice_start = 0 + for crai_line in crai_indices: + if not slices or crai_line.chr in seq_identifiers or crai_line.chr == -1: + slices.append((slice_start, crai_line.offset)) + slice_start = crai_line.offset + + # always include the last block of the file, and make sure it always ends in "None" + slices.append((slice_start, None)) # "None" signals a read to the very end of the file + # TODO: Join like ends together to optimize + return slices + +def download_sliced_cram(cram_gs_path: str, + crai_local_path: str, + regions: str, + output_filename: str = None) -> None: + """ + Given a cram, crai, and a set of regions, download only the sections of that cram file that correspond to + the first block, all relevant region blocks, & the final block, and write them into an output file. + """ + crai_indices = get_crai_indices(crai_local_path) + seq_map = get_seq_map(cram_gs_path, crai_indices) + + sequence_integer_references = [] + for region in regions.split(','): + seq_name = region.split(':')[0].encode('utf-8') + if seq_name in seq_map: + sequence_integer_references.append(seq_map[seq_name]) + + ordered_slices = ordered_slices_from_seq_identifiers(sequence_integer_references, crai_indices) + download_sliced_gs(cram_gs_path, ordered_slices, output_filename) + +def download_sliced_gs(gs_path: str, ordered_slices: List[Tuple[int, int]], output_filename: str = None): + """Given a gs:// path, download only referenced slices, and write them into output_filename.""" + # TODO: use gs_chunked_io instead + output_filename = output_filename if output_filename else gs_path[len('gs://'):].split('/', 1)[-1] + blob = gs_utils._blob_for_url(gs_path) + with open(output_filename, "wb") as f: + for start, end in ordered_slices: + # TODO: google raises google.resumable_media.common.DataCorruption when checksumming (erroneously?) + new_string = blob.download_as_bytes(start=start, end=end, raw_download=False, checksum=None) + f.seek(start) + f.write(new_string) + log.info(f'Sliced file "{gs_path}" downloaded to: {output_filename}') return output_filename def format_and_check_cram(cram: str) -> str: @@ -381,7 +537,6 @@ def write_final_file_with_samtools(cram: str, if crai: crai_arg = f'-X {crai}' else: - log.warning('No crai file present, this may take a while.') crai_arg = '' # we can get away with a simple split on spaces here because there's nothing complicated going on @@ -389,7 +544,7 @@ def write_final_file_with_samtools(cram: str, log.info(f'Now running: {cmd}') run(cmd, stdout=open(output, 'w'), check=True) - log.debug(f'Output CRAM successfully generated at: {output}') + log.info(f'Output CRAM successfully generated at: {output}') def stage(uri: str, output: str) -> None: """ @@ -421,21 +576,34 @@ def view(cram: str, crai: Optional[str], regions: Optional[str], output: Optional[str] = None, - cram_format: bool = True) -> str: + cram_format: bool = True, + slicing: bool = True) -> str: + """A limited wrapper around "samtools view", but with functions to operate on google cloud bucket keys.""" output = output or timestamped_filename(cram_format) output = output[len('file://'):] if output.startswith('file://') else output assert ':' not in output, f'Unsupported schema for output: "{output}".\n' \ f'Only local file outputs are currently supported.' with TemporaryDirectory() as staging_dir: - staged_cram = os.path.join(staging_dir, 'tmp.cram') - stage(uri=cram, output=staged_cram) if crai: staged_crai = os.path.join(staging_dir, 'tmp.crai') stage(uri=crai, output=staged_crai) else: + log.warning('No crai file specified, this may take a long long time.') staged_crai = None + staged_cram = os.path.join(staging_dir, 'tmp.cram') + if cram.startswith('gs://') and regions and slicing and staged_crai: + # try: + download_sliced_cram(cram, staged_crai, regions, output_filename=staged_cram) + # except Exception as e: + # log.warning(f'Slicing failed with:' + # f'\n{e}\n' + # f'Now making attempt without slicing.') + # stage(uri=cram, output=staged_cram) + else: + stage(uri=cram, output=staged_cram) + write_final_file_with_samtools(staged_cram, staged_crai, regions, cram_format, output) return output