Skip to content

Commit fad4e7f

Browse files
ArrayRecord Teamcopybara-github
authored andcommitted
Implement BoundedReaderPool and thread-local tokenizer cache in ArrayRecordDataSource for deterministic and resource-safe reading.
PiperOrigin-RevId: 902916801
1 parent c909611 commit fad4e7f

2 files changed

Lines changed: 251 additions & 58 deletions

File tree

python/array_record_data_source.py

Lines changed: 185 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3939
import os
4040
import pathlib
4141
import re
42+
import threading
4243
import typing
4344
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union
4445

@@ -96,6 +97,8 @@ def _run_in_parallel(
9697
"""
9798
if num_workers < 1:
9899
raise ValueError("num_workers must be >=1 for parallelism.")
100+
if num_workers == 1 or len(list_of_kwargs_to_function) == 1:
101+
return [function(**kwargs) for kwargs in list_of_kwargs_to_function]
99102
thread_futures = []
100103
with futures.ThreadPoolExecutor(num_workers) as executor:
101104
for kwargs in list_of_kwargs_to_function:
@@ -212,6 +215,62 @@ def _check_group_size(
212215
)
213216

214217

218+
class BoundedReaderPool:
219+
"""A pool of readers for a single shard with a dynamic upper bound."""
220+
221+
def __init__(self, filename: str, options_string: str, max_size: int):
222+
self._filename = filename
223+
self._options_string = options_string
224+
self._max_size = max_size
225+
self._readers = []
226+
self._created_count = 0
227+
self._condition = threading.Condition()
228+
self._group_size_checked = False
229+
self.has_read_method = None
230+
231+
def get(self) -> Any:
232+
"""Gets a reader from the pool, blocking if the cap is reached."""
233+
create_new = False
234+
with self._condition:
235+
# Wait if no idle readers and we reached the cap
236+
while not self._readers and self._created_count >= self._max_size:
237+
self._condition.wait()
238+
239+
if self._readers:
240+
return self._readers.pop()
241+
242+
self._created_count += 1
243+
create_new = True
244+
245+
if create_new:
246+
reader = _create_reader(self._filename, self._options_string)
247+
if self.has_read_method is None:
248+
self.has_read_method = hasattr(reader, "read")
249+
if not self._group_size_checked:
250+
_check_group_size(self._filename, reader)
251+
self._group_size_checked = True
252+
return reader
253+
254+
def put(self, reader: Any) -> None:
255+
"""Returns a reader to the pool."""
256+
with self._condition:
257+
self._readers.append(reader)
258+
self._condition.notify()
259+
260+
def close_all(self) -> None:
261+
"""Closes all pooled readers."""
262+
with self._condition:
263+
for reader in self._readers:
264+
if reader:
265+
reader.close()
266+
self._readers.clear()
267+
268+
def peek_readers(self) -> List[Any]:
269+
"""Returns the list of readers currently in the pool."""
270+
with self._condition:
271+
return list(self._readers)
272+
273+
215274
class ArrayRecordDataSource:
216275
"""Datasource for ArrayRecord files."""
217276

@@ -221,6 +280,7 @@ def __init__(
221280
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
222281
],
223282
reader_options: dict[str, str] | None = None,
283+
reader_pool_size: int = 1,
224284
):
225285
"""Creates a new ArrayRecordDataSource object.
226286
@@ -242,6 +302,8 @@ def __init__(
242302
initialization faster.
243303
reader_options: string of comma-separated options to be passed when
244304
creating a reader.
305+
reader_pool_size: Number of readers to pre-allocate in the pool for each
306+
shard. Default is 1.
245307
"""
246308
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
247309
paths = [paths]
@@ -265,31 +327,49 @@ def __init__(
265327
if reader_options is None:
266328
self._reader_options_string = ""
267329
else:
330+
reader_options = dict(reader_options)
331+
if "reader_pool_size" in reader_options:
332+
reader_pool_size = int(reader_options.pop("reader_pool_size"))
268333
self._reader_options_string = ",".join(
269334
[f"{k}:{v}" for k, v in reader_options.items()]
270335
)
271336
self._read_instructions = _get_read_instructions(paths)
272337
self._paths = [ri.filename for ri in self._read_instructions]
338+
self._reader_pool_size = max(reader_pool_size, 1)
339+
# We maintain a pool of readers for each shard to ensure thread safety
340+
# while allowing concurrent reads.
341+
self._shard_pools = [
342+
BoundedReaderPool(
343+
ri.filename, self._reader_options_string, self._reader_pool_size
344+
)
345+
for ri in self._read_instructions
346+
]
273347
# We open readers lazily when we need to read from them.
274-
self._readers = [None] * len(self._read_instructions)
348+
275349
self._num_records = sum(
276350
map(lambda x: x.num_records, self._read_instructions)
277351
)
278352
records_per_instruction = map(
279353
lambda x: x.num_records, self._read_instructions
280354
)
281355
self._prefix_sums = list(itertools.accumulate(records_per_instruction))
356+
self._thread_local = threading.local()
282357

283358
def __enter__(self):
284359
logging.debug("__enter__ for ArrayRecordDataSource is called.")
285360
return self
286361

287362
def __exit__(self, exc_type, exc_value, traceback):
288363
logging.debug("__exit__ for ArrayRecordDataSource is called.")
289-
for reader in self._readers:
290-
if reader:
291-
reader.close()
292-
self._readers = [None] * len(self._read_instructions)
364+
# Clear thread-local cache for the current thread.
365+
if hasattr(self._thread_local, "reader"):
366+
if self._thread_local.reader is not None:
367+
self._thread_local.reader.close()
368+
self._thread_local.reader = None
369+
self._thread_local.pool_idx = -1
370+
371+
for pool in self._shard_pools:
372+
pool.close_all()
293373

294374
def __len__(self) -> int:
295375
return self._num_records
@@ -298,79 +378,111 @@ def __iter__(self) -> Iterator[bytes]:
298378
for index in range(self._num_records):
299379
yield self[index]
300380

301-
def _reader_idx_and_position(
381+
def _pool_idx_and_position(
302382
self, record_key: SupportsIndex
303383
) -> Tuple[int, int]:
304-
"""Computes reader idx and position of given record key."""
384+
"""Computes pool idx and position of given record key."""
305385
record_key = record_key.__index__()
306386
if record_key < 0 or record_key >= self._num_records:
307387
raise ValueError("Record key should be in [0, num_records)")
308-
reader_idx = bisect.bisect_right(self._prefix_sums, record_key)
388+
pool_idx = bisect.bisect_right(self._prefix_sums, record_key)
309389
records_in_previous_instructions = 0
310-
if reader_idx > 0:
311-
records_in_previous_instructions = self._prefix_sums[reader_idx - 1]
390+
if pool_idx > 0:
391+
records_in_previous_instructions = self._prefix_sums[pool_idx - 1]
312392
return (
313-
reader_idx,
393+
pool_idx,
314394
record_key
315395
- records_in_previous_instructions
316-
+ self._read_instructions[reader_idx].start,
396+
+ self._read_instructions[pool_idx].start,
317397
)
318398

319-
def _split_keys_per_reader(
399+
def _split_keys_per_pool(
320400
self, record_keys: Sequence[SupportsIndex]
321401
) -> Mapping[int, Sequence[Tuple[int, int]]]:
322-
"""Splits record_keys among readers."""
402+
"""Splits record_keys among pools."""
323403
positions_and_indices = {}
324404
for idx, record_key in enumerate(record_keys):
325-
reader_idx, position = self._reader_idx_and_position(record_key)
326-
if reader_idx in positions_and_indices:
327-
positions_and_indices[reader_idx].append((position, idx))
405+
pool_idx, position = self._pool_idx_and_position(record_key)
406+
if pool_idx in positions_and_indices:
407+
positions_and_indices[pool_idx].append((position, idx))
328408
else:
329-
positions_and_indices[reader_idx] = [(position, idx)]
409+
positions_and_indices[pool_idx] = [(position, idx)]
330410
return positions_and_indices
331411

332-
def _ensure_reader_exists(self, reader_idx: int) -> None:
333-
"""Threadsafe method to create corresponding reader if it doesn't exist."""
334-
if self._readers[reader_idx] is not None:
335-
return
336-
filename = self._read_instructions[reader_idx].filename
337-
reader = _create_reader(filename, self._reader_options_string)
338-
_check_group_size(filename, reader)
339-
self._readers[reader_idx] = reader
412+
def _get_reader(self, pool_idx: int) -> Any:
413+
"""Gets a reader from the single-slot thread-local cache or the pool."""
414+
if not hasattr(self._thread_local, "pool_idx"):
415+
self._thread_local.pool_idx = -1
416+
self._thread_local.reader = None
417+
418+
if self._thread_local.pool_idx == pool_idx:
419+
return self._thread_local.reader
420+
421+
# Return previous reader to pool if we have one
422+
if self._thread_local.reader is not None:
423+
self._shard_pools[self._thread_local.pool_idx].put(
424+
self._thread_local.reader
425+
)
426+
427+
# Get new reader and cache it
428+
reader = self._shard_pools[pool_idx].get()
429+
self._thread_local.pool_idx = pool_idx
430+
self._thread_local.reader = reader
431+
return reader
432+
433+
def _release_reader(self, pool_idx: int, reader: Any) -> None:
434+
"""No-op: we keep it cached in _get_reader until the next request."""
435+
pass
436+
437+
def _read_record(
438+
self, reader: Any, pool: BoundedReaderPool, position: int
439+
) -> bytes:
440+
"""Helper to read a record using the best available method."""
441+
if pool.has_read_method:
442+
return reader.read([position])[0]
443+
if hasattr(reader, "read_record"):
444+
return reader.read_record(position)
445+
return reader[position]
340446

341447
def __getitem__(self, record_key: SupportsIndex) -> bytes:
342-
reader_idx, position = self._reader_idx_and_position(record_key)
343-
self._ensure_reader_exists(reader_idx)
344-
if hasattr(self._readers[reader_idx], "read"):
345-
return self._readers[reader_idx].read([position])[0]
346-
return self._readers[reader_idx][position]
448+
pool_idx, position = self._pool_idx_and_position(record_key)
449+
reader = self._get_reader(pool_idx)
450+
try:
451+
return self._read_record(reader, self._shard_pools[pool_idx], position)
452+
finally:
453+
self._release_reader(pool_idx, reader)
347454

348455
def __getitems__(
349456
self, record_keys: Sequence[SupportsIndex]
350457
) -> Sequence[bytes]:
458+
351459
def read_records(
352-
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
460+
pool_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
353461
) -> Sequence[Tuple[Any, int]]:
354462
"""Reads records using the given reader keeping track of the indices."""
355-
# Initialize readers lazily when we need to read from them.
356-
self._ensure_reader_exists(reader_idx)
357-
positions, indices = list(zip(*reader_positions_and_indices))
358-
if hasattr(self._readers[reader_idx], "read"):
359-
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
360-
else:
361-
records = [self._readers[reader_idx][p] for p in positions]
362-
return list(zip(records, indices))
363-
364-
positions_and_indices = self._split_keys_per_reader(record_keys)
463+
reader = self._get_reader(pool_idx)
464+
pool = self._shard_pools[pool_idx]
465+
try:
466+
records = []
467+
for position, _ in reader_positions_and_indices:
468+
records.append(self._read_record(reader, pool, position))
469+
indices = [idx for _, idx in reader_positions_and_indices]
470+
return list(zip(records, indices))
471+
finally:
472+
self._release_reader(pool_idx, reader)
473+
474+
# Group record keys by pool/shard to maximize reader reuse.
475+
positions_and_indices = self._split_keys_per_pool(record_keys)
365476
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
477+
# Parallelize reads across shards using the available threads.
366478
num_workers = min(len(positions_and_indices), num_threads)
367479
list_of_kwargs_to_read_records = []
368480
for (
369-
reader_idx,
481+
pool_idx,
370482
reader_positions_and_indices,
371483
) in positions_and_indices.items():
372484
list_of_kwargs_to_read_records.append({
373-
"reader_idx": reader_idx,
485+
"pool_idx": pool_idx,
374486
"reader_positions_and_indices": reader_positions_and_indices,
375487
})
376488
records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = (
@@ -390,15 +502,40 @@ def read_records(
390502
def __getstate__(self):
391503
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
392504
state = self.__dict__.copy()
393-
del state["_readers"]
505+
state.pop("_shard_pools", None)
506+
state.pop("_reader_pools", None)
507+
state.pop("_thread_local", None)
394508
return state
395509

396510
def __setstate__(self, state):
397511
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
398512
self.__dict__.update(state)
399-
# We open readers lazily when we need to read from them. Thus, we don't
400-
# need to re-open the same files as before pickling.
401-
self._readers = [None] * len(self._read_instructions)
513+
self._shard_pools = [
514+
BoundedReaderPool(
515+
ri.filename,
516+
self._reader_options_string,
517+
getattr(self, "_reader_pool_size", 1),
518+
)
519+
for ri in self._read_instructions
520+
]
521+
self._thread_local = threading.local()
522+
# We open readers lazily when we need to read from them.
523+
524+
def _peek_readers(self) -> List[Any]:
525+
"""Returns a list of readers (one per shard) or None (for testing only)."""
526+
readers = []
527+
for i, pool in enumerate(self._shard_pools):
528+
reader = None
529+
if (
530+
hasattr(self._thread_local, "pool_idx")
531+
and self._thread_local.pool_idx == i
532+
):
533+
reader = self._thread_local.reader
534+
if reader is None:
535+
pooled_readers = pool.peek_readers()
536+
reader = pooled_readers[-1] if pooled_readers else None
537+
readers.append(reader)
538+
return readers
402539

403540
def __repr__(self) -> str:
404541
"""Storing a hash of paths since paths can be a very long list."""

0 commit comments

Comments
 (0)