@@ -38,6 +38,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3838import itertools
3939import os
4040import pathlib
41+ import queue
4142import re
4243import typing
4344from typing import Any , Callable , Iterator , List , Mapping , Protocol , Sequence , SupportsIndex , Tuple , TypeVar , Union
@@ -221,6 +222,7 @@ def __init__(
221222 PathLikeOrFileInstruction , Sequence [PathLikeOrFileInstruction ]
222223 ],
223224 reader_options : dict [str , str ] | None = None ,
225+ reader_pool_size : int = 1 ,
224226 ):
225227 """Creates a new ArrayRecordDataSource object.
226228
@@ -242,6 +244,8 @@ def __init__(
242244 initialization faster.
243245 reader_options: string of comma-separated options to be passed when
244246 creating a reader.
247+ reader_pool_size: Number of readers to pre-allocate in the pool for each
248+ shard. Default is 1.
245249 """
246250 if isinstance (paths , (str , pathlib .Path , FileInstruction )):
247251 paths = [paths ]
@@ -270,8 +274,15 @@ def __init__(
270274 )
271275 self ._read_instructions = _get_read_instructions (paths )
272276 self ._paths = [ri .filename for ri in self ._read_instructions ]
277+ self ._reader_pool_size = max (reader_pool_size , 1 )
278+ # We maintain a pool of readers for each shard to ensure thread safety
279+ # while allowing concurrent reads.
280+ self ._reader_pools = [
281+ queue .LifoQueue (maxsize = self ._reader_pool_size )
282+ for _ in self ._read_instructions
283+ ]
273284 # We open readers lazily when we need to read from them.
274- self . _readers = [ None ] * len ( self . _read_instructions )
285+
275286 self ._num_records = sum (
276287 map (lambda x : x .num_records , self ._read_instructions )
277288 )
@@ -286,10 +297,11 @@ def __enter__(self):
286297
287298 def __exit__ (self , exc_type , exc_value , traceback ):
288299 logging .debug ("__exit__ for ArrayRecordDataSource is called." )
289- for reader in self ._readers :
290- if reader :
291- reader .close ()
292- self ._readers = [None ] * len (self ._read_instructions )
300+ for pool in self ._reader_pools :
301+ while not pool .empty ():
302+ reader = pool .get ()
303+ if reader :
304+ reader .close ()
293305
294306 def __len__ (self ) -> int :
295307 return self ._num_records
@@ -329,21 +341,33 @@ def _split_keys_per_reader(
329341 positions_and_indices [reader_idx ] = [(position , idx )]
330342 return positions_and_indices
331343
332- def _ensure_reader_exists (self , reader_idx : int ) -> None :
333- """Threadsafe method to create corresponding reader if it doesn't exist."""
334- if self ._readers [reader_idx ] is not None :
335- return
336- filename = self ._read_instructions [reader_idx ].filename
337- reader = _create_reader (filename , self ._reader_options_string )
338- _check_group_size (filename , reader )
339- self ._readers [reader_idx ] = reader
344+ def _get_reader (self , reader_idx : int ) -> Any :
345+ """Gets a reader from the pool or creates a new one."""
346+ try :
347+ return self ._reader_pools [reader_idx ].get_nowait ()
348+ except queue .Empty :
349+ filename = self ._read_instructions [reader_idx ].filename
350+ reader = _create_reader (filename , self ._reader_options_string )
351+ _check_group_size (filename , reader )
352+ return reader
353+
354+ def _release_reader (self , reader_idx : int , reader : Any ) -> None :
355+ """Returns a reader to the pool."""
356+ try :
357+ self ._reader_pools [reader_idx ].put_nowait (reader )
358+ except queue .Full :
359+ if reader :
360+ reader .close ()
340361
341362 def __getitem__ (self , record_key : SupportsIndex ) -> bytes :
342363 reader_idx , position = self ._reader_idx_and_position (record_key )
343- self ._ensure_reader_exists (reader_idx )
344- if hasattr (self ._readers [reader_idx ], "read" ):
345- return self ._readers [reader_idx ].read ([position ])[0 ]
346- return self ._readers [reader_idx ][position ]
364+ reader = self ._get_reader (reader_idx )
365+ try :
366+ if hasattr (reader , "read" ):
367+ return reader .read ([position ])[0 ]
368+ return reader [position ]
369+ finally :
370+ self ._release_reader (reader_idx , reader )
347371
348372 def __getitems__ (
349373 self , record_keys : Sequence [SupportsIndex ]
@@ -352,14 +376,16 @@ def read_records(
352376 reader_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
353377 ) -> Sequence [Tuple [Any , int ]]:
354378 """Reads records using the given reader keeping track of the indices."""
355- # Initialize readers lazily when we need to read from them.
356- self ._ensure_reader_exists (reader_idx )
357- positions , indices = list (zip (* reader_positions_and_indices ))
358- if hasattr (self ._readers [reader_idx ], "read" ):
359- records = self ._readers [reader_idx ].read (positions ) # pytype: disable=attribute-error
360- else :
361- records = [self ._readers [reader_idx ][p ] for p in positions ]
362- return list (zip (records , indices ))
379+ reader = self ._get_reader (reader_idx )
380+ try :
381+ positions , indices = list (zip (* reader_positions_and_indices ))
382+ if hasattr (reader , "read" ):
383+ records = reader .read (positions ) # pytype: disable=attribute-error
384+ else :
385+ records = [reader [p ] for p in positions ]
386+ return list (zip (records , indices ))
387+ finally :
388+ self ._release_reader (reader_idx , reader )
363389
364390 positions_and_indices = self ._split_keys_per_reader (record_keys )
365391 num_threads = _get_flag_value (_GRAIN_NUM_THREADS_FETCHING_RECORDS )
@@ -390,15 +416,26 @@ def read_records(
390416 def __getstate__ (self ):
391417 logging .debug ("__getstate__ for ArrayRecordDataSource is called." )
392418 state = self .__dict__ .copy ()
393- del state [ "_readers" ]
419+ state . pop ( "_reader_pools" , None )
394420 return state
395421
396422 def __setstate__ (self , state ):
397423 logging .debug ("__setstate__ for ArrayRecordDataSource is called." )
398424 self .__dict__ .update (state )
399- # We open readers lazily when we need to read from them. Thus, we don't
400- # need to re-open the same files as before pickling.
401- self ._readers = [None ] * len (self ._read_instructions )
425+ # After pickling, we re-initialize the reader pools.
426+ self ._reader_pools = [
427+ queue .LifoQueue (maxsize = self ._reader_pool_size )
428+ for _ in self ._read_instructions
429+ ]
430+ # We open readers lazily when we need to read from them.
431+
432+ def _peek_readers (self ) -> List [Any ]:
433+ """Returns a list of readers (one per shard) or None (for testing)."""
434+ readers = []
435+ for pool in self ._reader_pools :
436+ with pool .mutex :
437+ readers .append (pool .queue [- 1 ] if pool .queue else None )
438+ return readers
402439
403440 def __repr__ (self ) -> str :
404441 """Storing a hash of paths since paths can be a very long list."""
0 commit comments