diff --git a/README.md b/README.md index 1490524..1716387 100644 --- a/README.md +++ b/README.md @@ -131,3 +131,9 @@ running time of ``sampler`` is proportional to the length of ``id_list`` (to set up the priority queue) plus, if sampling is done with replacement, the value of ``drop+take``, where the constant of proportionality is about 1 microsecond. + +## Testing + +To run the test suite, just run: + + python code/test_consistent_sampler.py diff --git a/code/consistent_sampler.py b/code/consistent_sampler.py index 79f4954..d489793 100644 --- a/code/consistent_sampler.py +++ b/code/consistent_sampler.py @@ -373,35 +373,42 @@ def next_ticket(ticket): ticket.generation+1) -def make_ticket_heap(id_list, seed): - """Make a heap containing one ticket for each id in id_list. +def make_ticket_heap(id_list, max_tickets, seed): + """Make a heap containing the smallest max_tickets tickets generated from + the ids in id_list. Args: id_list (iterable): a list or iterable with a list of distinct hashable ids + max_tickets (int): the maximum number of tickets the heap can contain seed (str): a string or any printable python object. Returns: - a list that is a min-heap created by heapq with one ticket per id - in id_list. Ticket numbers are determined by the id and the seed. - By the heap property, the ticket_number at position i will be - less than or equal to the ticket_numbers at positions 2i+1 - and 2i+2. + a list that is a min-heap created by heapq containing the smallest + max_tickets tickets generated from id_list. Ticket numbers are + determined by the id and the seed for each id in id_list. The + smallest ticket numbers are kept in the heap, and the rest are + discarded. By the heap property, the ticket_number at position i will + be less than or equal to the ticket_numbers at positions 2i+1 and + 2i+2. Example: - >>> heap = make_ticket_heap(['dog', 'cat', 'fish', 'goat'], 'xy()134!g2n') + >>> heap = make_ticket_heap(['dog', 'cat', 'fish', 'goat'], 3, 'xy()134!g2n') >>> for ticket in heap: ... print(ticket) Ticket(ticket_number='0.24866413894129579898796850445568128508290132707747976039848637531569373309555', id='cat', generation=1) Ticket(ticket_number='0.33886035615681875183111698317327684455682722683976874746986356932751818935066', id='dog', generation=1) - Ticket(ticket_number='0.74685932088827950509145941729789143204056041958068799542050396198792954500593', id='fish', generation=1) Ticket(ticket_number='0.49599842072022713663423753308080171636735689997237236247068925068573448764387', id='goat', generation=1) """ - heap = [] seed_hash = sha256_hex(seed) - for id in id_list: - heapq.heappush(heap, first_ticket(id, seed, seed_hash)) + tickets = (first_ticket(id, seed, seed_hash) for id in id_list) + heap = ( + heapq.nsmallest(max_tickets, tickets) + if max_tickets < float('inf') + else list(tickets) + ) + heapq.heapify(heap) return heap @@ -497,8 +504,11 @@ def sampler(id_list, Args: id_list (iterable): a list or iterable for a finite collection of ids. Each id is typically a string, but may be a tuple - or other hashable object. It is checked that these ids - are distinct. + or other hashable object. When a list is given, it is checked + that these ids are distinct. When an iterator is given, the + caller is responsible for ensuring there are no duplicates (in + order to avoid the sampler having to load all of the ids into + memory at once). seed (object): a python object with a string representation with_replacement (bool): True if and only if sampling is with replacement (defaults to False) @@ -568,15 +578,23 @@ def sampler(id_list, or USAGE_EXAMPLES.md """ - assert len(id_list) == len(set(id_list)),\ - "Input id_list to sampler contains duplicate ids: {}"\ - .format(duplicates(id_list)) + if not isinstance(id_list, collections.Iterator): + assert len(id_list) == len(set(id_list)),\ + "Input id_list to sampler contains duplicate ids: {}"\ + .format(duplicates(id_list)) assert type(with_replacement) is bool output = output.lower() assert output in {'id', 'tuple', 'ticket'} assert type(digits) is int - heap = make_ticket_heap(id_list, seed) + # Generate the maximum number of tickets we want to populate the initial + # heap. When id_list is large, populating the initial ticket heap with a + # ticket for every id rapidly grows memory usage. When take is finite, we + # only need drop + take tickets in the initial heap to safely accommodate + # all possible samplings with replacement. + max_tickets = drop + take + heap = make_ticket_heap(id_list, max_tickets, seed) + count = 0 while len(heap) > 0: ticket = draw_without_replacement(heap) diff --git a/code/test_consistent_sampler.py b/code/test_consistent_sampler.py new file mode 100644 index 0000000..1b83a97 --- /dev/null +++ b/code/test_consistent_sampler.py @@ -0,0 +1,148 @@ +import unittest +import random +import itertools +from consistent_sampler import sampler + + +def ids(n): + id_list = [i for i in range(n)] + random.shuffle(id_list) + return id_list + + +class TestConsistentSampler(unittest.TestCase): + """ + Tests based on the paper describing consistent sampling: https://arxiv.org/abs/1808.10016 + """ + + def test_consistent_wrt_seed(self): + """ + Ensure the sample results are always the same for the same seed, and + always different for different seeds. + """ + for n in range(1, 10): + self.assertEqual( + list(sampler(ids(n), 12345)), + list(sampler(ids(n), 12345)), + ) + self.assertNotEqual( + list(sampler(ids(n), 12345)), + list(sampler(ids(n), 12346)), + ) + + self.assertEqual( + list(sampler(ids(n), 12345, take=n, with_replacement=True)), + list(sampler(ids(n), 12345, take=n, with_replacement=True)), + ) + self.assertNotEqual( + list(sampler(ids(n), 12345, take=n, with_replacement=True)), + list(sampler(ids(n), 12346, take=n, with_replacement=True)), + ) + + def test_consistent_wrt_sample_size(self): + """ + Ensure that when drawing a small sample and a larger from the same + list of ids, the small sample is a subset of the larger sample. + + From the paper: + For any pool I and any seed u, we have that for any sample size s + and s' with s' ≥ s: + S(I, u, s) ⊆ S(I, u, s0) + so that a larger sample is just an extension of a smaller sample. + """ + for i in range(1, 10): + for j in range(1, i): + self.assertEqual( + list(sampler(ids(10), 12345))[:j], + list(sampler(ids(10), 12345))[:i][:j], + ) + + for i in range(1, 20): + for j in range(1, i): + self.assertEqual( + list(sampler(ids(10), 12345, take=j, with_replacement=True)), + list(sampler(ids(10), 12345, take=i, with_replacement=True))[:j], + ) + + def test_consistent_wrt_population(self): + """ + If we draw a sample from a pool and another sample from a subset of + that pool, the sample from the subset should equal exactly the items + from the sample from the larger pool that are members of the subset. + + From the paper: + For any two nonempty sets J and K with J ⊆ K, we have + S(J , u) = S(K, u) ∩ J + where S ∩ J denotes the subsequence of sequence S obtained by + retaining only elements in J. + """ + n = 10 + for i in range(1, n): + K = random.sample(ids(n), random.randint(1, i)) + J = random.sample(K, random.randint(1, len(K))) + self.assertEqual( + list(sampler(J, 12345, output="id")), + [k for k in sampler(K, 12345, output="id") if k in J], + ) + self.assertEqual( + list( + sampler( + J, 12345, take=len(J) * 2, with_replacement=True, output="id" + ) + ), + list( + itertools.islice( + ( + k + for k in sampler( + K, 12345, with_replacement=True, output="id" + ) + if k in J + ), + len(J) * 2, + ) + ), + ) + + def test_take_and_drop(self): + for n in range(1, 10): + for d in range(1, n): + for t in range(1, n - d): + self.assertEqual( + list(sampler(ids(n), 12345, drop=d, take=t)), + list(sampler(ids(n), 12345))[d : d + t], + ) + self.assertEqual( + list( + sampler( + ids(n / 2), 12345, drop=d, take=t, with_replacement=True + ) + ), + list(sampler(ids(n / 2), 12345, with_replacement=True, take=n))[ + d : d + t + ], + ) + + def test_ordered_by_ticket_number(self): + for n in range(1, 10): + sample = list(sampler(ids(n), 12345)) + self.assertEqual(sample, sorted(sample, key=lambda ticket: ticket[0])) + sample = list(sampler(ids(n), 12345, with_replacement=True, take=n * 2)) + self.assertEqual(sample, sorted(sample, key=lambda ticket: ticket[0])) + + def test_replacement(self): + for n in range(1, 10): + sample = list(sampler(ids(n), 12345, with_replacement=True, take=n * 2)) + self.assertEqual(len(sample), n * 2) + self.assertTrue( + any(generation for (_, _, generation) in sample if generation > 1) + ) + self.assertTrue(max(generation for (_, _, generation) in sample) <= n * 2) + for _id, tickets in itertools.groupby(sample, key=lambda ticket: ticket[1]): + generations = [generation for (_, _, generation) in tickets] + self.assertEqual(generations, sorted(generations)) + self.assertEqual(generations, list(set(generations))) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/pkg/consistent_sampler/consistent_sampler.py b/pkg/consistent_sampler/consistent_sampler.py index 79f4954..d5cfbd8 100644 --- a/pkg/consistent_sampler/consistent_sampler.py +++ b/pkg/consistent_sampler/consistent_sampler.py @@ -146,10 +146,7 @@ import heapq -Ticket = collections.namedtuple("Ticket", - ['ticket_number', - 'id', - 'generation']) +Ticket = collections.namedtuple("Ticket", ["ticket_number", "id", "generation"]) """ A Ticket is a record referring to one object. @@ -192,10 +189,9 @@ def trim(x, mantissa_display_length=9): '0.9991234' """ - x0 = x+'0' - first_non_9_position = \ - min([i for i in range(2, len(x0)) if x0[i] < '9']) - return x[:first_non_9_position + mantissa_display_length] + x0 = x + "0" + first_non_9_position = min([i for i in range(2, len(x0)) if x0[i] < "9"]) + return x[: first_non_9_position + mantissa_display_length] def duplicates(L): @@ -224,7 +220,7 @@ def duplicates(L): def sha256_hex(hash_input): - """ Return 64-character hex representation of SHA256 of input. + """Return 64-character hex representation of SHA256 of input. Args: hash_input (obj): a python object having a string representation. @@ -238,7 +234,7 @@ def sha256_hex(hash_input): 'ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad' """ - return hashlib.sha256(str(hash_input).encode('utf-8')).hexdigest() + return hashlib.sha256(str(hash_input).encode("utf-8")).hexdigest() def sha256_uniform(hash_input): @@ -266,14 +262,14 @@ def sha256_uniform(hash_input): """ x_hex = sha256_hex(hash_input) - x_int = ("{:064d}".format(int(x_hex, 16))) + x_int = "{:064d}".format(int(x_hex, 16)) x_list = list(x_int) x_list.reverse() return "0." + "".join(x_list) def first_fraction(id, seed, seed_hash=None): - """ Return initial pseudo-random fraction for given id and seed. + """Return initial pseudo-random fraction for given id and seed. Args: id (obj): a hashable python object with a string representation @@ -298,7 +294,7 @@ def first_fraction(id, seed, seed_hash=None): def next_fraction(x): - """ Return pseudorandom real y in (x, 1) (so y>x). + """Return pseudorandom real y in (x, 1) (so y>x). Args: x (str): An input string of the form "0.ddd...dddd" @@ -318,16 +314,15 @@ def next_fraction(x): '0.642853261655004694691182528114375607701032283189922170593838029306715548381901' """ - assert x[:2] == '0.' - x0 = x+'0' # in case x mantissa is all 9s - first_non_9_position = \ - min([i for i in range(2, len(x0)) if x0[i] < '9']) - y = '0.' + assert x[:2] == "0." + x0 = x + "0" # in case x mantissa is all 9s + first_non_9_position = min([i for i in range(2, len(x0)) if x0[i] < "9"]) + y = "0." i = 0 while y <= x0: i = i + 1 y = x0[:first_non_9_position] - y = y + sha256_uniform(x + ':' + str(i))[2:] + y = y + sha256_uniform(x + ":" + str(i))[2:] return y @@ -368,40 +363,45 @@ def next_ticket(ticket): Ticket(ticket_number='0.8232357229934205790595761924514048157652891124687533667363938813600770093316', id='AB-130', generation=2) """ - return Ticket(next_fraction(ticket.ticket_number), - ticket.id, - ticket.generation+1) + return Ticket(next_fraction(ticket.ticket_number), ticket.id, ticket.generation + 1) -def make_ticket_heap(id_list, seed): - """Make a heap containing one ticket for each id in id_list. +def make_ticket_heap(id_list, max_tickets, seed): + """Make a heap containing the smallest max_tickets tickets generated from + the ids in id_list. Args: - id_list (iterable): a list or iterable with a list of distinct + id_list (iterable): a list or iterable with a list of distinct hashable ids + max_tickets (int): the maximum number of tickets the heap can contain seed (str): a string or any printable python object. Returns: - a list that is a min-heap created by heapq with one ticket per id - in id_list. Ticket numbers are determined by the id and the seed. - By the heap property, the ticket_number at position i will be - less than or equal to the ticket_numbers at positions 2i+1 - and 2i+2. + a list that is a min-heap created by heapq containing the smallest + max_tickets tickets generated from id_list. Ticket numbers are + determined by the id and the seed for each id in id_list. The + smallest ticket numbers are kept in the heap, and the rest are + discarded. By the heap property, the ticket_number at position i will + be less than or equal to the ticket_numbers at positions 2i+1 and + 2i+2. Example: - >>> heap = make_ticket_heap(['dog', 'cat', 'fish', 'goat'], 'xy()134!g2n') + >>> heap = make_ticket_heap(['dog', 'cat', 'fish', 'goat'], 3, 'xy()134!g2n') >>> for ticket in heap: ... print(ticket) Ticket(ticket_number='0.24866413894129579898796850445568128508290132707747976039848637531569373309555', id='cat', generation=1) Ticket(ticket_number='0.33886035615681875183111698317327684455682722683976874746986356932751818935066', id='dog', generation=1) - Ticket(ticket_number='0.74685932088827950509145941729789143204056041958068799542050396198792954500593', id='fish', generation=1) Ticket(ticket_number='0.49599842072022713663423753308080171636735689997237236247068925068573448764387', id='goat', generation=1) """ - heap = [] seed_hash = sha256_hex(seed) - for id in id_list: - heapq.heappush(heap, first_ticket(id, seed, seed_hash)) + tickets = (first_ticket(id, seed, seed_hash) for id in id_list) + heap = ( + heapq.nsmallest(max_tickets, tickets) + if max_tickets < float("inf") + else list(tickets) + ) + heapq.heapify(heap) return heap @@ -481,14 +481,15 @@ def draw_with_replacement(heap): return ticket -def sampler(id_list, - seed, - with_replacement=False, - drop=0, - take=float('inf'), - output='tuple', - digits=9, - ): +def sampler( + id_list, + seed, + with_replacement=False, + drop=0, + take=float("inf"), + output="tuple", + digits=9, +): """Return generator for a sample of the given list of ids. The sample is determined in a pseudo-random order controlled by @@ -497,8 +498,11 @@ def sampler(id_list, Args: id_list (iterable): a list or iterable for a finite collection of ids. Each id is typically a string, but may be a tuple - or other hashable object. It is checked that these ids - are distinct. + or other hashable object. When a list is given, it is checked + that these ids are distinct. When an iterator is given, the + caller is responsible for ensuring there are no duplicates (in + order to avoid the sampler having to load all of the ids into + memory at once). seed (object): a python object with a string representation with_replacement (bool): True if and only if sampling is with replacement (defaults to False) @@ -509,15 +513,15 @@ def sampler(id_list, If drop is 0, then take is an upper bound on the sample size. (defaults to infinity) output (str): one of {'id', 'tuple', 'ticket'} - Specifies whether each invocation of the returned generator + Specifies whether each invocation of the returned generator yields: an id, such as 'AB-130' a tuple (triple), such as ('0.235789114', 'AB-130', 1) or a Ticket, such as - Ticket(ticket_number='0.235789114', - id='AB-130', + Ticket(ticket_number='0.235789114', + id='AB-130', generation=1) digits (int): the number of significant digits to return in ticket numbers. (More precisely, this is the number of @@ -552,7 +556,7 @@ def sampler(id_list, Raises AssertionError if there are duplicate ids in id_list Examples: - >>> list(sampler(['A#2', 'B#7', 'C#1', 'D#4'], + >>> list(sampler(['A#2', 'B#7', 'C#1', 'D#4'], ... with_replacement=True, take=8, seed=314159, ... output='id')) ['D#4', 'C#1', 'C#1', 'B#7', 'A#2', 'C#1', 'D#4', 'B#7'] @@ -568,15 +572,25 @@ def sampler(id_list, or USAGE_EXAMPLES.md """ - assert len(id_list) == len(set(id_list)),\ - "Input id_list to sampler contains duplicate ids: {}"\ - .format(duplicates(id_list)) + if not isinstance(id_list, collections.Iterator): + assert len(id_list) == len( + set(id_list) + ), "Input id_list to sampler contains duplicate ids: {}".format( + duplicates(id_list) + ) assert type(with_replacement) is bool output = output.lower() - assert output in {'id', 'tuple', 'ticket'} + assert output in {"id", "tuple", "ticket"} assert type(digits) is int - - heap = make_ticket_heap(id_list, seed) + + # Generate the maximum number of tickets we want to populate the initial + # heap. When id_list is large, populating the initial ticket heap with a + # ticket for every id rapidly grows memory usage. When take is finite, we + # only need drop + take tickets in the initial heap to safely accommodate + # all possible samplings with replacement. + max_tickets = drop + take + heap = make_ticket_heap(id_list, max_tickets, seed) + count = 0 while len(heap) > 0: ticket = draw_without_replacement(heap) @@ -586,18 +600,21 @@ def sampler(id_list, if drop < count <= drop + take: ticket_list = list(ticket) ticket_list[0] = trim(ticket_list[0], digits) - if output == 'id': + if output == "id": yield ticket.id - elif output == 'tuple': + elif output == "tuple": yield tuple(ticket_list) else: - yield Ticket(ticket_number=ticket_list[0], - id=ticket_list[1], - generation=ticket_list[2]) - elif count > drop+take: + yield Ticket( + ticket_number=ticket_list[0], + id=ticket_list[1], + generation=ticket_list[2], + ) + elif count > drop + take: return -if __name__ == '__main__': +if __name__ == "__main__": import doctest + doctest.testmod() diff --git a/pkg/consistent_sampler/test_consistent_sampler.py b/pkg/consistent_sampler/test_consistent_sampler.py new file mode 100644 index 0000000..1b83a97 --- /dev/null +++ b/pkg/consistent_sampler/test_consistent_sampler.py @@ -0,0 +1,148 @@ +import unittest +import random +import itertools +from consistent_sampler import sampler + + +def ids(n): + id_list = [i for i in range(n)] + random.shuffle(id_list) + return id_list + + +class TestConsistentSampler(unittest.TestCase): + """ + Tests based on the paper describing consistent sampling: https://arxiv.org/abs/1808.10016 + """ + + def test_consistent_wrt_seed(self): + """ + Ensure the sample results are always the same for the same seed, and + always different for different seeds. + """ + for n in range(1, 10): + self.assertEqual( + list(sampler(ids(n), 12345)), + list(sampler(ids(n), 12345)), + ) + self.assertNotEqual( + list(sampler(ids(n), 12345)), + list(sampler(ids(n), 12346)), + ) + + self.assertEqual( + list(sampler(ids(n), 12345, take=n, with_replacement=True)), + list(sampler(ids(n), 12345, take=n, with_replacement=True)), + ) + self.assertNotEqual( + list(sampler(ids(n), 12345, take=n, with_replacement=True)), + list(sampler(ids(n), 12346, take=n, with_replacement=True)), + ) + + def test_consistent_wrt_sample_size(self): + """ + Ensure that when drawing a small sample and a larger from the same + list of ids, the small sample is a subset of the larger sample. + + From the paper: + For any pool I and any seed u, we have that for any sample size s + and s' with s' ≥ s: + S(I, u, s) ⊆ S(I, u, s0) + so that a larger sample is just an extension of a smaller sample. + """ + for i in range(1, 10): + for j in range(1, i): + self.assertEqual( + list(sampler(ids(10), 12345))[:j], + list(sampler(ids(10), 12345))[:i][:j], + ) + + for i in range(1, 20): + for j in range(1, i): + self.assertEqual( + list(sampler(ids(10), 12345, take=j, with_replacement=True)), + list(sampler(ids(10), 12345, take=i, with_replacement=True))[:j], + ) + + def test_consistent_wrt_population(self): + """ + If we draw a sample from a pool and another sample from a subset of + that pool, the sample from the subset should equal exactly the items + from the sample from the larger pool that are members of the subset. + + From the paper: + For any two nonempty sets J and K with J ⊆ K, we have + S(J , u) = S(K, u) ∩ J + where S ∩ J denotes the subsequence of sequence S obtained by + retaining only elements in J. + """ + n = 10 + for i in range(1, n): + K = random.sample(ids(n), random.randint(1, i)) + J = random.sample(K, random.randint(1, len(K))) + self.assertEqual( + list(sampler(J, 12345, output="id")), + [k for k in sampler(K, 12345, output="id") if k in J], + ) + self.assertEqual( + list( + sampler( + J, 12345, take=len(J) * 2, with_replacement=True, output="id" + ) + ), + list( + itertools.islice( + ( + k + for k in sampler( + K, 12345, with_replacement=True, output="id" + ) + if k in J + ), + len(J) * 2, + ) + ), + ) + + def test_take_and_drop(self): + for n in range(1, 10): + for d in range(1, n): + for t in range(1, n - d): + self.assertEqual( + list(sampler(ids(n), 12345, drop=d, take=t)), + list(sampler(ids(n), 12345))[d : d + t], + ) + self.assertEqual( + list( + sampler( + ids(n / 2), 12345, drop=d, take=t, with_replacement=True + ) + ), + list(sampler(ids(n / 2), 12345, with_replacement=True, take=n))[ + d : d + t + ], + ) + + def test_ordered_by_ticket_number(self): + for n in range(1, 10): + sample = list(sampler(ids(n), 12345)) + self.assertEqual(sample, sorted(sample, key=lambda ticket: ticket[0])) + sample = list(sampler(ids(n), 12345, with_replacement=True, take=n * 2)) + self.assertEqual(sample, sorted(sample, key=lambda ticket: ticket[0])) + + def test_replacement(self): + for n in range(1, 10): + sample = list(sampler(ids(n), 12345, with_replacement=True, take=n * 2)) + self.assertEqual(len(sample), n * 2) + self.assertTrue( + any(generation for (_, _, generation) in sample if generation > 1) + ) + self.assertTrue(max(generation for (_, _, generation) in sample) <= n * 2) + for _id, tickets in itertools.groupby(sample, key=lambda ticket: ticket[1]): + generations = [generation for (_, _, generation) in tickets] + self.assertEqual(generations, sorted(generations)) + self.assertEqual(generations, list(set(generations))) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file