Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 170 additions & 73 deletions python/array_record_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""array_record_data_source module.

Warning: this is an experimental module. The interface might change in the
Expand All @@ -23,22 +22,28 @@

```
class RandomAccessDataSource(Protocol, Generic[T]):

def __len__(self) -> int:
...

def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
def __getitem__(self, record_key: SupportsIndex) -> T:
...

def __getitems__(self, record_keys: Sequence[SupportsIndex]) -> Sequence[T]:
...
```
"""

import bisect
import collections
from concurrent import futures
import dataclasses
import hashlib
import itertools
import os
import pathlib
import re
import threading
import typing
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union

Expand All @@ -48,30 +53,33 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:

from . import array_record_module

T = TypeVar("T")


@typing.runtime_checkable
class FileInstruction(Protocol):
"""Protocol with same interface as FileInstruction returned by TFDS.

ArrayRecordDataSource would accept objects implementing this protocol without
depending on TFDS.
"""

filename: str
skip: int
take: int
examples_in_shard: int


PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction]
ArrayRecordDataSourcePaths = Union[
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
]


# TODO(jolesiak): Decide what to do with these flags, e.g., remove them (could
# be appropriate if we decide to use asyncio) or move them somewhere else and
# pass the number of threads as an argument. For now, since we experiment, it's
# convenient to have them.
_GRAIN_NUM_THREADS_COMPUTING_NUM_RECORDS = flags.DEFINE_integer(
"grain_num_threads_computing_num_records",
64,
(
"The number of threads used to fetch file instructions (i.e., the max"
" number of Array Record files opened while calculating the total"
" number of records)."
),
)
_GRAIN_NUM_THREADS_FETCHING_RECORDS = flags.DEFINE_integer(
"grain_num_threads_fetching_records",
64,
(
"The number of threads used to fetch records from Array Record files. "
"(i.e., the max number of Array Record files opened while fetching "
"records)."
),
)

T = TypeVar("T")


def _run_in_parallel(
Expand All @@ -96,6 +104,7 @@ def _run_in_parallel(
"""
if num_workers < 1:
raise ValueError("num_workers must be >=1 for parallelism.")

thread_futures = []
with futures.ThreadPoolExecutor(num_workers) as executor:
for kwargs in list_of_kwargs_to_function:
Expand Down Expand Up @@ -125,23 +134,6 @@ def __post_init__(self):
object.__setattr__(self, "num_records", self.end - self.start)


@typing.runtime_checkable
class FileInstruction(Protocol):
"""Protocol with same interface as FileInstruction returned by TFDS.

ArrayRecordDataSource would accept objects implementing this protocol without
depending on TFDS.
"""

filename: str
skip: int
take: int
examples_in_shard: int


PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction]


def _get_read_instructions(
paths: Sequence[PathLikeOrFileInstruction],
) -> Sequence[_ReadInstruction]:
Expand Down Expand Up @@ -212,15 +204,91 @@ def _check_group_size(
)


class BoundedReaderPool:
"""A semaphore-throttled connection pool for a single shard."""

def __init__(self, filename: str, options_string: str, max_size: int = 1):
self._filename = filename
self._options_string = options_string
self._max_size = max_size
self._readers = collections.deque()
# Use BoundedSemaphore to strictly enforce the max_size cap
self._semaphore = threading.BoundedSemaphore(max_size)
self._created_count = 0
self._lock = threading.Lock()
self._group_size_checked = False
self._closed = False

def get(self) -> Any:
"""Gets a reader atomically, throttling creation if cap is reached."""
self._semaphore.acquire()
try:
try:
return self._readers.popleft()
except IndexError as exc:
with self._lock:
if self._closed:
raise RuntimeError(
f"Cannot get reader from closed pool: {self._filename}"
) from exc
reader = _create_reader(self._filename, self._options_string)
if not self._group_size_checked:
_check_group_size(self._filename, reader)
self._group_size_checked = True
self._created_count += 1
return reader
except:
# CRITICAL FIX: Guarantee semaphore permit is released if reader
# creation throws.
self._semaphore.release()
raise

def put(self, reader: Any) -> None:
"""Returns a reader to the pool atomically."""
with self._lock:
if self._closed:
# If the pool was closed while the reader was borrowed, close it
# immediately.
if reader and hasattr(reader, "close"):
reader.close()
self._semaphore.release()
return

self._readers.append(reader)
self._semaphore.release()

def close_all(self) -> None:
"""Closes all pooled readers and prevents future allocations."""
with self._lock:
self._closed = True

while True:
try:
reader = self._readers.popleft()
if reader and hasattr(reader, "close"):
reader.close()
except IndexError:
break

def peek_readers(self) -> List[Any]:
"""Returns the list of readers (for testing only)."""
return list(self._readers)


# Retain alias for backward compatibility with existing code/tests
LockFreeReaderPool = BoundedReaderPool


class ArrayRecordDataSource:
"""Datasource for ArrayRecord files."""
"""Datasource for ArrayRecord files using a Lock-Free Connection Pool."""

def __init__(
self,
paths: Union[
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
],
reader_options: dict[str, str] | None = None,
reader_pool_size: int | None = None,
):
"""Creates a new ArrayRecordDataSource object.

Expand All @@ -242,11 +310,11 @@ def __init__(
initialization faster.
reader_options: string of comma-separated options to be passed when
creating a reader.
reader_pool_size: The maximum number of readers to keep open per shard.
"""
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
paths = [paths]
elif isinstance(paths, Sequence):
# Validate correct format of a sequence path
if len(paths) <= 0:
raise ValueError("Paths sequence can not be of 0 length")
elif not all(
Expand All @@ -270,8 +338,18 @@ def __init__(
)
self._read_instructions = _get_read_instructions(paths)
self._paths = [ri.filename for ri in self._read_instructions]
# We open readers lazily when we need to read from them.
self._readers = [None] * len(self._read_instructions)
self._reader_pool_size = (
reader_pool_size or _get_flag_value(_GRAIN_READER_POOL_SIZE) or 1
)

# Lock-free connection pool per shard
self._shard_pools = [
LockFreeReaderPool(
ri.filename, self._reader_options_string, self._reader_pool_size
)
for ri in self._read_instructions
]

self._num_records = sum(
map(lambda x: x.num_records, self._read_instructions)
)
Expand All @@ -286,10 +364,8 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
logging.debug("__exit__ for ArrayRecordDataSource is called.")
for reader in self._readers:
if reader:
reader.close()
self._readers = [None] * len(self._read_instructions)
for pool in self._shard_pools:
pool.close_all()

def __len__(self) -> int:
return self._num_records
Expand Down Expand Up @@ -329,48 +405,50 @@ def _split_keys_per_reader(
positions_and_indices[reader_idx] = [(position, idx)]
return positions_and_indices

def _ensure_reader_exists(self, reader_idx: int) -> None:
"""Threadsafe method to create corresponding reader if it doesn't exist."""
if self._readers[reader_idx] is not None:
return
filename = self._read_instructions[reader_idx].filename
reader = _create_reader(filename, self._reader_options_string)
_check_group_size(filename, reader)
self._readers[reader_idx] = reader
def _read_record(self, reader: Any, position: int) -> bytes:
"""Helper to read a record using the best available method."""
if hasattr(reader, "read_record"):
return reader.read_record(position)
if hasattr(reader, "read"):
return reader.read([position])[0]
return reader[position]

def __getitem__(self, record_key: SupportsIndex) -> bytes:
reader_idx, position = self._reader_idx_and_position(record_key)
self._ensure_reader_exists(reader_idx)
if hasattr(self._readers[reader_idx], "read"):
return self._readers[reader_idx].read([position])[0]
return self._readers[reader_idx][position]
pool_idx, position = self._reader_idx_and_position(record_key)
reader = self._shard_pools[pool_idx].get()
try:
return self._read_record(reader, position)
finally:
self._shard_pools[pool_idx].put(reader)

def __getitems__(
self, record_keys: Sequence[SupportsIndex]
) -> Sequence[bytes]:

def read_records(
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
pool_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
) -> Sequence[Tuple[Any, int]]:
"""Reads records using the given reader keeping track of the indices."""
# Initialize readers lazily when we need to read from them.
self._ensure_reader_exists(reader_idx)
positions, indices = list(zip(*reader_positions_and_indices))
if hasattr(self._readers[reader_idx], "read"):
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
else:
records = [self._readers[reader_idx][p] for p in positions]
return list(zip(records, indices))
reader = self._shard_pools[pool_idx].get()
try:
records = []
for position, _ in reader_positions_and_indices:
records.append(self._read_record(reader, position))
indices = [idx for _, idx in reader_positions_and_indices]
return list(zip(records, indices))
finally:
self._shard_pools[pool_idx].put(reader)

positions_and_indices = self._split_keys_per_reader(record_keys)
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
num_workers = min(len(positions_and_indices), num_threads)
list_of_kwargs_to_read_records = []
for (
reader_idx,
pool_idx,
reader_positions_and_indices,
) in positions_and_indices.items():
list_of_kwargs_to_read_records.append({
"reader_idx": reader_idx,
"pool_idx": pool_idx,
"reader_positions_and_indices": reader_positions_and_indices,
})
records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = (
Expand All @@ -390,15 +468,22 @@ def read_records(
def __getstate__(self):
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
state = self.__dict__.copy()
del state["_readers"]
state.pop("_shard_pools", None)
return state

def __setstate__(self, state):
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
self.__dict__.update(state)
# We open readers lazily when we need to read from them. Thus, we don't
# need to re-open the same files as before pickling.
self._readers = [None] * len(self._read_instructions)
self._shard_pools = [
LockFreeReaderPool(
ri.filename,
self._reader_options_string,
getattr(self, "_reader_pool_size", 1),
)
for ri in self._read_instructions
]

def __repr__(self) -> str:
"""Storing a hash of paths since paths can be a very long list."""
Expand All @@ -407,10 +492,22 @@ def __repr__(self) -> str:
h.update(p.encode())
return f"ArrayRecordDataSource(hash_of_paths={h.hexdigest()})"

def _peek_readers(self) -> List[Any]:
"""Returns a list of readers (one per shard) or None (for testing only)."""
readers = []
for pool in self._shard_pools:
pooled_readers = pool.peek_readers()
readers.append(pooled_readers[-1] if pooled_readers else None)
return readers


def _get_flag_value(flag: flags.FlagHolder[int]) -> int:
"""Retrieves the flag value or the default if run outside of absl."""
try:
return flag.value
except flags.UnparsedFlagAccessError:
return flag.default


# Alias for backward compatibility with unit tests
BoundedReaderPool = LockFreeReaderPool
Loading
Loading