Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,9 @@ running time of ``sampler`` is proportional to the length of
``id_list`` (to set up the priority queue) plus, if sampling is done
with replacement, the value of ``drop+take``, where the constant of
proportionality is about 1 microsecond.

## Testing

To run the test suite, just run:

python code/test_consistent_sampler.py
54 changes: 36 additions & 18 deletions code/consistent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,35 +373,42 @@ def next_ticket(ticket):
ticket.generation+1)
Copy link
Copy Markdown
Collaborator Author

@jonahkagan jonahkagan Feb 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing I'm not sure about is whether i need to copy this file to pkg as well in order to import it in arlo



def make_ticket_heap(id_list, seed):
"""Make a heap containing one ticket for each id in id_list.
def make_ticket_heap(id_list, max_tickets, seed):
"""Make a heap containing the smallest max_tickets tickets generated from
the ids in id_list.

Args:
id_list (iterable): a list or iterable with a list of distinct
hashable ids
max_tickets (int): the maximum number of tickets the heap can contain
seed (str): a string or any printable python object.

Returns:
a list that is a min-heap created by heapq with one ticket per id
in id_list. Ticket numbers are determined by the id and the seed.
By the heap property, the ticket_number at position i will be
less than or equal to the ticket_numbers at positions 2i+1
and 2i+2.
a list that is a min-heap created by heapq containing the smallest
max_tickets tickets generated from id_list. Ticket numbers are
determined by the id and the seed for each id in id_list. The
smallest ticket numbers are kept in the heap, and the rest are
discarded. By the heap property, the ticket_number at position i will
be less than or equal to the ticket_numbers at positions 2i+1 and
2i+2.

Example:
>>> heap = make_ticket_heap(['dog', 'cat', 'fish', 'goat'], 'xy()134!g2n')
>>> heap = make_ticket_heap(['dog', 'cat', 'fish', 'goat'], 3, 'xy()134!g2n')
>>> for ticket in heap:
... print(ticket)
Ticket(ticket_number='0.24866413894129579898796850445568128508290132707747976039848637531569373309555', id='cat', generation=1)
Ticket(ticket_number='0.33886035615681875183111698317327684455682722683976874746986356932751818935066', id='dog', generation=1)
Ticket(ticket_number='0.74685932088827950509145941729789143204056041958068799542050396198792954500593', id='fish', generation=1)
Ticket(ticket_number='0.49599842072022713663423753308080171636735689997237236247068925068573448764387', id='goat', generation=1)
"""

heap = []
seed_hash = sha256_hex(seed)
for id in id_list:
heapq.heappush(heap, first_ticket(id, seed, seed_hash))
tickets = (first_ticket(id, seed, seed_hash) for id in id_list)
heap = (
heapq.nsmallest(max_tickets, tickets)
if max_tickets < float('inf')
else list(tickets)
)
heapq.heapify(heap)
return heap


Expand Down Expand Up @@ -497,8 +504,11 @@ def sampler(id_list,
Args:
id_list (iterable): a list or iterable for a finite collection
of ids. Each id is typically a string, but may be a tuple
or other hashable object. It is checked that these ids
are distinct.
or other hashable object. When a list is given, it is checked
that these ids are distinct. When an iterator is given, the
caller is responsible for ensuring there are no duplicates (in
order to avoid the sampler having to load all of the ids into
memory at once).
seed (object): a python object with a string representation
with_replacement (bool): True if and only if sampling is with
replacement (defaults to False)
Expand Down Expand Up @@ -568,15 +578,23 @@ def sampler(id_list,
or USAGE_EXAMPLES.md
"""

assert len(id_list) == len(set(id_list)),\
"Input id_list to sampler contains duplicate ids: {}"\
.format(duplicates(id_list))
if not isinstance(id_list, collections.Iterator):
assert len(id_list) == len(set(id_list)),\
"Input id_list to sampler contains duplicate ids: {}"\
.format(duplicates(id_list))
assert type(with_replacement) is bool
output = output.lower()
assert output in {'id', 'tuple', 'ticket'}
assert type(digits) is int

heap = make_ticket_heap(id_list, seed)
# Generate the maximum number of tickets we want to populate the initial
# heap. When id_list is large, populating the initial ticket heap with a
# ticket for every id rapidly grows memory usage. When take is finite, we
# only need drop + take tickets in the initial heap to safely accommodate
# all possible samplings with replacement.
max_tickets = drop + take
heap = make_ticket_heap(id_list, max_tickets, seed)

count = 0
while len(heap) > 0:
ticket = draw_without_replacement(heap)
Expand Down
148 changes: 148 additions & 0 deletions code/test_consistent_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import unittest
import random
import itertools
from consistent_sampler import sampler


def ids(n):
id_list = [i for i in range(n)]
random.shuffle(id_list)
return id_list


class TestConsistentSampler(unittest.TestCase):
"""
Tests based on the paper describing consistent sampling: https://arxiv.org/abs/1808.10016
"""

def test_consistent_wrt_seed(self):
"""
Ensure the sample results are always the same for the same seed, and
always different for different seeds.
"""
for n in range(1, 10):
self.assertEqual(
list(sampler(ids(n), 12345)),
list(sampler(ids(n), 12345)),
)
self.assertNotEqual(
list(sampler(ids(n), 12345)),
list(sampler(ids(n), 12346)),
)

self.assertEqual(
list(sampler(ids(n), 12345, take=n, with_replacement=True)),
list(sampler(ids(n), 12345, take=n, with_replacement=True)),
)
self.assertNotEqual(
list(sampler(ids(n), 12345, take=n, with_replacement=True)),
list(sampler(ids(n), 12346, take=n, with_replacement=True)),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also be testing these things without replacement?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that's the assertions just above this. i think i tested with and without replacement for every test case

)

def test_consistent_wrt_sample_size(self):
"""
Ensure that when drawing a small sample and a larger from the same
list of ids, the small sample is a subset of the larger sample.

From the paper:
For any pool I and any seed u, we have that for any sample size s
and s' with s' ≥ s:
S(I, u, s) ⊆ S(I, u, s0)
so that a larger sample is just an extension of a smaller sample.
"""
for i in range(1, 10):
for j in range(1, i):
self.assertEqual(
list(sampler(ids(10), 12345))[:j],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we aren't passing ids(10) as a fixture? Not that it's a huge performance hit with a small number, but...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhh just didn't know how to do fixtures with unittest. plus it shuffles the list of ids every time it generates them so i think that's a reason to create it on the fly every time

list(sampler(ids(10), 12345))[:i][:j],
)

for i in range(1, 20):
for j in range(1, i):
self.assertEqual(
list(sampler(ids(10), 12345, take=j, with_replacement=True)),
list(sampler(ids(10), 12345, take=i, with_replacement=True))[:j],
)

def test_consistent_wrt_population(self):
"""
If we draw a sample from a pool and another sample from a subset of
that pool, the sample from the subset should equal exactly the items
from the sample from the larger pool that are members of the subset.

From the paper:
For any two nonempty sets J and K with J ⊆ K, we have
S(J , u) = S(K, u) ∩ J
where S ∩ J denotes the subsequence of sequence S obtained by
retaining only elements in J.
"""
n = 10
for i in range(1, n):
K = random.sample(ids(n), random.randint(1, i))
J = random.sample(K, random.randint(1, len(K)))
self.assertEqual(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for readability, it might be nice to have a docstring explaining what we're doing here (also in the other tests)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good point these make little sense without the explanatory math

list(sampler(J, 12345, output="id")),
[k for k in sampler(K, 12345, output="id") if k in J],
)
self.assertEqual(
list(
sampler(
J, 12345, take=len(J) * 2, with_replacement=True, output="id"
)
),
list(
itertools.islice(
(
k
for k in sampler(
K, 12345, with_replacement=True, output="id"
)
if k in J
),
len(J) * 2,
)
),
)

def test_take_and_drop(self):
for n in range(1, 10):
for d in range(1, n):
for t in range(1, n - d):
self.assertEqual(
list(sampler(ids(n), 12345, drop=d, take=t)),
list(sampler(ids(n), 12345))[d : d + t],
)
self.assertEqual(
list(
sampler(
ids(n / 2), 12345, drop=d, take=t, with_replacement=True
)
),
list(sampler(ids(n / 2), 12345, with_replacement=True, take=n))[
d : d + t
],
)

def test_ordered_by_ticket_number(self):
for n in range(1, 10):
sample = list(sampler(ids(n), 12345))
self.assertEqual(sample, sorted(sample, key=lambda ticket: ticket[0]))
sample = list(sampler(ids(n), 12345, with_replacement=True, take=n * 2))
self.assertEqual(sample, sorted(sample, key=lambda ticket: ticket[0]))

def test_replacement(self):
for n in range(1, 10):
sample = list(sampler(ids(n), 12345, with_replacement=True, take=n * 2))
self.assertEqual(len(sample), n * 2)
self.assertTrue(
any(generation for (_, _, generation) in sample if generation > 1)
)
self.assertTrue(max(generation for (_, _, generation) in sample) <= n * 2)
for _id, tickets in itertools.groupby(sample, key=lambda ticket: ticket[1]):
generations = [generation for (_, _, generation) in tickets]
self.assertEqual(generations, sorted(generations))
self.assertEqual(generations, list(set(generations)))


if __name__ == "__main__":
unittest.main()
Loading