Skip to content

Commit a580a3a

Browse files
ArrayRecord Teamcopybara-github
authored andcommitted
Fix the determinism issue by introducing a pool of readers/tokenizers usable by intra-shard threads.
This pool has a fixed size when underutilized, but will scale if the demand for them spikes (i.e. the pool is empty). This keeps a reasonable predictable memory footprint, while allowing it to still respond to large increases in demand, although the readers can have non-trivial initialization time. PiperOrigin-RevId: 897665704
1 parent c909611 commit a580a3a

3 files changed

Lines changed: 72 additions & 34 deletions

File tree

python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
44
load("@pypi//:requirements.bzl", "requirement")
55

6+
67
package(default_visibility = ["//visibility:public"])
78

89
licenses(["notice"])

python/array_record_data_source.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3838
import itertools
3939
import os
4040
import pathlib
41+
import queue
4142
import re
4243
import typing
4344
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union
@@ -221,6 +222,7 @@ def __init__(
221222
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
222223
],
223224
reader_options: dict[str, str] | None = None,
225+
reader_pool_size: int = 1,
224226
):
225227
"""Creates a new ArrayRecordDataSource object.
226228
@@ -242,6 +244,8 @@ def __init__(
242244
initialization faster.
243245
reader_options: string of comma-separated options to be passed when
244246
creating a reader.
247+
reader_pool_size: Number of readers to pre-allocate in the pool for each
248+
shard. Default is 1.
245249
"""
246250
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
247251
paths = [paths]
@@ -270,8 +274,15 @@ def __init__(
270274
)
271275
self._read_instructions = _get_read_instructions(paths)
272276
self._paths = [ri.filename for ri in self._read_instructions]
277+
self._reader_pool_size = max(reader_pool_size, 1)
278+
# We maintain a pool of readers for each shard to ensure thread safety
279+
# while allowing concurrent reads.
280+
self._reader_pools = [
281+
queue.LifoQueue(maxsize=self._reader_pool_size)
282+
for _ in self._read_instructions
283+
]
273284
# We open readers lazily when we need to read from them.
274-
self._readers = [None] * len(self._read_instructions)
285+
275286
self._num_records = sum(
276287
map(lambda x: x.num_records, self._read_instructions)
277288
)
@@ -286,10 +297,11 @@ def __enter__(self):
286297

287298
def __exit__(self, exc_type, exc_value, traceback):
288299
logging.debug("__exit__ for ArrayRecordDataSource is called.")
289-
for reader in self._readers:
290-
if reader:
291-
reader.close()
292-
self._readers = [None] * len(self._read_instructions)
300+
for pool in self._reader_pools:
301+
while not pool.empty():
302+
reader = pool.get()
303+
if reader:
304+
reader.close()
293305

294306
def __len__(self) -> int:
295307
return self._num_records
@@ -329,21 +341,33 @@ def _split_keys_per_reader(
329341
positions_and_indices[reader_idx] = [(position, idx)]
330342
return positions_and_indices
331343

332-
def _ensure_reader_exists(self, reader_idx: int) -> None:
333-
"""Threadsafe method to create corresponding reader if it doesn't exist."""
334-
if self._readers[reader_idx] is not None:
335-
return
336-
filename = self._read_instructions[reader_idx].filename
337-
reader = _create_reader(filename, self._reader_options_string)
338-
_check_group_size(filename, reader)
339-
self._readers[reader_idx] = reader
344+
def _get_reader(self, reader_idx: int) -> Any:
345+
"""Gets a reader from the pool or creates a new one."""
346+
try:
347+
return self._reader_pools[reader_idx].get_nowait()
348+
except queue.Empty:
349+
filename = self._read_instructions[reader_idx].filename
350+
reader = _create_reader(filename, self._reader_options_string)
351+
_check_group_size(filename, reader)
352+
return reader
353+
354+
def _release_reader(self, reader_idx: int, reader: Any) -> None:
355+
"""Returns a reader to the pool."""
356+
try:
357+
self._reader_pools[reader_idx].put_nowait(reader)
358+
except queue.Full:
359+
if reader:
360+
reader.close()
340361

341362
def __getitem__(self, record_key: SupportsIndex) -> bytes:
342363
reader_idx, position = self._reader_idx_and_position(record_key)
343-
self._ensure_reader_exists(reader_idx)
344-
if hasattr(self._readers[reader_idx], "read"):
345-
return self._readers[reader_idx].read([position])[0]
346-
return self._readers[reader_idx][position]
364+
reader = self._get_reader(reader_idx)
365+
try:
366+
if hasattr(reader, "read"):
367+
return reader.read([position])[0]
368+
return reader[position]
369+
finally:
370+
self._release_reader(reader_idx, reader)
347371

348372
def __getitems__(
349373
self, record_keys: Sequence[SupportsIndex]
@@ -352,14 +376,16 @@ def read_records(
352376
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
353377
) -> Sequence[Tuple[Any, int]]:
354378
"""Reads records using the given reader keeping track of the indices."""
355-
# Initialize readers lazily when we need to read from them.
356-
self._ensure_reader_exists(reader_idx)
357-
positions, indices = list(zip(*reader_positions_and_indices))
358-
if hasattr(self._readers[reader_idx], "read"):
359-
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
360-
else:
361-
records = [self._readers[reader_idx][p] for p in positions]
362-
return list(zip(records, indices))
379+
reader = self._get_reader(reader_idx)
380+
try:
381+
positions, indices = list(zip(*reader_positions_and_indices))
382+
if hasattr(reader, "read"):
383+
records = reader.read(positions) # pytype: disable=attribute-error
384+
else:
385+
records = [reader[p] for p in positions]
386+
return list(zip(records, indices))
387+
finally:
388+
self._release_reader(reader_idx, reader)
363389

364390
positions_and_indices = self._split_keys_per_reader(record_keys)
365391
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
@@ -390,15 +416,26 @@ def read_records(
390416
def __getstate__(self):
391417
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
392418
state = self.__dict__.copy()
393-
del state["_readers"]
419+
state.pop("_reader_pools", None)
394420
return state
395421

396422
def __setstate__(self, state):
397423
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
398424
self.__dict__.update(state)
399-
# We open readers lazily when we need to read from them. Thus, we don't
400-
# need to re-open the same files as before pickling.
401-
self._readers = [None] * len(self._read_instructions)
425+
# After pickling, we re-initialize the reader pools.
426+
self._reader_pools = [
427+
queue.LifoQueue(maxsize=self._reader_pool_size)
428+
for _ in self._read_instructions
429+
]
430+
# We open readers lazily when we need to read from them.
431+
432+
def _peek_readers(self) -> List[Any]:
433+
"""Returns a list of readers (one per shard) or None (for testing)."""
434+
readers = []
435+
for pool in self._reader_pools:
436+
with pool.mutex:
437+
readers.append(pool.queue[-1] if pool.queue else None)
438+
return readers
402439

403440
def __repr__(self) -> str:
404441
"""Storing a hash of paths since paths can be a very long list."""

python/array_record_data_source_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_array_record_data_source_single_path(self):
109109
) as ar:
110110
actual_data = [ar[x] for x in indices_to_read]
111111
self.assertEqual(expected_data, actual_data)
112-
self.assertTrue(all(reader is None for reader in ar._readers))
112+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
113113

114114
def test_array_record_data_source_string_read_instructions(self):
115115
indices_to_read = [0, 1, 2, 3, 4]
@@ -132,7 +132,7 @@ def test_array_record_data_source_reverse_order(self):
132132
]) as ar:
133133
actual_data = [ar[x] for x in indices_to_read]
134134
self.assertEqual(expected_data, actual_data)
135-
self.assertTrue(all(reader is None for reader in ar._readers))
135+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
136136

137137
def test_array_record_data_source_random_order(self):
138138
# some random permutation
@@ -144,7 +144,7 @@ def test_array_record_data_source_random_order(self):
144144
]) as ar:
145145
actual_data = [ar[x] for x in indices_to_read]
146146
self.assertEqual(expected_data, actual_data)
147-
self.assertTrue(all(reader is None for reader in ar._readers))
147+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
148148

149149
def test_array_record_data_source_random_order_batched(self):
150150
# some random permutation
@@ -156,7 +156,7 @@ def test_array_record_data_source_random_order_batched(self):
156156
]) as ar:
157157
actual_data = ar.__getitems__(indices_to_read)
158158
self.assertEqual(expected_data, actual_data)
159-
self.assertTrue(all(reader is None for reader in ar._readers))
159+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
160160

161161
def test_array_record_data_source_file_instructions(self):
162162
file_instruction_one = DummyFileInstruction(
@@ -187,7 +187,7 @@ def test_array_record_data_source_file_instructions(self):
187187
actual_data = [ar[x] for x in indices_to_read]
188188

189189
self.assertEqual(expected_data, actual_data)
190-
self.assertTrue(all(reader is None for reader in ar._readers))
190+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
191191

192192
def test_array_record_source_reader_idx_and_position(self):
193193
file_instructions = [

0 commit comments

Comments
 (0)