@@ -38,6 +38,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3838import itertools
3939import os
4040import pathlib
41+ import queue
4142import re
4243import typing
4344from typing import Any , Callable , Iterator , List , Mapping , Protocol , Sequence , SupportsIndex , Tuple , TypeVar , Union
@@ -221,6 +222,7 @@ def __init__(
221222 PathLikeOrFileInstruction , Sequence [PathLikeOrFileInstruction ]
222223 ],
223224 reader_options : dict [str , str ] | None = None ,
225+ reader_pool_size : int = 1 ,
224226 ):
225227 """Creates a new ArrayRecordDataSource object.
226228
@@ -242,6 +244,8 @@ def __init__(
242244 initialization faster.
243245 reader_options: string of comma-separated options to be passed when
244246 creating a reader.
247+ reader_pool_size: Number of readers to pre-allocate in the pool for each
248+ shard. Default is 1.
245249 """
246250 if isinstance (paths , (str , pathlib .Path , FileInstruction )):
247251 paths = [paths ]
@@ -270,8 +274,20 @@ def __init__(
270274 )
271275 self ._read_instructions = _get_read_instructions (paths )
272276 self ._paths = [ri .filename for ri in self ._read_instructions ]
273- # We open readers lazily when we need to read from them.
274- self ._readers = [None ] * len (self ._read_instructions )
277+ self ._reader_pool_size = max (reader_pool_size , 1 )
278+ # We maintain a pool of readers for each shard to ensure thread safety
279+ # while allowing concurrent reads.
280+ self ._reader_pools = [
281+ queue .LifoQueue (maxsize = self ._reader_pool_size )
282+ for _ in self ._read_instructions
283+ ]
284+ # We pre-populate the pools sequentially to avoid I/O storms at startup.
285+ for i , ri in enumerate (self ._read_instructions ):
286+ for _ in range (self ._reader_pool_size ):
287+ reader = _create_reader (ri .filename , self ._reader_options_string )
288+ _check_group_size (ri .filename , reader )
289+ self ._reader_pools [i ].put (reader )
290+
275291 self ._num_records = sum (
276292 map (lambda x : x .num_records , self ._read_instructions )
277293 )
@@ -286,10 +302,11 @@ def __enter__(self):
286302
287303 def __exit__ (self , exc_type , exc_value , traceback ):
288304 logging .debug ("__exit__ for ArrayRecordDataSource is called." )
289- for reader in self ._readers :
290- if reader :
291- reader .close ()
292- self ._readers = [None ] * len (self ._read_instructions )
305+ for pool in self ._reader_pools :
306+ while not pool .empty ():
307+ reader = pool .get ()
308+ if reader :
309+ reader .close ()
293310
294311 def __len__ (self ) -> int :
295312 return self ._num_records
@@ -329,21 +346,33 @@ def _split_keys_per_reader(
329346 positions_and_indices [reader_idx ] = [(position , idx )]
330347 return positions_and_indices
331348
332- def _ensure_reader_exists (self , reader_idx : int ) -> None :
333- """Threadsafe method to create corresponding reader if it doesn't exist."""
334- if self ._readers [reader_idx ] is not None :
335- return
336- filename = self ._read_instructions [reader_idx ].filename
337- reader = _create_reader (filename , self ._reader_options_string )
338- _check_group_size (filename , reader )
339- self ._readers [reader_idx ] = reader
349+ def _get_reader (self , reader_idx : int ) -> Any :
350+ """Gets a reader from the pool or creates a new one."""
351+ try :
352+ return self ._reader_pools [reader_idx ].get_nowait ()
353+ except queue .Empty :
354+ filename = self ._read_instructions [reader_idx ].filename
355+ reader = _create_reader (filename , self ._reader_options_string )
356+ _check_group_size (filename , reader )
357+ return reader
358+
359+ def _release_reader (self , reader_idx : int , reader : Any ) -> None :
360+ """Returns a reader to the pool."""
361+ try :
362+ self ._reader_pools [reader_idx ].put_nowait (reader )
363+ except queue .Full :
364+ if reader :
365+ reader .close ()
340366
341367 def __getitem__ (self , record_key : SupportsIndex ) -> bytes :
342368 reader_idx , position = self ._reader_idx_and_position (record_key )
343- self ._ensure_reader_exists (reader_idx )
344- if hasattr (self ._readers [reader_idx ], "read" ):
345- return self ._readers [reader_idx ].read ([position ])[0 ]
346- return self ._readers [reader_idx ][position ]
369+ reader = self ._get_reader (reader_idx )
370+ try :
371+ if hasattr (reader , "read" ):
372+ return reader .read ([position ])[0 ]
373+ return reader [position ]
374+ finally :
375+ self ._release_reader (reader_idx , reader )
347376
348377 def __getitems__ (
349378 self , record_keys : Sequence [SupportsIndex ]
@@ -352,14 +381,16 @@ def read_records(
352381 reader_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
353382 ) -> Sequence [Tuple [Any , int ]]:
354383 """Reads records using the given reader keeping track of the indices."""
355- # Initialize readers lazily when we need to read from them.
356- self ._ensure_reader_exists (reader_idx )
357- positions , indices = list (zip (* reader_positions_and_indices ))
358- if hasattr (self ._readers [reader_idx ], "read" ):
359- records = self ._readers [reader_idx ].read (positions ) # pytype: disable=attribute-error
360- else :
361- records = [self ._readers [reader_idx ][p ] for p in positions ]
362- return list (zip (records , indices ))
384+ reader = self ._get_reader (reader_idx )
385+ try :
386+ positions , indices = list (zip (* reader_positions_and_indices ))
387+ if hasattr (reader , "read" ):
388+ records = reader .read (positions ) # pytype: disable=attribute-error
389+ else :
390+ records = [reader [p ] for p in positions ]
391+ return list (zip (records , indices ))
392+ finally :
393+ self ._release_reader (reader_idx , reader )
363394
364395 positions_and_indices = self ._split_keys_per_reader (record_keys )
365396 num_threads = _get_flag_value (_GRAIN_NUM_THREADS_FETCHING_RECORDS )
@@ -390,15 +421,31 @@ def read_records(
390421 def __getstate__ (self ):
391422 logging .debug ("__getstate__ for ArrayRecordDataSource is called." )
392423 state = self .__dict__ .copy ()
393- del state [ "_readers" ]
424+ state . pop ( "_reader_pools" , None )
394425 return state
395426
396427 def __setstate__ (self , state ):
397428 logging .debug ("__setstate__ for ArrayRecordDataSource is called." )
398429 self .__dict__ .update (state )
399- # We open readers lazily when we need to read from them. Thus, we don't
400- # need to re-open the same files as before pickling.
401- self ._readers = [None ] * len (self ._read_instructions )
430+ # After pickling, we re-initialize the reader pools.
431+ self ._reader_pools = [
432+ queue .LifoQueue (maxsize = self ._reader_pool_size )
433+ for _ in self ._read_instructions
434+ ]
435+ # We pre-populate the pools again in the new process.
436+ for i , ri in enumerate (self ._read_instructions ):
437+ for _ in range (self ._reader_pool_size ):
438+ reader = _create_reader (ri .filename , self ._reader_options_string )
439+ _check_group_size (ri .filename , reader )
440+ self ._reader_pools [i ].put (reader )
441+
442+ def _peek_readers (self ) -> List [Any ]:
443+ """Returns a list of readers (one per shard) or None (for testing)."""
444+ readers = []
445+ for pool in self ._reader_pools :
446+ with pool .mutex :
447+ readers .append (pool .queue [- 1 ] if pool .queue else None )
448+ return readers
402449
403450 def __repr__ (self ) -> str :
404451 """Storing a hash of paths since paths can be a very long list."""
0 commit comments