1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
1514"""array_record_data_source module.
1615
1716Warning: this is an experimental module. The interface might change in the
2322
2423```
2524class RandomAccessDataSource(Protocol, Generic[T]):
25+
2626 def __len__(self) -> int:
2727 ...
2828
29- def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
29+ def __getitem__(self, record_key: SupportsIndex) -> T:
30+ ...
31+
32+ def __getitems__(self, record_keys: Sequence[SupportsIndex]) -> Sequence[T]:
3033 ...
3134```
3235"""
3336
3437import bisect
38+ import collections
3539from concurrent import futures
3640import dataclasses
3741import hashlib
3842import itertools
3943import os
4044import pathlib
4145import re
46+ import threading
4247import typing
4348from typing import Any , Callable , Iterator , List , Mapping , Protocol , Sequence , SupportsIndex , Tuple , TypeVar , Union
4449
@@ -48,30 +53,33 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
4853
4954from . import array_record_module
5055
56+ T = TypeVar ("T" )
57+
58+
59+ @typing .runtime_checkable
60+ class FileInstruction (Protocol ):
61+ """Protocol with same interface as FileInstruction returned by TFDS.
62+
63+ ArrayRecordDataSource would accept objects implementing this protocol without
64+ depending on TFDS.
65+ """
66+
67+ filename : str
68+ skip : int
69+ take : int
70+ examples_in_shard : int
71+
72+
73+ PathLikeOrFileInstruction = Union [epath .PathLike , FileInstruction ]
74+ ArrayRecordDataSourcePaths = Union [
75+ PathLikeOrFileInstruction , Sequence [PathLikeOrFileInstruction ]
76+ ]
77+
78+
5179# TODO(jolesiak): Decide what to do with these flags, e.g., remove them (could
5280# be appropriate if we decide to use asyncio) or move them somewhere else and
5381# pass the number of threads as an argument. For now, since we experiment, it's
5482# convenient to have them.
55- _GRAIN_NUM_THREADS_COMPUTING_NUM_RECORDS = flags .DEFINE_integer (
56- "grain_num_threads_computing_num_records" ,
57- 64 ,
58- (
59- "The number of threads used to fetch file instructions (i.e., the max"
60- " number of Array Record files opened while calculating the total"
61- " number of records)."
62- ),
63- )
64- _GRAIN_NUM_THREADS_FETCHING_RECORDS = flags .DEFINE_integer (
65- "grain_num_threads_fetching_records" ,
66- 64 ,
67- (
68- "The number of threads used to fetch records from Array Record files. "
69- "(i.e., the max number of Array Record files opened while fetching "
70- "records)."
71- ),
72- )
73-
74- T = TypeVar ("T" )
7583
7684
7785def _run_in_parallel (
@@ -96,6 +104,7 @@ def _run_in_parallel(
96104 """
97105 if num_workers < 1 :
98106 raise ValueError ("num_workers must be >=1 for parallelism." )
107+
99108 thread_futures = []
100109 with futures .ThreadPoolExecutor (num_workers ) as executor :
101110 for kwargs in list_of_kwargs_to_function :
@@ -125,23 +134,6 @@ def __post_init__(self):
125134 object .__setattr__ (self , "num_records" , self .end - self .start )
126135
127136
128- @typing .runtime_checkable
129- class FileInstruction (Protocol ):
130- """Protocol with same interface as FileInstruction returned by TFDS.
131-
132- ArrayRecordDataSource would accept objects implementing this protocol without
133- depending on TFDS.
134- """
135-
136- filename : str
137- skip : int
138- take : int
139- examples_in_shard : int
140-
141-
142- PathLikeOrFileInstruction = Union [epath .PathLike , FileInstruction ]
143-
144-
145137def _get_read_instructions (
146138 paths : Sequence [PathLikeOrFileInstruction ],
147139) -> Sequence [_ReadInstruction ]:
@@ -212,15 +204,91 @@ def _check_group_size(
212204 )
213205
214206
207+ class BoundedReaderPool :
208+ """A semaphore-throttled connection pool for a single shard."""
209+
210+ def __init__ (self , filename : str , options_string : str , max_size : int = 1 ):
211+ self ._filename = filename
212+ self ._options_string = options_string
213+ self ._max_size = max_size
214+ self ._readers = collections .deque ()
215+ # Use BoundedSemaphore to strictly enforce the max_size cap
216+ self ._semaphore = threading .BoundedSemaphore (max_size )
217+ self ._created_count = 0
218+ self ._lock = threading .Lock ()
219+ self ._group_size_checked = False
220+ self ._closed = False
221+
222+ def get (self ) -> Any :
223+ """Gets a reader atomically, throttling creation if cap is reached."""
224+ self ._semaphore .acquire ()
225+ try :
226+ try :
227+ return self ._readers .popleft ()
228+ except IndexError as exc :
229+ with self ._lock :
230+ if self ._closed :
231+ raise RuntimeError (
232+ f"Cannot get reader from closed pool: { self ._filename } "
233+ ) from exc
234+ reader = _create_reader (self ._filename , self ._options_string )
235+ if not self ._group_size_checked :
236+ _check_group_size (self ._filename , reader )
237+ self ._group_size_checked = True
238+ self ._created_count += 1
239+ return reader
240+ except :
241+ # CRITICAL FIX: Guarantee semaphore permit is released if reader
242+ # creation throws.
243+ self ._semaphore .release ()
244+ raise
245+
246+ def put (self , reader : Any ) -> None :
247+ """Returns a reader to the pool atomically."""
248+ with self ._lock :
249+ if self ._closed :
250+ # If the pool was closed while the reader was borrowed, close it
251+ # immediately.
252+ if reader and hasattr (reader , "close" ):
253+ reader .close ()
254+ self ._semaphore .release ()
255+ return
256+
257+ self ._readers .append (reader )
258+ self ._semaphore .release ()
259+
260+ def close_all (self ) -> None :
261+ """Closes all pooled readers and prevents future allocations."""
262+ with self ._lock :
263+ self ._closed = True
264+
265+ while True :
266+ try :
267+ reader = self ._readers .popleft ()
268+ if reader and hasattr (reader , "close" ):
269+ reader .close ()
270+ except IndexError :
271+ break
272+
273+ def peek_readers (self ) -> List [Any ]:
274+ """Returns the list of readers (for testing only)."""
275+ return list (self ._readers )
276+
277+
278+ # Retain alias for backward compatibility with existing code/tests
279+ LockFreeReaderPool = BoundedReaderPool
280+
281+
215282class ArrayRecordDataSource :
216- """Datasource for ArrayRecord files."""
283+ """Datasource for ArrayRecord files using a Lock-Free Connection Pool ."""
217284
218285 def __init__ (
219286 self ,
220287 paths : Union [
221288 PathLikeOrFileInstruction , Sequence [PathLikeOrFileInstruction ]
222289 ],
223290 reader_options : dict [str , str ] | None = None ,
291+ reader_pool_size : int | None = None ,
224292 ):
225293 """Creates a new ArrayRecordDataSource object.
226294
@@ -242,11 +310,11 @@ def __init__(
242310 initialization faster.
243311 reader_options: string of comma-separated options to be passed when
244312 creating a reader.
313+ reader_pool_size: The maximum number of readers to keep open per shard.
245314 """
246315 if isinstance (paths , (str , pathlib .Path , FileInstruction )):
247316 paths = [paths ]
248317 elif isinstance (paths , Sequence ):
249- # Validate correct format of a sequence path
250318 if len (paths ) <= 0 :
251319 raise ValueError ("Paths sequence can not be of 0 length" )
252320 elif not all (
@@ -270,8 +338,18 @@ def __init__(
270338 )
271339 self ._read_instructions = _get_read_instructions (paths )
272340 self ._paths = [ri .filename for ri in self ._read_instructions ]
273- # We open readers lazily when we need to read from them.
274- self ._readers = [None ] * len (self ._read_instructions )
341+ self ._reader_pool_size = (
342+ reader_pool_size or _get_flag_value (_GRAIN_READER_POOL_SIZE ) or 1
343+ )
344+
345+ # Lock-free connection pool per shard
346+ self ._shard_pools = [
347+ LockFreeReaderPool (
348+ ri .filename , self ._reader_options_string , self ._reader_pool_size
349+ )
350+ for ri in self ._read_instructions
351+ ]
352+
275353 self ._num_records = sum (
276354 map (lambda x : x .num_records , self ._read_instructions )
277355 )
@@ -286,10 +364,8 @@ def __enter__(self):
286364
287365 def __exit__ (self , exc_type , exc_value , traceback ):
288366 logging .debug ("__exit__ for ArrayRecordDataSource is called." )
289- for reader in self ._readers :
290- if reader :
291- reader .close ()
292- self ._readers = [None ] * len (self ._read_instructions )
367+ for pool in self ._shard_pools :
368+ pool .close_all ()
293369
294370 def __len__ (self ) -> int :
295371 return self ._num_records
@@ -329,48 +405,50 @@ def _split_keys_per_reader(
329405 positions_and_indices [reader_idx ] = [(position , idx )]
330406 return positions_and_indices
331407
332- def _ensure_reader_exists (self , reader_idx : int ) -> None :
333- """Threadsafe method to create corresponding reader if it doesn't exist."""
334- if self ._readers [reader_idx ] is not None :
335- return
336- filename = self ._read_instructions [reader_idx ].filename
337- reader = _create_reader (filename , self ._reader_options_string )
338- _check_group_size (filename , reader )
339- self ._readers [reader_idx ] = reader
408+ def _read_record (self , reader : Any , position : int ) -> bytes :
409+ """Helper to read a record using the best available method."""
410+ if hasattr (reader , "read_record" ):
411+ return reader .read_record (position )
412+ if hasattr (reader , "read" ):
413+ return reader .read ([position ])[0 ]
414+ return reader [position ]
340415
341416 def __getitem__ (self , record_key : SupportsIndex ) -> bytes :
342- reader_idx , position = self ._reader_idx_and_position (record_key )
343- self ._ensure_reader_exists (reader_idx )
344- if hasattr (self ._readers [reader_idx ], "read" ):
345- return self ._readers [reader_idx ].read ([position ])[0 ]
346- return self ._readers [reader_idx ][position ]
417+ pool_idx , position = self ._reader_idx_and_position (record_key )
418+ reader = self ._shard_pools [pool_idx ].get ()
419+ try :
420+ return self ._read_record (reader , position )
421+ finally :
422+ self ._shard_pools [pool_idx ].put (reader )
347423
348424 def __getitems__ (
349425 self , record_keys : Sequence [SupportsIndex ]
350426 ) -> Sequence [bytes ]:
427+
351428 def read_records (
352- reader_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
429+ pool_idx : int , reader_positions_and_indices : Sequence [Tuple [int , int ]]
353430 ) -> Sequence [Tuple [Any , int ]]:
354431 """Reads records using the given reader keeping track of the indices."""
355- # Initialize readers lazily when we need to read from them.
356- self ._ensure_reader_exists (reader_idx )
357- positions , indices = list (zip (* reader_positions_and_indices ))
358- if hasattr (self ._readers [reader_idx ], "read" ):
359- records = self ._readers [reader_idx ].read (positions ) # pytype: disable=attribute-error
360- else :
361- records = [self ._readers [reader_idx ][p ] for p in positions ]
362- return list (zip (records , indices ))
432+ reader = self ._shard_pools [pool_idx ].get ()
433+ try :
434+ records = []
435+ for position , _ in reader_positions_and_indices :
436+ records .append (self ._read_record (reader , position ))
437+ indices = [idx for _ , idx in reader_positions_and_indices ]
438+ return list (zip (records , indices ))
439+ finally :
440+ self ._shard_pools [pool_idx ].put (reader )
363441
364442 positions_and_indices = self ._split_keys_per_reader (record_keys )
365443 num_threads = _get_flag_value (_GRAIN_NUM_THREADS_FETCHING_RECORDS )
366444 num_workers = min (len (positions_and_indices ), num_threads )
367445 list_of_kwargs_to_read_records = []
368446 for (
369- reader_idx ,
447+ pool_idx ,
370448 reader_positions_and_indices ,
371449 ) in positions_and_indices .items ():
372450 list_of_kwargs_to_read_records .append ({
373- "reader_idx " : reader_idx ,
451+ "pool_idx " : pool_idx ,
374452 "reader_positions_and_indices" : reader_positions_and_indices ,
375453 })
376454 records_with_indices : Sequence [Sequence [Tuple [Any , int ]]] = (
@@ -390,15 +468,22 @@ def read_records(
390468 def __getstate__ (self ):
391469 logging .debug ("__getstate__ for ArrayRecordDataSource is called." )
392470 state = self .__dict__ .copy ()
393- del state [ "_readers" ]
471+ state . pop ( "_shard_pools" , None )
394472 return state
395473
396474 def __setstate__ (self , state ):
397475 logging .debug ("__setstate__ for ArrayRecordDataSource is called." )
398476 self .__dict__ .update (state )
399477 # We open readers lazily when we need to read from them. Thus, we don't
400478 # need to re-open the same files as before pickling.
401- self ._readers = [None ] * len (self ._read_instructions )
479+ self ._shard_pools = [
480+ LockFreeReaderPool (
481+ ri .filename ,
482+ self ._reader_options_string ,
483+ getattr (self , "_reader_pool_size" , 1 ),
484+ )
485+ for ri in self ._read_instructions
486+ ]
402487
403488 def __repr__ (self ) -> str :
404489 """Storing a hash of paths since paths can be a very long list."""
@@ -407,10 +492,22 @@ def __repr__(self) -> str:
407492 h .update (p .encode ())
408493 return f"ArrayRecordDataSource(hash_of_paths={ h .hexdigest ()} )"
409494
495+ def _peek_readers (self ) -> List [Any ]:
496+ """Returns a list of readers (one per shard) or None (for testing only)."""
497+ readers = []
498+ for pool in self ._shard_pools :
499+ pooled_readers = pool .peek_readers ()
500+ readers .append (pooled_readers [- 1 ] if pooled_readers else None )
501+ return readers
502+
410503
411504def _get_flag_value (flag : flags .FlagHolder [int ]) -> int :
412505 """Retrieves the flag value or the default if run outside of absl."""
413506 try :
414507 return flag .value
415508 except flags .UnparsedFlagAccessError :
416509 return flag .default
510+
511+
512+ # Alias for backward compatibility with unit tests
513+ BoundedReaderPool = LockFreeReaderPool
0 commit comments