diff --git a/kevlar/cli/count.py b/kevlar/cli/count.py index c28bc48..8153c31 100644 --- a/kevlar/cli/count.py +++ b/kevlar/cli/count.py @@ -60,9 +60,10 @@ def subparser(subparsers): mem_desc = """\ Specify how much memory to allocate for the sketch data structures - used to store k-mer counts. The first control sample will be - allocated the full amount of specifed `--memory`, and all subsequent - samples will be allocated a fraction thereof. + used to store k-mer counts. If `--mem-frac` is not set, all samples will be + allocated `MEM` bytes. If `--mem-frac` is set, then the first control + sample will be allocated `MEM` bytes, and all other samples will be + allocated `MEM * F` bytes. """ mem_desc = textwrap.dedent(mem_desc) memory_args = subparser.add_argument_group('Memory allocation', mem_desc) @@ -72,9 +73,9 @@ def subparser(subparsers): 'the initial control sample; default is 1M' ) memory_args.add_argument( - '-f', '--mem-frac', type=float, default=0.1, metavar='F', + '-f', '--mem-frac', type=float, default=None, metavar='F', help='fraction of the total memory to allocate to subsequent samples; ' - 'default is 0.1' + 'must be between 0.0 and 1.0' ) memory_args.add_argument( '--max-fpr', type=float, default=0.2, metavar='FPR', diff --git a/kevlar/count.py b/kevlar/count.py index a947d60..0e31c91 100644 --- a/kevlar/count.py +++ b/kevlar/count.py @@ -15,6 +15,12 @@ import kevlar +def split_infiles_outfiles(filelist): + outfiles = [flist[0] for flist in filelist] + infilelists = [flist[1:] for flist in filelist] + return outfiles, infilelists + + def main(args): if (args.num_bands is None) is not (args.band is None): raise ValueError('Must specify --num-bands and --band together') @@ -24,27 +30,32 @@ def main(args): timer.start('loadctrl') print('[kevlar::count] Loading control samples', file=args.logfile) - controls = kevlar.counting.load_samples_with_dilution( - args.control, args.ksize, args.memory, memfraction=args.mem_frac, - maxfpr=args.max_fpr, maxabund=args.ctrl_max, masks=None, - numbands=args.num_bands, band=args.band, logfile=args.logfile + outfiles, infilelists = split_infiles_outfiles(args.control) + o2, i2 = split_infiles_outfiles(args.case) + kevlar.counting.load_simplex_parallel( + infilelists + i2, args.ksize, args.memory, outfiles=outfiles + o2, + memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max, + mask=None, numbands=args.num_bands, band=args.band, + logfile=args.logfile ) elapsed = timer.stop('loadctrl') numcontrols = len(controls) message = '{:d} samples loaded in {:.2f} sec'.format(numcontrols, elapsed) print('[kevlar::count]', message, file=args.logfile) - print('[kevlar::count] Loading case samples', file=args.logfile) - timer.start('loadcase') - cases = kevlar.counting.load_samples_with_dilution( - args.case, args.ksize, args.memory, memfraction=args.mem_frac, - maxfpr=args.max_fpr, maxabund=args.ctrl_max, masks=controls, - numbands=args.num_bands, band=args.band, logfile=args.logfile - ) - elapsed = timer.stop('loadcase') - numcases = len(cases) - message = '{:d} sample(s) loaded in {:.2f} sec'.format(numcases, elapsed) - print('[kevlar::count]', message, file=args.logfile) + #print('[kevlar::count] Loading case samples', file=args.logfile) + #timer.start('loadcase') + #outfiles, infilelists = split_infiles_outfiles(args.case) + #cases = kevlar.counting.load_samples( + # infilelists, args.ksize, args.memory, outfiles=outfiles, + # memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max, + # mask=controls[0], numbands=args.num_bands, band=args.band, + # logfile=args.logfile + #) + #elapsed = timer.stop('loadcase') + #numcases = len(cases) + #message = '{:d} sample(s) loaded in {:.2f} sec'.format(numcases, elapsed) + #print('[kevlar::count]', message, file=args.logfile) total = timer.stop() message = 'Total time: {:.2f} seconds'.format(total) diff --git a/kevlar/counting.py b/kevlar/counting.py index 4cf164c..9ee8de3 100644 --- a/kevlar/counting.py +++ b/kevlar/counting.py @@ -8,6 +8,7 @@ # ----------------------------------------------------------------------------- from __future__ import print_function +import multiprocessing import re import sys @@ -19,38 +20,44 @@ class KevlarSampleIOError(ValueError): pass +class KevlarOutfileMismatchError(ValueError): + pass + + def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, - masks=None, maskmaxabund=1, numbands=None, band=None, + mask=None, numbands=None, band=None, outfile=None, logfile=sys.stderr): """ Compute k-mer abundances for the specified sequence input. Expected input is a list of one or more FASTA/FASTQ files corresponding to a single sample. A counttable is created and populated with abundances - of all k-mers observed in the input. + of all k-mers observed in the input. If `mask` is provided, only k-mers not + present in the mask will be loaded. """ message = 'loading sample from ' + ','.join(seqfiles) print('[kevlar::counting] ', message, file=logfile) sketch = khmer.Counttable(ksize, memory / 4, 4) n, nkmers = 0, 0 - for n, read in enumerate(kevlar.multi_file_iter_khmer(seqfiles), 1): - for subseq in kevlar.clean_subseqs(read.sequence, ksize): - for kmer in sketch.get_kmers(subseq): - if numbands: - khash = sketch.hash(kmer) - if khash & (numbands - 1) != band - 1: - continue - if masks: - for mask in masks: - if mask.get(kmer) > maskmaxabund: - break - else: - sketch.add(kmer) - nkmers += 1 - else: - sketch.add(kmer) - nkmers += 1 + + for seqfile in seqfiles: + if mask: + if numbands: + nr, nk = sketch.consume_seqfile_banding_with_mask( + seqfile, numbands, band, mask + ) + else: + nr, nk = sketch.consume_seqfile_with_mask(seqfile, mask) + else: + if numbands: + nr, nk = sketch.consume_seqfile_banding( + seqfile, numbands, band + ) + else: + nr, nk = sketch.consume_seqfile(seqfile) + n += nr + nkmers += nk message = 'done loading reads' if numbands: @@ -62,80 +69,93 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, if fpr > maxfpr: message += ' (FPR too high, bailing out!!!)' raise SystemExit(message) - else: - if outfile: - if not outfile.endswith(('.ct', '.counttable')): - outfile += '.counttable' - sketch.save(outfile) - message += '; saved to "{:s}"'.format(outfile) - print('[kevlar::counting] ', message, file=logfile) + if outfile: + if not outfile.endswith(('.ct', '.counttable')): + outfile += '.counttable' + sketch.save(outfile) + message += '; saved to "{:s}"'.format(outfile) + print('[kevlar::counting] ', message, file=logfile) return sketch -def load_samples(samplelists, ksize, memory, maxfpr=0.2, numbands=None, - band=None, logfile=sys.stderr): +def load_samples(samplelists, ksize, memory, mask=None, memfraction=None, + maxfpr=0.2, maxabund=1, numbands=None, band=None, + outfiles=None, logfile=sys.stderr): """ Load a group of related samples using a memory-efficient strategy. - Samples loaded initially are used as masks for subsequently loaded samples. - The first sample is allocated the full amount of memory, while subsequent - samples require only a fraction since they are first checked against the - mask(s). + By default, each sample is loaded into a dedicated counttable, which occupy + `memory` bytes of memory each. Setting `memfraction` to a value between 0.0 + and 1.0 will activate "masked" mode. + + If `mask` is provided, it serves as a mask for all other samples. If it is + not provided, the first sample is loaded normally and then serves as a mask + for all subsequent samples. + + In "masked mode", sample uses only `memory * memfraction` bytes of memory, + and any k-mer present in the mask (above a given threshold `maxabund`) is + ignored. In this way, we avoid taking up space storing abundances for + k-mers we know we're not interested in. """ numsamples = len(samplelists) + if outfiles is None: + outfiles = [None] * numsamples + if numsamples != len(outfiles): + message = '# of samples ({:d}) '.format(numsamples) + message += 'does not match # of outfiles ({:d})'.format(len(outfiles)) + raise KevlarOutfileMismatchError(message) message = 'computing k-mer abundances for {:d} samples'.format(numsamples) print('[kevlar::counting] ', message, file=logfile) sketches = list() - for seqfiles in samplelists: + for seqfiles, outfile in zip(samplelists, outfiles): + if mask and memfraction is not None: + mymask = mask + sketchmem = memory * memfraction + elif len(sketches) == 0 or memfraction is None: + mymask = None + sketchmem = memory + else: + mymask = sketches[0] + sketchmem = memory * memfraction sketch = load_sample_seqfile( - seqfiles, ksize, memory, maxfpr=maxfpr, numbands=numbands, - band=band, outfile=None, logfile=logfile + seqfiles, ksize, sketchmem, maxfpr=maxfpr, mask=mymask, + numbands=numbands, band=band, outfile=outfile, logfile=logfile ) sketches.append(sketch) return sketches -def load_samples_with_dilution(samplelists, ksize, memory, memfraction=0.1, - maxfpr=0.2, maxabund=1, masks=None, - numbands=None, band=None, skipsave=False, - logfile=sys.stderr): - """ - Load a group of related samples using a memory-efficient strategy. - - Samples loaded initially are used as masks for subsequently loaded samples. - The first sample is allocated the full amount of memory, while subsequent - samples require only a fraction since they are first checked against the - mask(s). - """ +def load_simplex_parallel(samplelists, ksize, memory, mask=None, + memfraction=None, maxfpr=0.2, maxabund=1, + numbands=None, band=None, outfiles=None, + logfile=sys.stderr): numsamples = len(samplelists) + if outfiles is None or numsamples != len(outfiles): + message = '# of samples ({:d}) '.format(numsamples) + message += 'does not match # of outfiles ({:d})'.format(len(outfiles)) + raise KevlarOutfileMismatchError(message) message = 'computing k-mer abundances for {:d} samples'.format(numsamples) print('[kevlar::counting] ', message, file=logfile) - sketches = list() - for samplelist in samplelists: - if len(samplelist) < 2: - message = 'must specify an output file and at least one input file' - raise KevlarSampleIOError(message) - outfile = samplelist[0] - seqfiles = samplelist[1:] - if masks: - mymasks = masks - sketchmem = memory * memfraction - elif len(sketches) == 0: - mymasks = None - sketchmem = memory - else: - mymasks = sketches - sketchmem = memory * memfraction - sketch = load_sample_seqfile( - seqfiles, ksize, sketchmem, maxfpr=maxfpr, masks=mymasks, - maskmaxabund=maxabund, numbands=numbands, band=band, - outfile=outfile, logfile=logfile + procs = list() + mymask = None + sketchmem = memory + for seqfiles, outfile in zip(samplelists, outfiles): + process = multiprocessing.Process( + target=load_sample_seqfile, + args=(seqfiles, ksize, sketchmem), + kwargs={ + 'maxfpr': maxfpr, 'mask': mymask, 'numbands': numbands, + 'band': band, 'outfile': outfile, 'logfile': logfile + }, ) - sketches.append(sketch) - return sketches + process.start() + procs.append(process) + + for process in procs: + process.join() def load_samples_sketchfiles(sketchfiles, maxfpr=0.2, logfile=sys.stderr): @@ -150,7 +170,6 @@ def load_samples_sketchfiles(sketchfiles, maxfpr=0.2, logfile=sys.stderr): if fpr > maxfpr: message += ' (FPR too high, bailing out!!!)' raise SystemExit(message) - else: - print(message, file=logfile) + print(message, file=logfile) sketches.append(sketch) return sketches diff --git a/kevlar/novel.py b/kevlar/novel.py index 86b4df7..e82e64c 100644 --- a/kevlar/novel.py +++ b/kevlar/novel.py @@ -68,7 +68,8 @@ def main(args): else: controls = kevlar.counting.load_samples( args.control, args.ksize, args.memory, maxfpr=args.max_fpr, - numbands=args.num_bands, band=args.band, logfile=args.logfile + memfraction=None, numbands=args.num_bands, band=args.band, + logfile=args.logfile ) elapsed = timer.stop('loadctrl') message = 'Control samples loaded in {:.2f} sec'.format(elapsed) @@ -92,7 +93,8 @@ def main(args): else: cases = kevlar.counting.load_samples( args.case, args.ksize, args.memory, maxfpr=args.max_fpr, - numbands=args.num_bands, band=args.band, logfile=args.logfile + memfraction=None, numbands=args.num_bands, band=args.band, + logfile=args.logfile ) elapsed = timer.stop('loadcases') print('[kevlar::novel] Case samples loaded in {:.2f} sec'.format(elapsed), diff --git a/kevlar/tests/test_count.py b/kevlar/tests/test_count.py index 895b3eb..961cbc5 100644 --- a/kevlar/tests/test_count.py +++ b/kevlar/tests/test_count.py @@ -14,12 +14,13 @@ import screed import kevlar from kevlar.tests import data_file, data_glob +import sys @pytest.mark.parametrize('numbands,band,kmers_stored', [ (0, 0, 15600), - (2, 1, 7992), - (16, 7, 1218), + (2, 1, 7663), + (16, 7, 859), ]) def test_count_simple(numbands, band, kmers_stored, capsys): with NamedTemporaryFile(suffix='.counttable') as ctrl1out, \ @@ -32,12 +33,14 @@ def test_count_simple(numbands, band, kmers_stored, capsys): '--case', caseout.name, case, '--control', ctrl1out.name, ctrls[0], '--control', ctrl2out.name, ctrls[1], - '--ksize', '25', '--memory', '5K', '--ctrl-max', '0', + '--ksize', '25', '--ctrl-max', '0', + '--mem-frac', '0.1', '--memory', '50K', '--num-bands', str(numbands), '--band', str(band), ] args = kevlar.cli.parser().parse_args(arglist) kevlar.count.main(args) out, err = capsys.readouterr() + print(err, file=sys.stderr) assert '600 reads processed' in str(err) assert '{:d} k-mers stored'.format(kmers_stored) in str(err)