@@ -38,6 +38,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3838import itertools
3939import os
4040import pathlib
41+ import queue
4142import re
4243import typing
4344from typing import Any , Callable , Iterator , List , Mapping , Protocol , Sequence , SupportsIndex , Tuple , TypeVar , Union
@@ -221,6 +222,7 @@ def __init__(
221222 PathLikeOrFileInstruction , Sequence [PathLikeOrFileInstruction ]
222223 ],
223224 reader_options : dict [str , str ] | None = None ,
225+ reader_pool_size : int = 1 ,
224226 ):
225227 """Creates a new ArrayRecordDataSource object.
226228
@@ -242,6 +244,8 @@ def __init__(
242244 initialization faster.
243245 reader_options: string of comma-separated options to be passed when
244246 creating a reader.
247+ reader_pool_size: Number of readers to pre-allocate in the pool for each
248+ shard. Default is 1.
245249 """
246250 if isinstance (paths , (str , pathlib .Path , FileInstruction )):
247251 paths = [paths ]
@@ -270,8 +274,15 @@ def __init__(
270274 )
271275 self ._read_instructions = _get_read_instructions (paths )
272276 self ._paths = [ri .filename for ri in self ._read_instructions ]
277+ self ._reader_pool_size = max (reader_pool_size , 1 )
278+ # We maintain a pool of readers for each shard to ensure thread safety
279+ # while allowing concurrent reads.
280+ self ._reader_pools = [
281+ queue .LifoQueue (maxsize = self ._reader_pool_size )
282+ for _ in self ._read_instructions
283+ ]
273284 # We open readers lazily when we need to read from them.
274- self . _readers = [ None ] * len ( self . _read_instructions )
285+
275286 self ._num_records = sum (
276287 map (lambda x : x .num_records , self ._read_instructions )
277288 )
@@ -286,10 +297,11 @@ def __enter__(self):
286297
287298 def __exit__ (self , exc_type , exc_value , traceback ):
288299 logging .debug ("__exit__ for ArrayRecordDataSource is called." )
289- for reader in self ._readers :
290- if reader :
291- reader .close ()
292- self ._readers = [None ] * len (self ._read_instructions )
300+ for pool in self ._reader_pools :
301+ while not pool .empty ():
302+ reader = pool .get ()
303+ if reader :
304+ reader .close ()
293305
294306 def __len__ (self ) -> int :
295307 return self ._num_records
@@ -298,79 +310,100 @@ def __iter__(self) -> Iterator[bytes]:
298310 for index in range (self ._num_records ):
299311 yield self [index ]
300312
301- def _reader_idx_and_position (
313+ def _pool_idx_and_position (
302314 self , record_key : SupportsIndex
303315 ) -> Tuple [int , int ]:
304- """Computes reader idx and position of given record key."""
316+ """Computes pool idx and position of given record key."""
305317 record_key = record_key .__index__ ()
306318 if record_key < 0 or record_key >= self ._num_records :
307319 raise ValueError ("Record key should be in [0, num_records)" )
308- reader_idx = bisect .bisect_right (self ._prefix_sums , record_key )
320+ pool_idx = bisect .bisect_right (self ._prefix_sums , record_key )
309321 records_in_previous_instructions = 0
310- if reader_idx > 0 :
311- records_in_previous_instructions = self ._prefix_sums [reader_idx - 1 ]
322+ if pool_idx > 0 :
323+ records_in_previous_instructions = self ._prefix_sums [pool_idx - 1 ]
312324 return (
313- reader_idx ,
325+ pool_idx ,
314326 record_key
315327 - records_in_previous_instructions
316- + self ._read_instructions [reader_idx ].start ,
328+ + self ._read_instructions [pool_idx ].start ,
317329 )
318330
319- def _split_keys_per_reader (
331+ def _split_keys_per_pool (
320332 self , record_keys : Sequence [SupportsIndex ]
321333 ) -> Mapping [int , Sequence [Tuple [int , int ]]]:
322- """Splits record_keys among readers ."""
334+ """Splits record_keys among pools ."""
323335 positions_and_indices = {}
324336 for idx , record_key in enumerate (record_keys ):
325- reader_idx , position = self ._reader_idx_and_position (record_key )
326- if reader_idx in positions_and_indices :
327- positions_and_indices [reader_idx ].append ((position , idx ))
337+ pool_idx , position = self ._pool_idx_and_position (record_key )
338+ if pool_idx in positions_and_indices :
339+ positions_and_indices [pool_idx ].append ((position , idx ))
328340 else :
329- positions_and_indices [reader_idx ] = [(position , idx )]
341+ positions_and_indices [pool_idx ] = [(position , idx )]
330342 return positions_and_indices
331343
332- def _ensure_reader_exists (self , reader_idx : int ) -> None :
333- """Threadsafe method to create corresponding reader if it doesn't exist."""
334- if self ._readers [reader_idx ] is not None :
335- return
336- filename = self ._read_instructions [reader_idx ].filename
337- reader = _create_reader (filename , self ._reader_options_string )
338- _check_group_size (filename , reader )
339- self ._readers [reader_idx ] = reader
344+ def _get_reader (self , pool_idx : int ) -> Any :
345+ """Gets a reader from the pool or creates a new one."""
346+ try :
347+ return self ._reader_pools [pool_idx ].get_nowait ()
348+ except queue .Empty :
349+ # If the pool is empty (all readers are busy), we create a new one
350+ # on demand to avoid blocking the calling thread.
351+ filename = self ._read_instructions [pool_idx ].filename
352+ reader = _create_reader (filename , self ._reader_options_string )
353+ _check_group_size (filename , reader )
354+ return reader
355+
356+ def _release_reader (self , pool_idx : int , reader : Any ) -> None :
357+ """Returns a reader to the pool."""
358+ try :
359+ self ._reader_pools [pool_idx ].put_nowait (reader )
360+ except queue .Full :
361+ # If the pool is already full of idle readers, we close this one
362+ # to keep the memory footprint bounded.
363+ if reader :
364+ reader .close ()
340365
341366 def __getitem__ (self , record_key : SupportsIndex ) -> bytes :
342- reader_idx , position = self ._reader_idx_and_position (record_key )
343- self ._ensure_reader_exists (reader_idx )
344- if hasattr (self ._readers [reader_idx ], "read" ):
345- return self ._readers [reader_idx ].read ([position ])[0 ]
346- return self ._readers [reader_idx ][position ]
367+ pool_idx , position = self ._pool_idx_and_position (record_key )
368+ reader = self ._get_reader (pool_idx )
369+ try :
370+ if hasattr (reader , "read" ):
371+ return reader .read ([position ])[0 ]
372+ return reader [position ]
373+ finally :
374+ self ._release_reader (pool_idx , reader )
347375
348376 def __getitems__ (
349377 self , record_keys : Sequence [SupportsIndex ]
350378 ) -> Sequence [bytes ]:
379+
351380 def read_records (
352- reader_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
381+ pool_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
353382 ) -> Sequence [Tuple [Any , int ]]:
354383 """Reads records using the given reader keeping track of the indices."""
355- # Initialize readers lazily when we need to read from them.
356- self ._ensure_reader_exists (reader_idx )
357- positions , indices = list (zip (* reader_positions_and_indices ))
358- if hasattr (self ._readers [reader_idx ], "read" ):
359- records = self ._readers [reader_idx ].read (positions ) # pytype: disable=attribute-error
360- else :
361- records = [self ._readers [reader_idx ][p ] for p in positions ]
362- return list (zip (records , indices ))
363-
364- positions_and_indices = self ._split_keys_per_reader (record_keys )
384+ reader = self ._get_reader (pool_idx )
385+ try :
386+ positions , indices = list (zip (* reader_positions_and_indices ))
387+ if hasattr (reader , "read" ):
388+ records = reader .read (positions ) # pytype: disable=attribute-error
389+ else :
390+ records = [reader [p ] for p in positions ]
391+ return list (zip (records , indices ))
392+ finally :
393+ self ._release_reader (pool_idx , reader )
394+
395+ # Group record keys by pool/shard to maximize reader reuse.
396+ positions_and_indices = self ._split_keys_per_pool (record_keys )
365397 num_threads = _get_flag_value (_GRAIN_NUM_THREADS_FETCHING_RECORDS )
398+ # Parallelize reads across shards using the available threads.
366399 num_workers = min (len (positions_and_indices ), num_threads )
367400 list_of_kwargs_to_read_records = []
368401 for (
369- reader_idx ,
402+ pool_idx ,
370403 reader_positions_and_indices ,
371404 ) in positions_and_indices .items ():
372405 list_of_kwargs_to_read_records .append ({
373- "reader_idx " : reader_idx ,
406+ "pool_idx " : pool_idx ,
374407 "reader_positions_and_indices" : reader_positions_and_indices ,
375408 })
376409 records_with_indices : Sequence [Sequence [Tuple [Any , int ]]] = (
@@ -390,15 +423,26 @@ def read_records(
390423 def __getstate__ (self ):
391424 logging .debug ("__getstate__ for ArrayRecordDataSource is called." )
392425 state = self .__dict__ .copy ()
393- del state [ "_readers" ]
426+ state . pop ( "_reader_pools" , None )
394427 return state
395428
396429 def __setstate__ (self , state ):
397430 logging .debug ("__setstate__ for ArrayRecordDataSource is called." )
398431 self .__dict__ .update (state )
399- # We open readers lazily when we need to read from them. Thus, we don't
400- # need to re-open the same files as before pickling.
401- self ._readers = [None ] * len (self ._read_instructions )
432+ # After pickling, we re-initialize the reader pools.
433+ self ._reader_pools = [
434+ queue .LifoQueue (maxsize = self ._reader_pool_size )
435+ for _ in self ._read_instructions
436+ ]
437+ # We open readers lazily when we need to read from them.
438+
439+ def _peek_readers (self ) -> List [Any ]:
440+ """Returns a list of readers (one per shard) or None (for testing only)."""
441+ readers = []
442+ for pool in self ._reader_pools :
443+ with pool .mutex :
444+ readers .append (pool .queue [- 1 ] if pool .queue else None )
445+ return readers
402446
403447 def __repr__ (self ) -> str :
404448 """Storing a hash of paths since paths can be a very long list."""
0 commit comments