From 09879cd0f9e8daf0f2a8bd4018a91d99d39b9d77 Mon Sep 17 00:00:00 2001 From: Daniel Standage Date: Tue, 8 Aug 2017 01:12:11 -0700 Subject: [PATCH 1/5] Improve(?!?!) counting performance with threading [skip ci] --- kevlar/counting.py | 56 +++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/kevlar/counting.py b/kevlar/counting.py index 896975e..b7e3710 100644 --- a/kevlar/counting.py +++ b/kevlar/counting.py @@ -10,6 +10,7 @@ from __future__ import print_function import re import sys +import threading import khmer import kevlar @@ -25,7 +26,7 @@ class KevlarOutfileMismatchError(ValueError): def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, mask=None, maskmaxabund=1, numbands=None, band=None, - outfile=None, logfile=sys.stderr): + outfile=None, numthreads=1, logfile=sys.stderr): """ Compute k-mer abundances for the specified sequence input. @@ -40,29 +41,43 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, sketch = khmer.Counttable(ksize, memory / 4, 4) n, nkmers = 0, 0 for seqfile in seqfiles: - if mask: - if numbands: - nr, nk = sketch.consume_seqfile_banding_with_mask( - seqfile, numbands, band, mask - ) + parser = khmer.ReadParser(seqfile) + threads = list() + for _ in range(numthreads): + if mask: + if numbands: + thread = threading.Thread( + target=sketch.consume_seqfile_banding_with_mask_with_reads_parser, # noqa + args=(parser, numbands, band, mask), + ) + else: + thread = threading.Thread( + target=sketch.consume_seqfile_with_mask_with_reads_parser, # noqa + args=(parser, mask), + ) else: - nr, nk = sketch.consume_seqfile_with_mask(seqfile, mask) - else: - if numbands: - nr, nk = sketch.consume_seqfile_banding( - seqfile, numbands, band - ) - else: - nr, nk = sketch.consume_seqfile(seqfile) - n += nr - nkmers += nk + if numbands: + thread = threading.Thread( + target=sketch.consume_seqfile_banding_with_reads_parser, # noqa + args=(parser, numbands, band), + ) + else: + thread = threading.Thread( + target=sketch.consume_seqfile_with_reads_parser, + args=(parser), + ) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() message = 'done loading reads' if numbands: message += ' (band {:d}/{:d})'.format(band, numbands) fpr = kevlar.sketch.estimate_fpr(sketch) - message += '; {:d} reads processed'.format(n) - message += ', {:d} k-mers stored'.format(nkmers) + message += '; {:d} reads processed'.format(parser.num_reads) + message += ', {:d} distinct k-mers stored'.format(sketch.n_unique_kmers()) message += '; estimated false positive rate is {:1.3f}'.format(fpr) if fpr > maxfpr: message += ' (FPR too high, bailing out!!!)' @@ -80,7 +95,7 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, def load_samples(samplelists, ksize, memory, mask=None, memfraction=None, maxfpr=0.2, maxabund=1, numbands=None, band=None, - outfiles=None, logfile=sys.stderr): + outfiles=None, numthreads=1, logfile=sys.stderr): """ Load a group of related samples using a memory-efficient strategy. @@ -115,7 +130,8 @@ def load_samples(samplelists, ksize, memory, mask=None, memfraction=None, mymask = sketches[0] sketch = load_sample_seqfile( seqfiles, ksize, sketchmem, maxfpr=maxfpr, mask=mymask, - numbands=numbands, band=band, outfile=outfile, logfile=logfile + numbands=numbands, band=band, outfile=outfile, + numthreads=numthreads, logfile=logfile ) sketches.append(sketch) return sketches From d2c472968527dccedbd09a728b1db128ff6e5e83 Mon Sep 17 00:00:00 2001 From: Daniel Standage Date: Tue, 8 Aug 2017 10:00:18 -0700 Subject: [PATCH 2/5] Fix threading bug with comma --- kevlar/counting.py | 4 ++-- kevlar/tests/test_count.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kevlar/counting.py b/kevlar/counting.py index b7e3710..8ca9df2 100644 --- a/kevlar/counting.py +++ b/kevlar/counting.py @@ -64,7 +64,7 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, else: thread = threading.Thread( target=sketch.consume_seqfile_with_reads_parser, - args=(parser), + args=(parser, ), # Comma and space directly after "parser is critical" # noqa ) threads.append(thread) thread.start() @@ -74,7 +74,7 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, message = 'done loading reads' if numbands: - message += ' (band {:d}/{:d})'.format(band, numbands) + message += ' (band {:d}/{:d})'.format(band+1, numbands) fpr = kevlar.sketch.estimate_fpr(sketch) message += '; {:d} reads processed'.format(parser.num_reads) message += ', {:d} distinct k-mers stored'.format(sketch.n_unique_kmers()) diff --git a/kevlar/tests/test_count.py b/kevlar/tests/test_count.py index e55f420..5eed195 100644 --- a/kevlar/tests/test_count.py +++ b/kevlar/tests/test_count.py @@ -17,9 +17,9 @@ @pytest.mark.parametrize('numbands,band,kmers_stored', [ - (0, 0, 15600), - (2, 1, 7937), - (16, 7, 1068), + (0, 0, 947), + (2, 1, 500), + (16, 7, 68), ]) def test_count_simple(numbands, band, kmers_stored, capsys): with NamedTemporaryFile(suffix='.counttable') as ctrl1out, \ @@ -40,4 +40,4 @@ def test_count_simple(numbands, band, kmers_stored, capsys): out, err = capsys.readouterr() assert '600 reads processed' in str(err) - assert '{:d} k-mers stored'.format(kmers_stored) in str(err) + assert '{:d} distinct k-mers stored'.format(kmers_stored) in str(err) From 28f6072d7589f9f7afd747674552dba2a314ecde Mon Sep 17 00:00:00 2001 From: Daniel Standage Date: Tue, 8 Aug 2017 10:19:20 -0700 Subject: [PATCH 3/5] Add threads to CLI, and add tests --- kevlar/cli/count.py | 3 +++ kevlar/cli/novel.py | 3 +++ kevlar/count.py | 5 +++-- kevlar/novel.py | 4 ++-- kevlar/tests/test_count.py | 20 ++++++++++++++++++++ 5 files changed, 31 insertions(+), 4 deletions(-) diff --git a/kevlar/cli/count.py b/kevlar/cli/count.py index b659ca8..9157a73 100644 --- a/kevlar/cli/count.py +++ b/kevlar/cli/count.py @@ -109,3 +109,6 @@ def subparser(subparsers): help='show this help message and exit') misc_args.add_argument('-k', '--ksize', type=int, default=31, metavar='K', help='k-mer size; default is 31') + misc_args.add_argument('-t', '--threads', type=int, default=1, metavar='T', + help='number of threads to use for file processing;' + ' default is 1') diff --git a/kevlar/cli/novel.py b/kevlar/cli/novel.py index bf73927..819241d 100644 --- a/kevlar/cli/novel.py +++ b/kevlar/cli/novel.py @@ -138,3 +138,6 @@ def subparser(subparsers): misc_args.add_argument('--abund-screen', type=int, default=None, metavar='INT', help='discard reads with any k-mers ' 'whose abundance is < INT') + misc_args.add_argument('-t', '--threads', type=int, default=1, metavar='T', + help='number of threads to use for file processing;' + ' default is 1') diff --git a/kevlar/count.py b/kevlar/count.py index fd71833..0f2ca29 100644 --- a/kevlar/count.py +++ b/kevlar/count.py @@ -35,7 +35,8 @@ def main(args): controls = kevlar.counting.load_samples( infilelists, args.ksize, args.memory, outfiles=outfiles, memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max, - mask=None, numbands=args.num_bands, band=myband, logfile=args.logfile + mask=None, numbands=args.num_bands, band=myband, + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadctrl') numcontrols = len(controls) @@ -50,7 +51,7 @@ def main(args): infilelists, args.ksize, args.memory, outfiles=outfiles, memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max, mask=casemask, numbands=args.num_bands, band=myband, - logfile=args.logfile + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadcase') numcases = len(cases) diff --git a/kevlar/novel.py b/kevlar/novel.py index c8a5ee0..9d5ac98 100644 --- a/kevlar/novel.py +++ b/kevlar/novel.py @@ -79,7 +79,7 @@ def main(args): controls = kevlar.counting.load_samples( args.control, args.ksize, args.memory, maxfpr=args.max_fpr, memfraction=None, numbands=args.num_bands, band=myband, - logfile=args.logfile + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadctrl') message = 'Control samples loaded in {:.2f} sec'.format(elapsed) @@ -104,7 +104,7 @@ def main(args): cases = kevlar.counting.load_samples( args.case, args.ksize, args.memory, maxfpr=args.max_fpr, memfraction=None, numbands=args.num_bands, band=myband, - logfile=args.logfile + numthreads=args.threads, logfile=args.logfile ) elapsed = timer.stop('loadcases') print('[kevlar::novel] Case samples loaded in {:.2f} sec'.format(elapsed), diff --git a/kevlar/tests/test_count.py b/kevlar/tests/test_count.py index 5eed195..ad85897 100644 --- a/kevlar/tests/test_count.py +++ b/kevlar/tests/test_count.py @@ -41,3 +41,23 @@ def test_count_simple(numbands, band, kmers_stored, capsys): assert '600 reads processed' in str(err) assert '{:d} distinct k-mers stored'.format(kmers_stored) in str(err) + + +def test_count_threading(): + with NamedTemporaryFile(suffix='.counttable') as ctrl1out, \ + NamedTemporaryFile(suffix='.counttable') as ctrl2out, \ + NamedTemporaryFile(suffix='.counttable') as caseout: + case = data_file('trio1/case1.fq') + ctrls = data_glob('trio1/ctrl[1,2].fq') + arglist = [ + 'count', + '--ksize', '19', '--memory', '500K', '--threads', '2', + '--case', caseout.name, case, + '--control', ctrl1out.name, ctrls[0], + '--control', ctrl2out.name, ctrls[1], + ] + args = kevlar.cli.parser().parse_args(arglist) + kevlar.count.main(args) + + # No checks, just doing a "smoke test" to make sure things don't explode + # when counting is done in "threaded" mode. From a0c25b82bdfdb07b2d82af045cc1c1164fbcbaab Mon Sep 17 00:00:00 2001 From: Daniel Standage Date: Mon, 18 Sep 2017 11:26:06 -0700 Subject: [PATCH 4/5] Threading using latest khmer consume_seqfile updates --- kevlar/counting.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/kevlar/counting.py b/kevlar/counting.py index ce33e4b..30ce224 100644 --- a/kevlar/counting.py +++ b/kevlar/counting.py @@ -46,24 +46,24 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2, if mask: if numbands: thread = threading.Thread( - target=sketch.consume_seqfile_banding_with_mask_with_reads_parser, # noqa - args=(parser, numbands, band, mask), + target=sketch.consume_seqfile_banding_with_mask, + args=(parser, numbands, band, mask, ), ) else: thread = threading.Thread( - target=sketch.consume_seqfile_with_mask_with_reads_parser, # noqa - args=(parser, mask), + target=sketch.consume_seqfile_with_mask, + args=(parser, mask, ), ) else: if numbands: thread = threading.Thread( - target=sketch.consume_seqfile_banding_with_reads_parser, # noqa - args=(parser, numbands, band), + target=sketch.consume_seqfile_banding, + args=(parser, numbands, band, ), ) else: thread = threading.Thread( - target=sketch.consume_seqfile_with_reads_parser, - args=(parser, ), # Comma and space directly after "parser is critical" # noqa + target=sketch.consume_seqfile, + args=(parser, ), ) threads.append(thread) thread.start() From 3095e317ed2f98945720635545c8ff3c9041d7c7 Mon Sep 17 00:00:00 2001 From: Daniel Standage Date: Mon, 18 Sep 2017 12:25:08 -0700 Subject: [PATCH 5/5] More tests --- kevlar/tests/test_count.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/kevlar/tests/test_count.py b/kevlar/tests/test_count.py index ad85897..bcd82b9 100644 --- a/kevlar/tests/test_count.py +++ b/kevlar/tests/test_count.py @@ -16,6 +16,37 @@ from kevlar.tests import data_file, data_glob +@pytest.fixture +def triomask(): + mask = khmer.Counttable(19, 1e4, 4) + mask.consume('TGAGGGGACTAGGTGATCAGGTGAGGGTTTCCCAGTTCCCGAAGATGACT') + mask.consume('GATCTTTCGCTCCCTGTCATCAAGGAGTGATACGCGAAGTGCGTCCCCTT') + mask.consume('GAAGTTTTGACAATTTACGTGAGCCCTACCTAGCGAAACAACAGAGAACC') + return mask + + +@pytest.mark.parametrize('mask,numbands,band', [ + (None, None, None), + (None, 9, 2), + (triomask, None, None), + (triomask, 23, 19), +]) +def test_load_threading(mask, numbands, band): + # Smoke test: make sure things don't explode when run in "threaded" mode. + infiles = data_glob('trio1/case1.fq') + sketch = kevlar.counting.load_sample_seqfile( + infiles, 19, 1e7, mask=mask, numbands=numbands, band=band, numthreads=2 + ) + + +def test_load_sketches(): + infiles = data_glob('test.counttable') + sketches = kevlar.counting.load_samples_sketchfiles(infiles, maxfpr=0.5) + for sketch in sketches: + assert sketch.get('CCTGATATCCGGAATCTTAGC') > 0 + assert sketch.get('GATTACA' * 3) == 0 + + @pytest.mark.parametrize('numbands,band,kmers_stored', [ (0, 0, 947), (2, 1, 500),