Skip to content

Commit 6af6853

Browse files
ArrayRecord Teamcopybara-github
authored andcommitted
Fix the determinism issue by introducing a pool of readers/tokenizers usable by intra-shard threads.
This pool has a fixed size when underutilized, but will scale if the demand for them spikes (i.e. the pool is empty). This keeps a reasonable predictable memory footprint, while allowing it to still respond to large increases in demand, although the readers can have non-trivial initialization time. PiperOrigin-RevId: 897665704
1 parent c909611 commit 6af6853

3 files changed

Lines changed: 83 additions & 35 deletions

File tree

python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
44
load("@pypi//:requirements.bzl", "requirement")
55

6+
67
package(default_visibility = ["//visibility:public"])
78

89
licenses(["notice"])

python/array_record_data_source.py

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3838
import itertools
3939
import os
4040
import pathlib
41+
import queue
4142
import re
4243
import typing
4344
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union
@@ -221,6 +222,7 @@ def __init__(
221222
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
222223
],
223224
reader_options: dict[str, str] | None = None,
225+
reader_pool_size: int = 1,
224226
):
225227
"""Creates a new ArrayRecordDataSource object.
226228
@@ -242,6 +244,8 @@ def __init__(
242244
initialization faster.
243245
reader_options: string of comma-separated options to be passed when
244246
creating a reader.
247+
reader_pool_size: Number of readers to pre-allocate in the pool for each
248+
shard. Default is 1.
245249
"""
246250
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
247251
paths = [paths]
@@ -270,8 +274,20 @@ def __init__(
270274
)
271275
self._read_instructions = _get_read_instructions(paths)
272276
self._paths = [ri.filename for ri in self._read_instructions]
273-
# We open readers lazily when we need to read from them.
274-
self._readers = [None] * len(self._read_instructions)
277+
self._reader_pool_size = max(reader_pool_size, 1)
278+
# We maintain a pool of readers for each shard to ensure thread safety
279+
# while allowing concurrent reads.
280+
self._reader_pools = [
281+
queue.LifoQueue(maxsize=self._reader_pool_size)
282+
for _ in self._read_instructions
283+
]
284+
# We pre-populate the pools sequentially to avoid I/O storms at startup.
285+
for i, ri in enumerate(self._read_instructions):
286+
for _ in range(self._reader_pool_size):
287+
reader = _create_reader(ri.filename, self._reader_options_string)
288+
_check_group_size(ri.filename, reader)
289+
self._reader_pools[i].put(reader)
290+
275291
self._num_records = sum(
276292
map(lambda x: x.num_records, self._read_instructions)
277293
)
@@ -286,10 +302,11 @@ def __enter__(self):
286302

287303
def __exit__(self, exc_type, exc_value, traceback):
288304
logging.debug("__exit__ for ArrayRecordDataSource is called.")
289-
for reader in self._readers:
290-
if reader:
291-
reader.close()
292-
self._readers = [None] * len(self._read_instructions)
305+
for pool in self._reader_pools:
306+
while not pool.empty():
307+
reader = pool.get()
308+
if reader:
309+
reader.close()
293310

294311
def __len__(self) -> int:
295312
return self._num_records
@@ -329,21 +346,33 @@ def _split_keys_per_reader(
329346
positions_and_indices[reader_idx] = [(position, idx)]
330347
return positions_and_indices
331348

332-
def _ensure_reader_exists(self, reader_idx: int) -> None:
333-
"""Threadsafe method to create corresponding reader if it doesn't exist."""
334-
if self._readers[reader_idx] is not None:
335-
return
336-
filename = self._read_instructions[reader_idx].filename
337-
reader = _create_reader(filename, self._reader_options_string)
338-
_check_group_size(filename, reader)
339-
self._readers[reader_idx] = reader
349+
def _get_reader(self, reader_idx: int) -> Any:
350+
"""Gets a reader from the pool or creates a new one."""
351+
try:
352+
return self._reader_pools[reader_idx].get_nowait()
353+
except queue.Empty:
354+
filename = self._read_instructions[reader_idx].filename
355+
reader = _create_reader(filename, self._reader_options_string)
356+
_check_group_size(filename, reader)
357+
return reader
358+
359+
def _release_reader(self, reader_idx: int, reader: Any) -> None:
360+
"""Returns a reader to the pool."""
361+
try:
362+
self._reader_pools[reader_idx].put_nowait(reader)
363+
except queue.Full:
364+
if reader:
365+
reader.close()
340366

341367
def __getitem__(self, record_key: SupportsIndex) -> bytes:
342368
reader_idx, position = self._reader_idx_and_position(record_key)
343-
self._ensure_reader_exists(reader_idx)
344-
if hasattr(self._readers[reader_idx], "read"):
345-
return self._readers[reader_idx].read([position])[0]
346-
return self._readers[reader_idx][position]
369+
reader = self._get_reader(reader_idx)
370+
try:
371+
if hasattr(reader, "read"):
372+
return reader.read([position])[0]
373+
return reader[position]
374+
finally:
375+
self._release_reader(reader_idx, reader)
347376

348377
def __getitems__(
349378
self, record_keys: Sequence[SupportsIndex]
@@ -352,14 +381,16 @@ def read_records(
352381
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
353382
) -> Sequence[Tuple[Any, int]]:
354383
"""Reads records using the given reader keeping track of the indices."""
355-
# Initialize readers lazily when we need to read from them.
356-
self._ensure_reader_exists(reader_idx)
357-
positions, indices = list(zip(*reader_positions_and_indices))
358-
if hasattr(self._readers[reader_idx], "read"):
359-
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
360-
else:
361-
records = [self._readers[reader_idx][p] for p in positions]
362-
return list(zip(records, indices))
384+
reader = self._get_reader(reader_idx)
385+
try:
386+
positions, indices = list(zip(*reader_positions_and_indices))
387+
if hasattr(reader, "read"):
388+
records = reader.read(positions) # pytype: disable=attribute-error
389+
else:
390+
records = [reader[p] for p in positions]
391+
return list(zip(records, indices))
392+
finally:
393+
self._release_reader(reader_idx, reader)
363394

364395
positions_and_indices = self._split_keys_per_reader(record_keys)
365396
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
@@ -390,15 +421,31 @@ def read_records(
390421
def __getstate__(self):
391422
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
392423
state = self.__dict__.copy()
393-
del state["_readers"]
424+
state.pop("_reader_pools", None)
394425
return state
395426

396427
def __setstate__(self, state):
397428
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
398429
self.__dict__.update(state)
399-
# We open readers lazily when we need to read from them. Thus, we don't
400-
# need to re-open the same files as before pickling.
401-
self._readers = [None] * len(self._read_instructions)
430+
# After pickling, we re-initialize the reader pools.
431+
self._reader_pools = [
432+
queue.LifoQueue(maxsize=self._reader_pool_size)
433+
for _ in self._read_instructions
434+
]
435+
# We pre-populate the pools again in the new process.
436+
for i, ri in enumerate(self._read_instructions):
437+
for _ in range(self._reader_pool_size):
438+
reader = _create_reader(ri.filename, self._reader_options_string)
439+
_check_group_size(ri.filename, reader)
440+
self._reader_pools[i].put(reader)
441+
442+
def _peek_readers(self) -> List[Any]:
443+
"""Returns a list of readers (one per shard) or None (for testing)."""
444+
readers = []
445+
for pool in self._reader_pools:
446+
with pool.mutex:
447+
readers.append(pool.queue[-1] if pool.queue else None)
448+
return readers
402449

403450
def __repr__(self) -> str:
404451
"""Storing a hash of paths since paths can be a very long list."""

python/array_record_data_source_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_array_record_data_source_single_path(self):
109109
) as ar:
110110
actual_data = [ar[x] for x in indices_to_read]
111111
self.assertEqual(expected_data, actual_data)
112-
self.assertTrue(all(reader is None for reader in ar._readers))
112+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
113113

114114
def test_array_record_data_source_string_read_instructions(self):
115115
indices_to_read = [0, 1, 2, 3, 4]
@@ -132,7 +132,7 @@ def test_array_record_data_source_reverse_order(self):
132132
]) as ar:
133133
actual_data = [ar[x] for x in indices_to_read]
134134
self.assertEqual(expected_data, actual_data)
135-
self.assertTrue(all(reader is None for reader in ar._readers))
135+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
136136

137137
def test_array_record_data_source_random_order(self):
138138
# some random permutation
@@ -144,7 +144,7 @@ def test_array_record_data_source_random_order(self):
144144
]) as ar:
145145
actual_data = [ar[x] for x in indices_to_read]
146146
self.assertEqual(expected_data, actual_data)
147-
self.assertTrue(all(reader is None for reader in ar._readers))
147+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
148148

149149
def test_array_record_data_source_random_order_batched(self):
150150
# some random permutation
@@ -156,7 +156,7 @@ def test_array_record_data_source_random_order_batched(self):
156156
]) as ar:
157157
actual_data = ar.__getitems__(indices_to_read)
158158
self.assertEqual(expected_data, actual_data)
159-
self.assertTrue(all(reader is None for reader in ar._readers))
159+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
160160

161161
def test_array_record_data_source_file_instructions(self):
162162
file_instruction_one = DummyFileInstruction(
@@ -187,7 +187,7 @@ def test_array_record_data_source_file_instructions(self):
187187
actual_data = [ar[x] for x in indices_to_read]
188188

189189
self.assertEqual(expected_data, actual_data)
190-
self.assertTrue(all(reader is None for reader in ar._readers))
190+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
191191

192192
def test_array_record_source_reader_idx_and_position(self):
193193
file_instructions = [

0 commit comments

Comments
 (0)