Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions kevlar/cli/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def subparser(subparsers):

mem_desc = """\
Specify how much memory to allocate for the sketch data structures
used to store k-mer counts. The first control sample will be
allocated the full amount of specifed `--memory`, and all subsequent
samples will be allocated a fraction thereof.
used to store k-mer counts. If `--mem-frac` is not set, all samples will be
allocated `MEM` bytes. If `--mem-frac` is set, then the first control
sample will be allocated `MEM` bytes, and all other samples will be
allocated `MEM * F` bytes.
"""
mem_desc = textwrap.dedent(mem_desc)
memory_args = subparser.add_argument_group('Memory allocation', mem_desc)
Expand All @@ -72,9 +73,9 @@ def subparser(subparsers):
'the initial control sample; default is 1M'
)
memory_args.add_argument(
'-f', '--mem-frac', type=float, default=0.1, metavar='F',
'-f', '--mem-frac', type=float, default=None, metavar='F',
help='fraction of the total memory to allocate to subsequent samples; '
'default is 0.1'
'must be between 0.0 and 1.0'
)
memory_args.add_argument(
'--max-fpr', type=float, default=0.2, metavar='FPR',
Expand Down
41 changes: 26 additions & 15 deletions kevlar/count.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
import kevlar


def split_infiles_outfiles(filelist):
outfiles = [flist[0] for flist in filelist]
infilelists = [flist[1:] for flist in filelist]
return outfiles, infilelists


def main(args):
if (args.num_bands is None) is not (args.band is None):
raise ValueError('Must specify --num-bands and --band together')
Expand All @@ -24,27 +30,32 @@ def main(args):

timer.start('loadctrl')
print('[kevlar::count] Loading control samples', file=args.logfile)
controls = kevlar.counting.load_samples_with_dilution(
args.control, args.ksize, args.memory, memfraction=args.mem_frac,
maxfpr=args.max_fpr, maxabund=args.ctrl_max, masks=None,
numbands=args.num_bands, band=args.band, logfile=args.logfile
outfiles, infilelists = split_infiles_outfiles(args.control)
o2, i2 = split_infiles_outfiles(args.case)
kevlar.counting.load_simplex_parallel(
infilelists + i2, args.ksize, args.memory, outfiles=outfiles + o2,
memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max,
mask=None, numbands=args.num_bands, band=args.band,
logfile=args.logfile
)
elapsed = timer.stop('loadctrl')
numcontrols = len(controls)
message = '{:d} samples loaded in {:.2f} sec'.format(numcontrols, elapsed)
print('[kevlar::count]', message, file=args.logfile)

print('[kevlar::count] Loading case samples', file=args.logfile)
timer.start('loadcase')
cases = kevlar.counting.load_samples_with_dilution(
args.case, args.ksize, args.memory, memfraction=args.mem_frac,
maxfpr=args.max_fpr, maxabund=args.ctrl_max, masks=controls,
numbands=args.num_bands, band=args.band, logfile=args.logfile
)
elapsed = timer.stop('loadcase')
numcases = len(cases)
message = '{:d} sample(s) loaded in {:.2f} sec'.format(numcases, elapsed)
print('[kevlar::count]', message, file=args.logfile)
#print('[kevlar::count] Loading case samples', file=args.logfile)
#timer.start('loadcase')
#outfiles, infilelists = split_infiles_outfiles(args.case)
#cases = kevlar.counting.load_samples(
# infilelists, args.ksize, args.memory, outfiles=outfiles,
# memfraction=args.mem_frac, maxfpr=args.max_fpr, maxabund=args.ctrl_max,
# mask=controls[0], numbands=args.num_bands, band=args.band,
# logfile=args.logfile
#)
#elapsed = timer.stop('loadcase')
#numcases = len(cases)
#message = '{:d} sample(s) loaded in {:.2f} sec'.format(numcases, elapsed)
#print('[kevlar::count]', message, file=args.logfile)

total = timer.stop()
message = 'Total time: {:.2f} seconds'.format(total)
Expand Down
161 changes: 90 additions & 71 deletions kevlar/counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# -----------------------------------------------------------------------------

from __future__ import print_function
import multiprocessing
import re
import sys

Expand All @@ -19,38 +20,44 @@ class KevlarSampleIOError(ValueError):
pass


class KevlarOutfileMismatchError(ValueError):
pass


def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2,
masks=None, maskmaxabund=1, numbands=None, band=None,
mask=None, numbands=None, band=None,
outfile=None, logfile=sys.stderr):
"""
Compute k-mer abundances for the specified sequence input.

Expected input is a list of one or more FASTA/FASTQ files corresponding
to a single sample. A counttable is created and populated with abundances
of all k-mers observed in the input.
of all k-mers observed in the input. If `mask` is provided, only k-mers not
present in the mask will be loaded.
"""
message = 'loading sample from ' + ','.join(seqfiles)
print('[kevlar::counting] ', message, file=logfile)

sketch = khmer.Counttable(ksize, memory / 4, 4)
n, nkmers = 0, 0
for n, read in enumerate(kevlar.multi_file_iter_khmer(seqfiles), 1):
for subseq in kevlar.clean_subseqs(read.sequence, ksize):
for kmer in sketch.get_kmers(subseq):
if numbands:
khash = sketch.hash(kmer)
if khash & (numbands - 1) != band - 1:
continue
if masks:
for mask in masks:
if mask.get(kmer) > maskmaxabund:
break
else:
sketch.add(kmer)
nkmers += 1
else:
sketch.add(kmer)
nkmers += 1

for seqfile in seqfiles:
if mask:
if numbands:
nr, nk = sketch.consume_seqfile_banding_with_mask(
seqfile, numbands, band, mask
)
else:
nr, nk = sketch.consume_seqfile_with_mask(seqfile, mask)
else:
if numbands:
nr, nk = sketch.consume_seqfile_banding(
seqfile, numbands, band
)
else:
nr, nk = sketch.consume_seqfile(seqfile)
n += nr
nkmers += nk

message = 'done loading reads'
if numbands:
Expand All @@ -62,80 +69,93 @@ def load_sample_seqfile(seqfiles, ksize, memory, maxfpr=0.2,
if fpr > maxfpr:
message += ' (FPR too high, bailing out!!!)'
raise SystemExit(message)
else:
if outfile:
if not outfile.endswith(('.ct', '.counttable')):
outfile += '.counttable'
sketch.save(outfile)
message += '; saved to "{:s}"'.format(outfile)
print('[kevlar::counting] ', message, file=logfile)
if outfile:
if not outfile.endswith(('.ct', '.counttable')):
outfile += '.counttable'
sketch.save(outfile)
message += '; saved to "{:s}"'.format(outfile)
print('[kevlar::counting] ', message, file=logfile)

return sketch


def load_samples(samplelists, ksize, memory, maxfpr=0.2, numbands=None,
band=None, logfile=sys.stderr):
def load_samples(samplelists, ksize, memory, mask=None, memfraction=None,
maxfpr=0.2, maxabund=1, numbands=None, band=None,
outfiles=None, logfile=sys.stderr):
"""
Load a group of related samples using a memory-efficient strategy.

Samples loaded initially are used as masks for subsequently loaded samples.
The first sample is allocated the full amount of memory, while subsequent
samples require only a fraction since they are first checked against the
mask(s).
By default, each sample is loaded into a dedicated counttable, which occupy
`memory` bytes of memory each. Setting `memfraction` to a value between 0.0
and 1.0 will activate "masked" mode.

If `mask` is provided, it serves as a mask for all other samples. If it is
not provided, the first sample is loaded normally and then serves as a mask
for all subsequent samples.

In "masked mode", sample uses only `memory * memfraction` bytes of memory,
and any k-mer present in the mask (above a given threshold `maxabund`) is
ignored. In this way, we avoid taking up space storing abundances for
k-mers we know we're not interested in.
"""
numsamples = len(samplelists)
if outfiles is None:
outfiles = [None] * numsamples
if numsamples != len(outfiles):
message = '# of samples ({:d}) '.format(numsamples)
message += 'does not match # of outfiles ({:d})'.format(len(outfiles))
raise KevlarOutfileMismatchError(message)
message = 'computing k-mer abundances for {:d} samples'.format(numsamples)
print('[kevlar::counting] ', message, file=logfile)

sketches = list()
for seqfiles in samplelists:
for seqfiles, outfile in zip(samplelists, outfiles):
if mask and memfraction is not None:
mymask = mask
sketchmem = memory * memfraction
elif len(sketches) == 0 or memfraction is None:
mymask = None
sketchmem = memory
else:
mymask = sketches[0]
sketchmem = memory * memfraction
sketch = load_sample_seqfile(
seqfiles, ksize, memory, maxfpr=maxfpr, numbands=numbands,
band=band, outfile=None, logfile=logfile
seqfiles, ksize, sketchmem, maxfpr=maxfpr, mask=mymask,
numbands=numbands, band=band, outfile=outfile, logfile=logfile
)
sketches.append(sketch)
return sketches


def load_samples_with_dilution(samplelists, ksize, memory, memfraction=0.1,
maxfpr=0.2, maxabund=1, masks=None,
numbands=None, band=None, skipsave=False,
logfile=sys.stderr):
"""
Load a group of related samples using a memory-efficient strategy.

Samples loaded initially are used as masks for subsequently loaded samples.
The first sample is allocated the full amount of memory, while subsequent
samples require only a fraction since they are first checked against the
mask(s).
"""
def load_simplex_parallel(samplelists, ksize, memory, mask=None,
memfraction=None, maxfpr=0.2, maxabund=1,
numbands=None, band=None, outfiles=None,
logfile=sys.stderr):
numsamples = len(samplelists)
if outfiles is None or numsamples != len(outfiles):
message = '# of samples ({:d}) '.format(numsamples)
message += 'does not match # of outfiles ({:d})'.format(len(outfiles))
raise KevlarOutfileMismatchError(message)
message = 'computing k-mer abundances for {:d} samples'.format(numsamples)
print('[kevlar::counting] ', message, file=logfile)

sketches = list()
for samplelist in samplelists:
if len(samplelist) < 2:
message = 'must specify an output file and at least one input file'
raise KevlarSampleIOError(message)
outfile = samplelist[0]
seqfiles = samplelist[1:]
if masks:
mymasks = masks
sketchmem = memory * memfraction
elif len(sketches) == 0:
mymasks = None
sketchmem = memory
else:
mymasks = sketches
sketchmem = memory * memfraction
sketch = load_sample_seqfile(
seqfiles, ksize, sketchmem, maxfpr=maxfpr, masks=mymasks,
maskmaxabund=maxabund, numbands=numbands, band=band,
outfile=outfile, logfile=logfile
procs = list()
mymask = None
sketchmem = memory
for seqfiles, outfile in zip(samplelists, outfiles):
process = multiprocessing.Process(
target=load_sample_seqfile,
args=(seqfiles, ksize, sketchmem),
kwargs={
'maxfpr': maxfpr, 'mask': mymask, 'numbands': numbands,
'band': band, 'outfile': outfile, 'logfile': logfile
},
)
sketches.append(sketch)
return sketches
process.start()
procs.append(process)

for process in procs:
process.join()


def load_samples_sketchfiles(sketchfiles, maxfpr=0.2, logfile=sys.stderr):
Expand All @@ -150,7 +170,6 @@ def load_samples_sketchfiles(sketchfiles, maxfpr=0.2, logfile=sys.stderr):
if fpr > maxfpr:
message += ' (FPR too high, bailing out!!!)'
raise SystemExit(message)
else:
print(message, file=logfile)
print(message, file=logfile)
sketches.append(sketch)
return sketches
6 changes: 4 additions & 2 deletions kevlar/novel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def main(args):
else:
controls = kevlar.counting.load_samples(
args.control, args.ksize, args.memory, maxfpr=args.max_fpr,
numbands=args.num_bands, band=args.band, logfile=args.logfile
memfraction=None, numbands=args.num_bands, band=args.band,
logfile=args.logfile
)
elapsed = timer.stop('loadctrl')
message = 'Control samples loaded in {:.2f} sec'.format(elapsed)
Expand All @@ -92,7 +93,8 @@ def main(args):
else:
cases = kevlar.counting.load_samples(
args.case, args.ksize, args.memory, maxfpr=args.max_fpr,
numbands=args.num_bands, band=args.band, logfile=args.logfile
memfraction=None, numbands=args.num_bands, band=args.band,
logfile=args.logfile
)
elapsed = timer.stop('loadcases')
print('[kevlar::novel] Case samples loaded in {:.2f} sec'.format(elapsed),
Expand Down
9 changes: 6 additions & 3 deletions kevlar/tests/test_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
import screed
import kevlar
from kevlar.tests import data_file, data_glob
import sys


@pytest.mark.parametrize('numbands,band,kmers_stored', [
(0, 0, 15600),
(2, 1, 7992),
(16, 7, 1218),
(2, 1, 7663),
(16, 7, 859),
])
def test_count_simple(numbands, band, kmers_stored, capsys):
with NamedTemporaryFile(suffix='.counttable') as ctrl1out, \
Expand All @@ -32,12 +33,14 @@ def test_count_simple(numbands, band, kmers_stored, capsys):
'--case', caseout.name, case,
'--control', ctrl1out.name, ctrls[0],
'--control', ctrl2out.name, ctrls[1],
'--ksize', '25', '--memory', '5K', '--ctrl-max', '0',
'--ksize', '25', '--ctrl-max', '0',
'--mem-frac', '0.1', '--memory', '50K',
'--num-bands', str(numbands), '--band', str(band),
]
args = kevlar.cli.parser().parse_args(arglist)
kevlar.count.main(args)
out, err = capsys.readouterr()

print(err, file=sys.stderr)
assert '600 reads processed' in str(err)
assert '{:d} k-mers stored'.format(kmers_stored) in str(err)