@@ -39,6 +39,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
3939import os
4040import pathlib
4141import re
42+ import threading
4243import typing
4344from typing import Any , Callable , Iterator , List , Mapping , Protocol , Sequence , SupportsIndex , Tuple , TypeVar , Union
4445
@@ -96,6 +97,8 @@ def _run_in_parallel(
9697 """
9798 if num_workers < 1 :
9899 raise ValueError ("num_workers must be >=1 for parallelism." )
100+ if num_workers == 1 or len (list_of_kwargs_to_function ) == 1 :
101+ return [function (** kwargs ) for kwargs in list_of_kwargs_to_function ]
99102 thread_futures = []
100103 with futures .ThreadPoolExecutor (num_workers ) as executor :
101104 for kwargs in list_of_kwargs_to_function :
@@ -212,6 +215,62 @@ def _check_group_size(
212215 )
213216
214217
218+ class BoundedReaderPool :
219+ """A pool of readers for a single shard with a dynamic upper bound."""
220+
221+ def __init__ (self , filename : str , options_string : str , max_size : int ):
222+ self ._filename = filename
223+ self ._options_string = options_string
224+ self ._max_size = max_size
225+ self ._readers = []
226+ self ._created_count = 0
227+ self ._condition = threading .Condition ()
228+ self ._group_size_checked = False
229+ self .has_read_method = None
230+
231+ def get (self ) -> Any :
232+ """Gets a reader from the pool, blocking if the cap is reached."""
233+ create_new = False
234+ with self ._condition :
235+ # Wait if no idle readers and we reached the cap
236+ while not self ._readers and self ._created_count >= self ._max_size :
237+ self ._condition .wait ()
238+
239+ if self ._readers :
240+ return self ._readers .pop ()
241+
242+ self ._created_count += 1
243+ create_new = True
244+
245+ if create_new :
246+ reader = _create_reader (self ._filename , self ._options_string )
247+ if self .has_read_method is None :
248+ self .has_read_method = hasattr (reader , "read" )
249+ if not self ._group_size_checked :
250+ _check_group_size (self ._filename , reader )
251+ self ._group_size_checked = True
252+ return reader
253+
254+ def put (self , reader : Any ) -> None :
255+ """Returns a reader to the pool."""
256+ with self ._condition :
257+ self ._readers .append (reader )
258+ self ._condition .notify ()
259+
260+ def close_all (self ) -> None :
261+ """Closes all pooled readers."""
262+ with self ._condition :
263+ for reader in self ._readers :
264+ if reader :
265+ reader .close ()
266+ self ._readers .clear ()
267+
268+ def peek_readers (self ) -> List [Any ]:
269+ """Returns the list of readers currently in the pool."""
270+ with self ._condition :
271+ return list (self ._readers )
272+
273+
215274class ArrayRecordDataSource :
216275 """Datasource for ArrayRecord files."""
217276
@@ -221,6 +280,7 @@ def __init__(
221280 PathLikeOrFileInstruction , Sequence [PathLikeOrFileInstruction ]
222281 ],
223282 reader_options : dict [str , str ] | None = None ,
283+ reader_pool_size : int = 1 ,
224284 ):
225285 """Creates a new ArrayRecordDataSource object.
226286
@@ -242,6 +302,8 @@ def __init__(
242302 initialization faster.
243303 reader_options: string of comma-separated options to be passed when
244304 creating a reader.
305+ reader_pool_size: Number of readers to pre-allocate in the pool for each
306+ shard. Default is 1.
245307 """
246308 if isinstance (paths , (str , pathlib .Path , FileInstruction )):
247309 paths = [paths ]
@@ -265,31 +327,49 @@ def __init__(
265327 if reader_options is None :
266328 self ._reader_options_string = ""
267329 else :
330+ reader_options = dict (reader_options )
331+ if "reader_pool_size" in reader_options :
332+ reader_pool_size = int (reader_options .pop ("reader_pool_size" ))
268333 self ._reader_options_string = "," .join (
269334 [f"{ k } :{ v } " for k , v in reader_options .items ()]
270335 )
271336 self ._read_instructions = _get_read_instructions (paths )
272337 self ._paths = [ri .filename for ri in self ._read_instructions ]
338+ self ._reader_pool_size = max (reader_pool_size , 1 )
339+ # We maintain a pool of readers for each shard to ensure thread safety
340+ # while allowing concurrent reads.
341+ self ._shard_pools = [
342+ BoundedReaderPool (
343+ ri .filename , self ._reader_options_string , self ._reader_pool_size
344+ )
345+ for ri in self ._read_instructions
346+ ]
273347 # We open readers lazily when we need to read from them.
274- self . _readers = [ None ] * len ( self . _read_instructions )
348+
275349 self ._num_records = sum (
276350 map (lambda x : x .num_records , self ._read_instructions )
277351 )
278352 records_per_instruction = map (
279353 lambda x : x .num_records , self ._read_instructions
280354 )
281355 self ._prefix_sums = list (itertools .accumulate (records_per_instruction ))
356+ self ._thread_local = threading .local ()
282357
283358 def __enter__ (self ):
284359 logging .debug ("__enter__ for ArrayRecordDataSource is called." )
285360 return self
286361
287362 def __exit__ (self , exc_type , exc_value , traceback ):
288363 logging .debug ("__exit__ for ArrayRecordDataSource is called." )
289- for reader in self ._readers :
290- if reader :
291- reader .close ()
292- self ._readers = [None ] * len (self ._read_instructions )
364+ # Clear thread-local cache for the current thread.
365+ if hasattr (self ._thread_local , "reader" ):
366+ if self ._thread_local .reader is not None :
367+ self ._thread_local .reader .close ()
368+ self ._thread_local .reader = None
369+ self ._thread_local .pool_idx = - 1
370+
371+ for pool in self ._shard_pools :
372+ pool .close_all ()
293373
294374 def __len__ (self ) -> int :
295375 return self ._num_records
@@ -298,79 +378,111 @@ def __iter__(self) -> Iterator[bytes]:
298378 for index in range (self ._num_records ):
299379 yield self [index ]
300380
301- def _reader_idx_and_position (
381+ def _pool_idx_and_position (
302382 self , record_key : SupportsIndex
303383 ) -> Tuple [int , int ]:
304- """Computes reader idx and position of given record key."""
384+ """Computes pool idx and position of given record key."""
305385 record_key = record_key .__index__ ()
306386 if record_key < 0 or record_key >= self ._num_records :
307387 raise ValueError ("Record key should be in [0, num_records)" )
308- reader_idx = bisect .bisect_right (self ._prefix_sums , record_key )
388+ pool_idx = bisect .bisect_right (self ._prefix_sums , record_key )
309389 records_in_previous_instructions = 0
310- if reader_idx > 0 :
311- records_in_previous_instructions = self ._prefix_sums [reader_idx - 1 ]
390+ if pool_idx > 0 :
391+ records_in_previous_instructions = self ._prefix_sums [pool_idx - 1 ]
312392 return (
313- reader_idx ,
393+ pool_idx ,
314394 record_key
315395 - records_in_previous_instructions
316- + self ._read_instructions [reader_idx ].start ,
396+ + self ._read_instructions [pool_idx ].start ,
317397 )
318398
319- def _split_keys_per_reader (
399+ def _split_keys_per_pool (
320400 self , record_keys : Sequence [SupportsIndex ]
321401 ) -> Mapping [int , Sequence [Tuple [int , int ]]]:
322- """Splits record_keys among readers ."""
402+ """Splits record_keys among pools ."""
323403 positions_and_indices = {}
324404 for idx , record_key in enumerate (record_keys ):
325- reader_idx , position = self ._reader_idx_and_position (record_key )
326- if reader_idx in positions_and_indices :
327- positions_and_indices [reader_idx ].append ((position , idx ))
405+ pool_idx , position = self ._pool_idx_and_position (record_key )
406+ if pool_idx in positions_and_indices :
407+ positions_and_indices [pool_idx ].append ((position , idx ))
328408 else :
329- positions_and_indices [reader_idx ] = [(position , idx )]
409+ positions_and_indices [pool_idx ] = [(position , idx )]
330410 return positions_and_indices
331411
332- def _ensure_reader_exists (self , reader_idx : int ) -> None :
333- """Threadsafe method to create corresponding reader if it doesn't exist."""
334- if self ._readers [reader_idx ] is not None :
335- return
336- filename = self ._read_instructions [reader_idx ].filename
337- reader = _create_reader (filename , self ._reader_options_string )
338- _check_group_size (filename , reader )
339- self ._readers [reader_idx ] = reader
412+ def _get_reader (self , pool_idx : int ) -> Any :
413+ """Gets a reader from the single-slot thread-local cache or the pool."""
414+ if not hasattr (self ._thread_local , "pool_idx" ):
415+ self ._thread_local .pool_idx = - 1
416+ self ._thread_local .reader = None
417+
418+ if self ._thread_local .pool_idx == pool_idx :
419+ return self ._thread_local .reader
420+
421+ # Return previous reader to pool if we have one
422+ if self ._thread_local .reader is not None :
423+ self ._shard_pools [self ._thread_local .pool_idx ].put (
424+ self ._thread_local .reader
425+ )
426+
427+ # Get new reader and cache it
428+ reader = self ._shard_pools [pool_idx ].get ()
429+ self ._thread_local .pool_idx = pool_idx
430+ self ._thread_local .reader = reader
431+ return reader
432+
433+ def _release_reader (self , pool_idx : int , reader : Any ) -> None :
434+ """No-op: we keep it cached in _get_reader until the next request."""
435+ pass
436+
437+ def _read_record (
438+ self , reader : Any , pool : BoundedReaderPool , position : int
439+ ) -> bytes :
440+ """Helper to read a record using the best available method."""
441+ if pool .has_read_method :
442+ return reader .read ([position ])[0 ]
443+ if hasattr (reader , "read_record" ):
444+ return reader .read_record (position )
445+ return reader [position ]
340446
341447 def __getitem__ (self , record_key : SupportsIndex ) -> bytes :
342- reader_idx , position = self ._reader_idx_and_position (record_key )
343- self ._ensure_reader_exists (reader_idx )
344- if hasattr (self ._readers [reader_idx ], "read" ):
345- return self ._readers [reader_idx ].read ([position ])[0 ]
346- return self ._readers [reader_idx ][position ]
448+ pool_idx , position = self ._pool_idx_and_position (record_key )
449+ reader = self ._get_reader (pool_idx )
450+ try :
451+ return self ._read_record (reader , self ._shard_pools [pool_idx ], position )
452+ finally :
453+ self ._release_reader (pool_idx , reader )
347454
348455 def __getitems__ (
349456 self , record_keys : Sequence [SupportsIndex ]
350457 ) -> Sequence [bytes ]:
458+
351459 def read_records (
352- reader_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
460+ pool_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
353461 ) -> Sequence [Tuple [Any , int ]]:
354462 """Reads records using the given reader keeping track of the indices."""
355- # Initialize readers lazily when we need to read from them.
356- self ._ensure_reader_exists (reader_idx )
357- positions , indices = list (zip (* reader_positions_and_indices ))
358- if hasattr (self ._readers [reader_idx ], "read" ):
359- records = self ._readers [reader_idx ].read (positions ) # pytype: disable=attribute-error
360- else :
361- records = [self ._readers [reader_idx ][p ] for p in positions ]
362- return list (zip (records , indices ))
363-
364- positions_and_indices = self ._split_keys_per_reader (record_keys )
463+ reader = self ._get_reader (pool_idx )
464+ pool = self ._shard_pools [pool_idx ]
465+ try :
466+ records = []
467+ for position , _ in reader_positions_and_indices :
468+ records .append (self ._read_record (reader , pool , position ))
469+ indices = [idx for _ , idx in reader_positions_and_indices ]
470+ return list (zip (records , indices ))
471+ finally :
472+ self ._release_reader (pool_idx , reader )
473+
474+ # Group record keys by pool/shard to maximize reader reuse.
475+ positions_and_indices = self ._split_keys_per_pool (record_keys )
365476 num_threads = _get_flag_value (_GRAIN_NUM_THREADS_FETCHING_RECORDS )
477+ # Parallelize reads across shards using the available threads.
366478 num_workers = min (len (positions_and_indices ), num_threads )
367479 list_of_kwargs_to_read_records = []
368480 for (
369- reader_idx ,
481+ pool_idx ,
370482 reader_positions_and_indices ,
371483 ) in positions_and_indices .items ():
372484 list_of_kwargs_to_read_records .append ({
373- "reader_idx " : reader_idx ,
485+ "pool_idx " : pool_idx ,
374486 "reader_positions_and_indices" : reader_positions_and_indices ,
375487 })
376488 records_with_indices : Sequence [Sequence [Tuple [Any , int ]]] = (
@@ -390,15 +502,40 @@ def read_records(
390502 def __getstate__ (self ):
391503 logging .debug ("__getstate__ for ArrayRecordDataSource is called." )
392504 state = self .__dict__ .copy ()
393- del state ["_readers" ]
505+ state .pop ("_shard_pools" , None )
506+ state .pop ("_reader_pools" , None )
507+ state .pop ("_thread_local" , None )
394508 return state
395509
396510 def __setstate__ (self , state ):
397511 logging .debug ("__setstate__ for ArrayRecordDataSource is called." )
398512 self .__dict__ .update (state )
399- # We open readers lazily when we need to read from them. Thus, we don't
400- # need to re-open the same files as before pickling.
401- self ._readers = [None ] * len (self ._read_instructions )
513+ self ._shard_pools = [
514+ BoundedReaderPool (
515+ ri .filename ,
516+ self ._reader_options_string ,
517+ getattr (self , "_reader_pool_size" , 1 ),
518+ )
519+ for ri in self ._read_instructions
520+ ]
521+ self ._thread_local = threading .local ()
522+ # We open readers lazily when we need to read from them.
523+
524+ def _peek_readers (self ) -> List [Any ]:
525+ """Returns a list of readers (one per shard) or None (for testing only)."""
526+ readers = []
527+ for i , pool in enumerate (self ._shard_pools ):
528+ reader = None
529+ if (
530+ hasattr (self ._thread_local , "pool_idx" )
531+ and self ._thread_local .pool_idx == i
532+ ):
533+ reader = self ._thread_local .reader
534+ if reader is None :
535+ pooled_readers = pool .peek_readers ()
536+ reader = pooled_readers [- 1 ] if pooled_readers else None
537+ readers .append (reader )
538+ return readers
402539
403540 def __repr__ (self ) -> str :
404541 """Storing a hash of paths since paths can be a very long list."""
0 commit comments