diff --git a/kevlar/cli/count.py b/kevlar/cli/count.py index b659ca8..9157a73 100644 --- a/kevlar/cli/count.py +++ b/kevlar/cli/count.py @@ -109,3 +109,6 @@ def subparser(subparsers): help='show this help message and exit') misc_args.add_argument('-k', '--ksize', type=int, default=31, metavar='K', help='k-mer size; default is 31') + misc_args.add_argument('-t', '--threads', type=int, default=1, metavar='T', + help='number of threads to use for file processing;' + ' default is 1') diff --git a/kevlar/cli/novel.py b/kevlar/cli/novel.py index bf73927..819241d 100644 --- a/kevlar/cli/novel.py +++ b/kevlar/cli/novel.py @@ -138,3 +138,6 @@ def subparser(subparsers): misc_args.add_argument('--abund-screen', type=int, default=None, metavar='INT', help='discard reads with any k-mers ' 'whose abundance is < INT') + misc_args.add_argument('-t', '--threads', type=int, default=1, metavar='T', + help='number of threads to use for file processing;' + ' default is 1') diff --git a/kevlar/count.py b/kevlar/count.py index 58ad968..66ba9d8 100644 --- a/kevlar/count.py +++ b/kevlar/count.py @@ -34,7 +34,8 @@ def main(args): controls = kevlar.counting.load_samples( infilelists, args.ksize, args.memory, outfiles=outfiles, memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max, - mask=None, numbands=args.num_bands, band=myband, logfile=args.logfile + mask=None, numbands=args.num_bands, band=myband, + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadctrl') numcontrols = len(controls) @@ -49,7 +50,7 @@ def main(args): infilelists, args.ksize, args.memory, outfiles=outfiles, memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max, mask=casemask, numbands=args.num_bands, band=myband, - logfile=args.logfile + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadcase') numcases = len(cases) diff --git a/kevlar/counting.py b/kevlar/counting.py index 37cf50b..30ce224 100644 --- a/kevlar/counting.py +++ b/kevlar/counting.py @@ -9,6 +9,7 @@ import re import sys +import threading import khmer import kevlar @@ -24,7 +25,7 @@ class KevlarOutfileMismatchError(ValueError): def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, mask=None, maskmaxabund=1, numbands=None, band=None, - outfile=None, logfile=sys.stderr): + outfile=None, numthreads=1, logfile=sys.stderr): """ Compute k-mer abundances for the specified sequence input. @@ -39,29 +40,43 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, sketch = khmer.Counttable(ksize, memory / 4, 4) n, nkmers = 0, 0 for seqfile in seqfiles: - if mask: - if numbands: - nr, nk = sketch.consume_seqfile_banding_with_mask( - seqfile, numbands, band, mask - ) + parser = khmer.ReadParser(seqfile) + threads = list() + for _ in range(numthreads): + if mask: + if numbands: + thread = threading.Thread( + target=sketch.consume_seqfile_banding_with_mask, + args=(parser, numbands, band, mask, ), + ) + else: + thread = threading.Thread( + target=sketch.consume_seqfile_with_mask, + args=(parser, mask, ), + ) else: - nr, nk = sketch.consume_seqfile_with_mask(seqfile, mask) - else: - if numbands: - nr, nk = sketch.consume_seqfile_banding( - seqfile, numbands, band - ) - else: - nr, nk = sketch.consume_seqfile(seqfile) - n += nr - nkmers += nk + if numbands: + thread = threading.Thread( + target=sketch.consume_seqfile_banding, + args=(parser, numbands, band, ), + ) + else: + thread = threading.Thread( + target=sketch.consume_seqfile, + args=(parser, ), + ) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() message = 'done loading reads' if numbands: - message += ' (band {:d}/{:d})'.format(band, numbands) + message += ' (band {:d}/{:d})'.format(band+1, numbands) fpr = kevlar.sketch.estimate_fpr(sketch) - message += '; {:d} reads processed'.format(n) - message += ', {:d} k-mers stored'.format(nkmers) + message += '; {:d} reads processed'.format(parser.num_reads) + message += ', {:d} distinct k-mers stored'.format(sketch.n_unique_kmers()) message += '; estimated false positive rate is {:1.3f}'.format(fpr) if fpr > maxfpr: message += ' (FPR too high, bailing out!!!)' @@ -79,7 +94,7 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, def load_samples(samplelists, ksize, memory, mask=None, memfraction=None, maxfpr=0.2, maxabund=1, numbands=None, band=None, - outfiles=None, logfile=sys.stderr): + outfiles=None, numthreads=1, logfile=sys.stderr): """ Load a group of related samples using a memory-efficient strategy. @@ -114,7 +129,8 @@ def load_samples(samplelists, ksize, memory, mask=None, memfraction=None, mymask = sketches[0] sketch = load_sample_seqfile( seqfiles, ksize, sketchmem, maxfpr=maxfpr, mask=mymask, - numbands=numbands, band=band, outfile=outfile, logfile=logfile + numbands=numbands, band=band, outfile=outfile, + numthreads=numthreads, logfile=logfile ) sketches.append(sketch) return sketches diff --git a/kevlar/novel.py b/kevlar/novel.py index 8ef3664..2cea653 100644 --- a/kevlar/novel.py +++ b/kevlar/novel.py @@ -78,7 +78,7 @@ def main(args): controls = kevlar.counting.load_samples( args.control, args.ksize, args.memory, maxfpr=args.max_fpr, memfraction=None, numbands=args.num_bands, band=myband, - logfile=args.logfile + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadctrl') message = 'Control samples loaded in {:.2f} sec'.format(elapsed) @@ -103,7 +103,7 @@ def main(args): cases = kevlar.counting.load_samples( args.case, args.ksize, args.memory, maxfpr=args.max_fpr, memfraction=None, numbands=args.num_bands, band=myband, - logfile=args.logfile + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadcases') print('[kevlar::novel] Case samples loaded in {:.2f} sec'.format(elapsed), diff --git a/kevlar/tests/test_count.py b/kevlar/tests/test_count.py index e55f420..bcd82b9 100644 --- a/kevlar/tests/test_count.py +++ b/kevlar/tests/test_count.py @@ -16,10 +16,41 @@ from kevlar.tests import data_file, data_glob +@pytest.fixture +def triomask(): + mask = khmer.Counttable(19, 1e4, 4) + mask.consume('TGAGGGGACTAGGTGATCAGGTGAGGGTTTCCCAGTTCCCGAAGATGACT') + mask.consume('GATCTTTCGCTCCCTGTCATCAAGGAGTGATACGCGAAGTGCGTCCCCTT') + mask.consume('GAAGTTTTGACAATTTACGTGAGCCCTACCTAGCGAAACAACAGAGAACC') + return mask + + +@pytest.mark.parametrize('mask,numbands,band', [ + (None, None, None), + (None, 9, 2), + (triomask, None, None), + (triomask, 23, 19), +]) +def test_load_threading(mask, numbands, band): + # Smoke test: make sure things don't explode when run in "threaded" mode. + infiles = data_glob('trio1/case1.fq') + sketch = kevlar.counting.load_sample_seqfile( + infiles, 19, 1e7, mask=mask, numbands=numbands, band=band, numthreads=2 + ) + + +def test_load_sketches(): + infiles = data_glob('test.counttable') + sketches = kevlar.counting.load_samples_sketchfiles(infiles, maxfpr=0.5) + for sketch in sketches: + assert sketch.get('CCTGATATCCGGAATCTTAGC') > 0 + assert sketch.get('GATTACA' * 3) == 0 + + @pytest.mark.parametrize('numbands,band,kmers_stored', [ - (0, 0, 15600), - (2, 1, 7937), - (16, 7, 1068), + (0, 0, 947), + (2, 1, 500), + (16, 7, 68), ]) def test_count_simple(numbands, band, kmers_stored, capsys): with NamedTemporaryFile(suffix='.counttable') as ctrl1out, \ @@ -40,4 +71,24 @@ def test_count_simple(numbands, band, kmers_stored, capsys): out, err = capsys.readouterr() assert '600 reads processed' in str(err) - assert '{:d} k-mers stored'.format(kmers_stored) in str(err) + assert '{:d} distinct k-mers stored'.format(kmers_stored) in str(err) + + +def test_count_threading(): + with NamedTemporaryFile(suffix='.counttable') as ctrl1out, \ + NamedTemporaryFile(suffix='.counttable') as ctrl2out, \ + NamedTemporaryFile(suffix='.counttable') as caseout: + case = data_file('trio1/case1.fq') + ctrls = data_glob('trio1/ctrl[1,2].fq') + arglist = [ + 'count', + '--ksize', '19', '--memory', '500K', '--threads', '2', + '--case', caseout.name, case, + '--control', ctrl1out.name, ctrls[0], + '--control', ctrl2out.name, ctrls[1], + ] + args = kevlar.cli.parser().parse_args(arglist) + kevlar.count.main(args) + + # No checks, just doing a "smoke test" to make sure things don't explode + # when counting is done in "threaded" mode.