Skip to content

Commit e279416

Browse files
ArrayRecord Teamcopybara-github
authored andcommitted
Fix the determinism issue by introducing a pool of readers/tokenizers usable by intra-shard threads.
This pool has a fixed size when underutilized, but will scale if the demand for them spikes (i.e. the pool is empty). This keeps a reasonable predictable memory footprint, while allowing it to still respond to large increases in demand, although the readers can have non-trivial initialization time. PiperOrigin-RevId: 897665704
1 parent c909611 commit e279416

2 files changed

Lines changed: 102 additions & 58 deletions

File tree

python/array_record_data_source.py

Lines changed: 92 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3838
import itertools
3939
import os
4040
import pathlib
41+
import queue
4142
import re
4243
import typing
4344
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union
@@ -221,6 +222,7 @@ def __init__(
221222
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
222223
],
223224
reader_options: dict[str, str] | None = None,
225+
reader_pool_size: int = 1,
224226
):
225227
"""Creates a new ArrayRecordDataSource object.
226228
@@ -242,6 +244,8 @@ def __init__(
242244
initialization faster.
243245
reader_options: string of comma-separated options to be passed when
244246
creating a reader.
247+
reader_pool_size: Number of readers to pre-allocate in the pool for each
248+
shard. Default is 1.
245249
"""
246250
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
247251
paths = [paths]
@@ -270,8 +274,15 @@ def __init__(
270274
)
271275
self._read_instructions = _get_read_instructions(paths)
272276
self._paths = [ri.filename for ri in self._read_instructions]
277+
self._reader_pool_size = max(reader_pool_size, 1)
278+
# We maintain a pool of readers for each shard to ensure thread safety
279+
# while allowing concurrent reads.
280+
self._reader_pools = [
281+
queue.LifoQueue(maxsize=self._reader_pool_size)
282+
for _ in self._read_instructions
283+
]
273284
# We open readers lazily when we need to read from them.
274-
self._readers = [None] * len(self._read_instructions)
285+
275286
self._num_records = sum(
276287
map(lambda x: x.num_records, self._read_instructions)
277288
)
@@ -286,10 +297,11 @@ def __enter__(self):
286297

287298
def __exit__(self, exc_type, exc_value, traceback):
288299
logging.debug("__exit__ for ArrayRecordDataSource is called.")
289-
for reader in self._readers:
290-
if reader:
291-
reader.close()
292-
self._readers = [None] * len(self._read_instructions)
300+
for pool in self._reader_pools:
301+
while not pool.empty():
302+
reader = pool.get()
303+
if reader:
304+
reader.close()
293305

294306
def __len__(self) -> int:
295307
return self._num_records
@@ -298,79 +310,100 @@ def __iter__(self) -> Iterator[bytes]:
298310
for index in range(self._num_records):
299311
yield self[index]
300312

301-
def _reader_idx_and_position(
313+
def _pool_idx_and_position(
302314
self, record_key: SupportsIndex
303315
) -> Tuple[int, int]:
304-
"""Computes reader idx and position of given record key."""
316+
"""Computes pool idx and position of given record key."""
305317
record_key = record_key.__index__()
306318
if record_key < 0 or record_key >= self._num_records:
307319
raise ValueError("Record key should be in [0, num_records)")
308-
reader_idx = bisect.bisect_right(self._prefix_sums, record_key)
320+
pool_idx = bisect.bisect_right(self._prefix_sums, record_key)
309321
records_in_previous_instructions = 0
310-
if reader_idx > 0:
311-
records_in_previous_instructions = self._prefix_sums[reader_idx - 1]
322+
if pool_idx > 0:
323+
records_in_previous_instructions = self._prefix_sums[pool_idx - 1]
312324
return (
313-
reader_idx,
325+
pool_idx,
314326
record_key
315327
- records_in_previous_instructions
316-
+ self._read_instructions[reader_idx].start,
328+
+ self._read_instructions[pool_idx].start,
317329
)
318330

319-
def _split_keys_per_reader(
331+
def _split_keys_per_pool(
320332
self, record_keys: Sequence[SupportsIndex]
321333
) -> Mapping[int, Sequence[Tuple[int, int]]]:
322-
"""Splits record_keys among readers."""
334+
"""Splits record_keys among pools."""
323335
positions_and_indices = {}
324336
for idx, record_key in enumerate(record_keys):
325-
reader_idx, position = self._reader_idx_and_position(record_key)
326-
if reader_idx in positions_and_indices:
327-
positions_and_indices[reader_idx].append((position, idx))
337+
pool_idx, position = self._pool_idx_and_position(record_key)
338+
if pool_idx in positions_and_indices:
339+
positions_and_indices[pool_idx].append((position, idx))
328340
else:
329-
positions_and_indices[reader_idx] = [(position, idx)]
341+
positions_and_indices[pool_idx] = [(position, idx)]
330342
return positions_and_indices
331343

332-
def _ensure_reader_exists(self, reader_idx: int) -> None:
333-
"""Threadsafe method to create corresponding reader if it doesn't exist."""
334-
if self._readers[reader_idx] is not None:
335-
return
336-
filename = self._read_instructions[reader_idx].filename
337-
reader = _create_reader(filename, self._reader_options_string)
338-
_check_group_size(filename, reader)
339-
self._readers[reader_idx] = reader
344+
def _get_reader(self, pool_idx: int) -> Any:
345+
"""Gets a reader from the pool or creates a new one."""
346+
try:
347+
return self._reader_pools[pool_idx].get_nowait()
348+
except queue.Empty:
349+
# If the pool is empty (all readers are busy), we create a new one
350+
# on demand to avoid blocking the calling thread.
351+
filename = self._read_instructions[pool_idx].filename
352+
reader = _create_reader(filename, self._reader_options_string)
353+
_check_group_size(filename, reader)
354+
return reader
355+
356+
def _release_reader(self, pool_idx: int, reader: Any) -> None:
357+
"""Returns a reader to the pool."""
358+
try:
359+
self._reader_pools[pool_idx].put_nowait(reader)
360+
except queue.Full:
361+
# If the pool is already full of idle readers, we close this one
362+
# to keep the memory footprint bounded.
363+
if reader:
364+
reader.close()
340365

341366
def __getitem__(self, record_key: SupportsIndex) -> bytes:
342-
reader_idx, position = self._reader_idx_and_position(record_key)
343-
self._ensure_reader_exists(reader_idx)
344-
if hasattr(self._readers[reader_idx], "read"):
345-
return self._readers[reader_idx].read([position])[0]
346-
return self._readers[reader_idx][position]
367+
pool_idx, position = self._pool_idx_and_position(record_key)
368+
reader = self._get_reader(pool_idx)
369+
try:
370+
if hasattr(reader, "read"):
371+
return reader.read([position])[0]
372+
return reader[position]
373+
finally:
374+
self._release_reader(pool_idx, reader)
347375

348376
def __getitems__(
349377
self, record_keys: Sequence[SupportsIndex]
350378
) -> Sequence[bytes]:
379+
351380
def read_records(
352-
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
381+
pool_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
353382
) -> Sequence[Tuple[Any, int]]:
354383
"""Reads records using the given reader keeping track of the indices."""
355-
# Initialize readers lazily when we need to read from them.
356-
self._ensure_reader_exists(reader_idx)
357-
positions, indices = list(zip(*reader_positions_and_indices))
358-
if hasattr(self._readers[reader_idx], "read"):
359-
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
360-
else:
361-
records = [self._readers[reader_idx][p] for p in positions]
362-
return list(zip(records, indices))
363-
364-
positions_and_indices = self._split_keys_per_reader(record_keys)
384+
reader = self._get_reader(pool_idx)
385+
try:
386+
positions, indices = list(zip(*reader_positions_and_indices))
387+
if hasattr(reader, "read"):
388+
records = reader.read(positions) # pytype: disable=attribute-error
389+
else:
390+
records = [reader[p] for p in positions]
391+
return list(zip(records, indices))
392+
finally:
393+
self._release_reader(pool_idx, reader)
394+
395+
# Group record keys by pool/shard to maximize reader reuse.
396+
positions_and_indices = self._split_keys_per_pool(record_keys)
365397
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
398+
# Parallelize reads across shards using the available threads.
366399
num_workers = min(len(positions_and_indices), num_threads)
367400
list_of_kwargs_to_read_records = []
368401
for (
369-
reader_idx,
402+
pool_idx,
370403
reader_positions_and_indices,
371404
) in positions_and_indices.items():
372405
list_of_kwargs_to_read_records.append({
373-
"reader_idx": reader_idx,
406+
"pool_idx": pool_idx,
374407
"reader_positions_and_indices": reader_positions_and_indices,
375408
})
376409
records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = (
@@ -390,15 +423,26 @@ def read_records(
390423
def __getstate__(self):
391424
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
392425
state = self.__dict__.copy()
393-
del state["_readers"]
426+
state.pop("_reader_pools", None)
394427
return state
395428

396429
def __setstate__(self, state):
397430
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
398431
self.__dict__.update(state)
399-
# We open readers lazily when we need to read from them. Thus, we don't
400-
# need to re-open the same files as before pickling.
401-
self._readers = [None] * len(self._read_instructions)
432+
# After pickling, we re-initialize the reader pools.
433+
self._reader_pools = [
434+
queue.LifoQueue(maxsize=self._reader_pool_size)
435+
for _ in self._read_instructions
436+
]
437+
# We open readers lazily when we need to read from them.
438+
439+
def _peek_readers(self) -> List[Any]:
440+
"""Returns a list of readers (one per shard) or None (for testing only)."""
441+
readers = []
442+
for pool in self._reader_pools:
443+
with pool.mutex:
444+
readers.append(pool.queue[-1] if pool.queue else None)
445+
return readers
402446

403447
def __repr__(self) -> str:
404448
"""Storing a hash of paths since paths can be a very long list."""

python/array_record_data_source_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_array_record_data_source_single_path(self):
109109
) as ar:
110110
actual_data = [ar[x] for x in indices_to_read]
111111
self.assertEqual(expected_data, actual_data)
112-
self.assertTrue(all(reader is None for reader in ar._readers))
112+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
113113

114114
def test_array_record_data_source_string_read_instructions(self):
115115
indices_to_read = [0, 1, 2, 3, 4]
@@ -132,7 +132,7 @@ def test_array_record_data_source_reverse_order(self):
132132
]) as ar:
133133
actual_data = [ar[x] for x in indices_to_read]
134134
self.assertEqual(expected_data, actual_data)
135-
self.assertTrue(all(reader is None for reader in ar._readers))
135+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
136136

137137
def test_array_record_data_source_random_order(self):
138138
# some random permutation
@@ -144,7 +144,7 @@ def test_array_record_data_source_random_order(self):
144144
]) as ar:
145145
actual_data = [ar[x] for x in indices_to_read]
146146
self.assertEqual(expected_data, actual_data)
147-
self.assertTrue(all(reader is None for reader in ar._readers))
147+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
148148

149149
def test_array_record_data_source_random_order_batched(self):
150150
# some random permutation
@@ -156,7 +156,7 @@ def test_array_record_data_source_random_order_batched(self):
156156
]) as ar:
157157
actual_data = ar.__getitems__(indices_to_read)
158158
self.assertEqual(expected_data, actual_data)
159-
self.assertTrue(all(reader is None for reader in ar._readers))
159+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
160160

161161
def test_array_record_data_source_file_instructions(self):
162162
file_instruction_one = DummyFileInstruction(
@@ -187,9 +187,9 @@ def test_array_record_data_source_file_instructions(self):
187187
actual_data = [ar[x] for x in indices_to_read]
188188

189189
self.assertEqual(expected_data, actual_data)
190-
self.assertTrue(all(reader is None for reader in ar._readers))
190+
self.assertTrue(all(reader is None for reader in ar._peek_readers()))
191191

192-
def test_array_record_source_reader_idx_and_position(self):
192+
def test_array_record_source_pool_idx_and_position(self):
193193
file_instructions = [
194194
# 2 records
195195
DummyFileInstruction(
@@ -221,19 +221,19 @@ def test_array_record_source_reader_idx_and_position(self):
221221
for record_key in range(len(ar)):
222222
self.assertEqual(
223223
expected_indices_and_positions[record_key],
224-
ar._reader_idx_and_position(record_key),
224+
ar._pool_idx_and_position(record_key),
225225
)
226226

227-
def test_array_record_source_reader_idx_and_position_negative_idx(self):
227+
def test_array_record_source_pool_idx_and_position_negative_idx(self):
228228
with array_record_data_source.ArrayRecordDataSource([
229229
self.testdata_dir / "digits.array_record-00000-of-00002",
230230
self.testdata_dir / "digits.array_record-00001-of-00002",
231231
]) as ar:
232232
with self.assertRaises(ValueError):
233-
ar._reader_idx_and_position(-1)
233+
ar._pool_idx_and_position(-1)
234234

235235
with self.assertRaises(ValueError):
236-
ar._reader_idx_and_position(len(ar))
236+
ar._pool_idx_and_position(len(ar))
237237

238238
def test_array_record_source_empty_sequence(self):
239239
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)