diff --git a/python/array_record_data_source.py b/python/array_record_data_source.py index 117d244..e2e9a61 100644 --- a/python/array_record_data_source.py +++ b/python/array_record_data_source.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """array_record_data_source module. Warning: this is an experimental module. The interface might change in the @@ -23,15 +22,20 @@ ``` class RandomAccessDataSource(Protocol, Generic[T]): + def __len__(self) -> int: ... - def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: + def __getitem__(self, record_key: SupportsIndex) -> T: + ... + + def __getitems__(self, record_keys: Sequence[SupportsIndex]) -> Sequence[T]: ... ``` """ import bisect +import collections from concurrent import futures import dataclasses import hashlib @@ -39,6 +43,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: import os import pathlib import re +import threading import typing from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union @@ -48,30 +53,33 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: from . import array_record_module +T = TypeVar("T") + + +@typing.runtime_checkable +class FileInstruction(Protocol): + """Protocol with same interface as FileInstruction returned by TFDS. + + ArrayRecordDataSource would accept objects implementing this protocol without + depending on TFDS. + """ + + filename: str + skip: int + take: int + examples_in_shard: int + + +PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction] +ArrayRecordDataSourcePaths = Union[ + PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction] +] + + # TODO(jolesiak): Decide what to do with these flags, e.g., remove them (could # be appropriate if we decide to use asyncio) or move them somewhere else and # pass the number of threads as an argument. For now, since we experiment, it's # convenient to have them. -_GRAIN_NUM_THREADS_COMPUTING_NUM_RECORDS = flags.DEFINE_integer( - "grain_num_threads_computing_num_records", - 64, - ( - "The number of threads used to fetch file instructions (i.e., the max" - " number of Array Record files opened while calculating the total" - " number of records)." - ), -) -_GRAIN_NUM_THREADS_FETCHING_RECORDS = flags.DEFINE_integer( - "grain_num_threads_fetching_records", - 64, - ( - "The number of threads used to fetch records from Array Record files. " - "(i.e., the max number of Array Record files opened while fetching " - "records)." - ), -) - -T = TypeVar("T") def _run_in_parallel( @@ -96,6 +104,7 @@ def _run_in_parallel( """ if num_workers < 1: raise ValueError("num_workers must be >=1 for parallelism.") + thread_futures = [] with futures.ThreadPoolExecutor(num_workers) as executor: for kwargs in list_of_kwargs_to_function: @@ -125,23 +134,6 @@ def __post_init__(self): object.__setattr__(self, "num_records", self.end - self.start) -@typing.runtime_checkable -class FileInstruction(Protocol): - """Protocol with same interface as FileInstruction returned by TFDS. - - ArrayRecordDataSource would accept objects implementing this protocol without - depending on TFDS. - """ - - filename: str - skip: int - take: int - examples_in_shard: int - - -PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction] - - def _get_read_instructions( paths: Sequence[PathLikeOrFileInstruction], ) -> Sequence[_ReadInstruction]: @@ -212,8 +204,83 @@ def _check_group_size( ) +class BoundedReaderPool: + """A semaphore-throttled connection pool for a single shard.""" + + def __init__(self, filename: str, options_string: str, max_size: int = 1): + self._filename = filename + self._options_string = options_string + self._max_size = max_size + self._readers = collections.deque() + # Use BoundedSemaphore to strictly enforce the max_size cap + self._semaphore = threading.BoundedSemaphore(max_size) + self._created_count = 0 + self._lock = threading.Lock() + self._group_size_checked = False + self._closed = False + + def get(self) -> Any: + """Gets a reader atomically, throttling creation if cap is reached.""" + self._semaphore.acquire() + try: + try: + return self._readers.popleft() + except IndexError as exc: + with self._lock: + if self._closed: + raise RuntimeError( + f"Cannot get reader from closed pool: {self._filename}" + ) from exc + reader = _create_reader(self._filename, self._options_string) + if not self._group_size_checked: + _check_group_size(self._filename, reader) + self._group_size_checked = True + self._created_count += 1 + return reader + except: + # CRITICAL FIX: Guarantee semaphore permit is released if reader + # creation throws. + self._semaphore.release() + raise + + def put(self, reader: Any) -> None: + """Returns a reader to the pool atomically.""" + with self._lock: + if self._closed: + # If the pool was closed while the reader was borrowed, close it + # immediately. + if reader and hasattr(reader, "close"): + reader.close() + self._semaphore.release() + return + + self._readers.append(reader) + self._semaphore.release() + + def close_all(self) -> None: + """Closes all pooled readers and prevents future allocations.""" + with self._lock: + self._closed = True + + while True: + try: + reader = self._readers.popleft() + if reader and hasattr(reader, "close"): + reader.close() + except IndexError: + break + + def peek_readers(self) -> List[Any]: + """Returns the list of readers (for testing only).""" + return list(self._readers) + + +# Retain alias for backward compatibility with existing code/tests +LockFreeReaderPool = BoundedReaderPool + + class ArrayRecordDataSource: - """Datasource for ArrayRecord files.""" + """Datasource for ArrayRecord files using a Lock-Free Connection Pool.""" def __init__( self, @@ -221,6 +288,7 @@ def __init__( PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction] ], reader_options: dict[str, str] | None = None, + reader_pool_size: int | None = None, ): """Creates a new ArrayRecordDataSource object. @@ -242,11 +310,11 @@ def __init__( initialization faster. reader_options: string of comma-separated options to be passed when creating a reader. + reader_pool_size: The maximum number of readers to keep open per shard. """ if isinstance(paths, (str, pathlib.Path, FileInstruction)): paths = [paths] elif isinstance(paths, Sequence): - # Validate correct format of a sequence path if len(paths) <= 0: raise ValueError("Paths sequence can not be of 0 length") elif not all( @@ -270,8 +338,18 @@ def __init__( ) self._read_instructions = _get_read_instructions(paths) self._paths = [ri.filename for ri in self._read_instructions] - # We open readers lazily when we need to read from them. - self._readers = [None] * len(self._read_instructions) + self._reader_pool_size = ( + reader_pool_size or _get_flag_value(_GRAIN_READER_POOL_SIZE) or 1 + ) + + # Lock-free connection pool per shard + self._shard_pools = [ + LockFreeReaderPool( + ri.filename, self._reader_options_string, self._reader_pool_size + ) + for ri in self._read_instructions + ] + self._num_records = sum( map(lambda x: x.num_records, self._read_instructions) ) @@ -286,10 +364,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): logging.debug("__exit__ for ArrayRecordDataSource is called.") - for reader in self._readers: - if reader: - reader.close() - self._readers = [None] * len(self._read_instructions) + for pool in self._shard_pools: + pool.close_all() def __len__(self) -> int: return self._num_records @@ -329,48 +405,50 @@ def _split_keys_per_reader( positions_and_indices[reader_idx] = [(position, idx)] return positions_and_indices - def _ensure_reader_exists(self, reader_idx: int) -> None: - """Threadsafe method to create corresponding reader if it doesn't exist.""" - if self._readers[reader_idx] is not None: - return - filename = self._read_instructions[reader_idx].filename - reader = _create_reader(filename, self._reader_options_string) - _check_group_size(filename, reader) - self._readers[reader_idx] = reader + def _read_record(self, reader: Any, position: int) -> bytes: + """Helper to read a record using the best available method.""" + if hasattr(reader, "read_record"): + return reader.read_record(position) + if hasattr(reader, "read"): + return reader.read([position])[0] + return reader[position] def __getitem__(self, record_key: SupportsIndex) -> bytes: - reader_idx, position = self._reader_idx_and_position(record_key) - self._ensure_reader_exists(reader_idx) - if hasattr(self._readers[reader_idx], "read"): - return self._readers[reader_idx].read([position])[0] - return self._readers[reader_idx][position] + pool_idx, position = self._reader_idx_and_position(record_key) + reader = self._shard_pools[pool_idx].get() + try: + return self._read_record(reader, position) + finally: + self._shard_pools[pool_idx].put(reader) def __getitems__( self, record_keys: Sequence[SupportsIndex] ) -> Sequence[bytes]: + def read_records( - reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]] + pool_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]] ) -> Sequence[Tuple[Any, int]]: """Reads records using the given reader keeping track of the indices.""" - # Initialize readers lazily when we need to read from them. - self._ensure_reader_exists(reader_idx) - positions, indices = list(zip(*reader_positions_and_indices)) - if hasattr(self._readers[reader_idx], "read"): - records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error - else: - records = [self._readers[reader_idx][p] for p in positions] - return list(zip(records, indices)) + reader = self._shard_pools[pool_idx].get() + try: + records = [] + for position, _ in reader_positions_and_indices: + records.append(self._read_record(reader, position)) + indices = [idx for _, idx in reader_positions_and_indices] + return list(zip(records, indices)) + finally: + self._shard_pools[pool_idx].put(reader) positions_and_indices = self._split_keys_per_reader(record_keys) num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS) num_workers = min(len(positions_and_indices), num_threads) list_of_kwargs_to_read_records = [] for ( - reader_idx, + pool_idx, reader_positions_and_indices, ) in positions_and_indices.items(): list_of_kwargs_to_read_records.append({ - "reader_idx": reader_idx, + "pool_idx": pool_idx, "reader_positions_and_indices": reader_positions_and_indices, }) records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = ( @@ -390,7 +468,7 @@ def read_records( def __getstate__(self): logging.debug("__getstate__ for ArrayRecordDataSource is called.") state = self.__dict__.copy() - del state["_readers"] + state.pop("_shard_pools", None) return state def __setstate__(self, state): @@ -398,7 +476,14 @@ def __setstate__(self, state): self.__dict__.update(state) # We open readers lazily when we need to read from them. Thus, we don't # need to re-open the same files as before pickling. - self._readers = [None] * len(self._read_instructions) + self._shard_pools = [ + LockFreeReaderPool( + ri.filename, + self._reader_options_string, + getattr(self, "_reader_pool_size", 1), + ) + for ri in self._read_instructions + ] def __repr__(self) -> str: """Storing a hash of paths since paths can be a very long list.""" @@ -407,6 +492,14 @@ def __repr__(self) -> str: h.update(p.encode()) return f"ArrayRecordDataSource(hash_of_paths={h.hexdigest()})" + def _peek_readers(self) -> List[Any]: + """Returns a list of readers (one per shard) or None (for testing only).""" + readers = [] + for pool in self._shard_pools: + pooled_readers = pool.peek_readers() + readers.append(pooled_readers[-1] if pooled_readers else None) + return readers + def _get_flag_value(flag: flags.FlagHolder[int]) -> int: """Retrieves the flag value or the default if run outside of absl.""" @@ -414,3 +507,7 @@ def _get_flag_value(flag: flags.FlagHolder[int]) -> int: return flag.value except flags.UnparsedFlagAccessError: return flag.default + + +# Alias for backward compatibility with unit tests +BoundedReaderPool = LockFreeReaderPool diff --git a/python/array_record_data_source_test.py b/python/array_record_data_source_test.py index 8977a27..b4de730 100644 --- a/python/array_record_data_source_test.py +++ b/python/array_record_data_source_test.py @@ -17,6 +17,7 @@ import dataclasses import os import pathlib +import pickle from unittest import mock from absl import flags @@ -109,7 +110,7 @@ def test_array_record_data_source_single_path(self): ) as ar: actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_string_read_instructions(self): indices_to_read = [0, 1, 2, 3, 4] @@ -132,7 +133,7 @@ def test_array_record_data_source_reverse_order(self): ]) as ar: actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_random_order(self): # some random permutation @@ -144,7 +145,7 @@ def test_array_record_data_source_random_order(self): ]) as ar: actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_random_order_batched(self): # some random permutation @@ -156,7 +157,7 @@ def test_array_record_data_source_random_order_batched(self): ]) as ar: actual_data = ar.__getitems__(indices_to_read) self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_data_source_file_instructions(self): file_instruction_one = DummyFileInstruction( @@ -187,7 +188,7 @@ def test_array_record_data_source_file_instructions(self): actual_data = [ar[x] for x in indices_to_read] self.assertEqual(expected_data, actual_data) - self.assertTrue(all(reader is None for reader in ar._readers)) + self.assertTrue(all(reader is None for reader in ar._peek_readers())) def test_array_record_source_reader_idx_and_position(self): file_instructions = [