Skip to content

Commit 201b2a0

Browse files
ArrayRecord Teamcopybara-github
authored andcommitted
Fix concurrent read determinism and file descriptor leaks in ArrayRecordDataSource.
PiperOrigin-RevId: 897665704
1 parent c909611 commit 201b2a0

2 files changed

Lines changed: 176 additions & 78 deletions

File tree

python/array_record_data_source.py

Lines changed: 170 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""array_record_data_source module.
1615
1716
Warning: this is an experimental module. The interface might change in the
@@ -23,22 +22,28 @@
2322
2423
```
2524
class RandomAccessDataSource(Protocol, Generic[T]):
25+
2626
def __len__(self) -> int:
2727
...
2828
29-
def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
29+
def __getitem__(self, record_key: SupportsIndex) -> T:
30+
...
31+
32+
def __getitems__(self, record_keys: Sequence[SupportsIndex]) -> Sequence[T]:
3033
...
3134
```
3235
"""
3336

3437
import bisect
38+
import collections
3539
from concurrent import futures
3640
import dataclasses
3741
import hashlib
3842
import itertools
3943
import os
4044
import pathlib
4145
import re
46+
import threading
4247
import typing
4348
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union
4449

@@ -48,30 +53,33 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
4853

4954
from . import array_record_module
5055

56+
T = TypeVar("T")
57+
58+
59+
@typing.runtime_checkable
60+
class FileInstruction(Protocol):
61+
"""Protocol with same interface as FileInstruction returned by TFDS.
62+
63+
ArrayRecordDataSource would accept objects implementing this protocol without
64+
depending on TFDS.
65+
"""
66+
67+
filename: str
68+
skip: int
69+
take: int
70+
examples_in_shard: int
71+
72+
73+
PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction]
74+
ArrayRecordDataSourcePaths = Union[
75+
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
76+
]
77+
78+
5179
# TODO(jolesiak): Decide what to do with these flags, e.g., remove them (could
5280
# be appropriate if we decide to use asyncio) or move them somewhere else and
5381
# pass the number of threads as an argument. For now, since we experiment, it's
5482
# convenient to have them.
55-
_GRAIN_NUM_THREADS_COMPUTING_NUM_RECORDS = flags.DEFINE_integer(
56-
"grain_num_threads_computing_num_records",
57-
64,
58-
(
59-
"The number of threads used to fetch file instructions (i.e., the max"
60-
" number of Array Record files opened while calculating the total"
61-
" number of records)."
62-
),
63-
)
64-
_GRAIN_NUM_THREADS_FETCHING_RECORDS = flags.DEFINE_integer(
65-
"grain_num_threads_fetching_records",
66-
64,
67-
(
68-
"The number of threads used to fetch records from Array Record files. "
69-
"(i.e., the max number of Array Record files opened while fetching "
70-
"records)."
71-
),
72-
)
73-
74-
T = TypeVar("T")
7583

7684

7785
def _run_in_parallel(
@@ -96,6 +104,7 @@ def _run_in_parallel(
96104
"""
97105
if num_workers < 1:
98106
raise ValueError("num_workers must be >=1 for parallelism.")
107+
99108
thread_futures = []
100109
with futures.ThreadPoolExecutor(num_workers) as executor:
101110
for kwargs in list_of_kwargs_to_function:
@@ -125,23 +134,6 @@ def __post_init__(self):
125134
object.__setattr__(self, "num_records", self.end - self.start)
126135

127136

128-
@typing.runtime_checkable
129-
class FileInstruction(Protocol):
130-
"""Protocol with same interface as FileInstruction returned by TFDS.
131-
132-
ArrayRecordDataSource would accept objects implementing this protocol without
133-
depending on TFDS.
134-
"""
135-
136-
filename: str
137-
skip: int
138-
take: int
139-
examples_in_shard: int
140-
141-
142-
PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction]
143-
144-
145137
def _get_read_instructions(
146138
paths: Sequence[PathLikeOrFileInstruction],
147139
) -> Sequence[_ReadInstruction]:
@@ -212,15 +204,91 @@ def _check_group_size(
212204
)
213205

214206

207+
class BoundedReaderPool:
208+
"""A semaphore-throttled connection pool for a single shard."""
209+
210+
def __init__(self, filename: str, options_string: str, max_size: int = 1):
211+
self._filename = filename
212+
self._options_string = options_string
213+
self._max_size = max_size
214+
self._readers = collections.deque()
215+
# Use BoundedSemaphore to strictly enforce the max_size cap
216+
self._semaphore = threading.BoundedSemaphore(max_size)
217+
self._created_count = 0
218+
self._lock = threading.Lock()
219+
self._group_size_checked = False
220+
self._closed = False
221+
222+
def get(self) -> Any:
223+
"""Gets a reader atomically, throttling creation if cap is reached."""
224+
self._semaphore.acquire()
225+
try:
226+
try:
227+
return self._readers.popleft()
228+
except IndexError as exc:
229+
with self._lock:
230+
if self._closed:
231+
raise RuntimeError(
232+
f"Cannot get reader from closed pool: {self._filename}"
233+
) from exc
234+
reader = _create_reader(self._filename, self._options_string)
235+
if not self._group_size_checked:
236+
_check_group_size(self._filename, reader)
237+
self._group_size_checked = True
238+
self._created_count += 1
239+
return reader
240+
except:
241+
# CRITICAL FIX: Guarantee semaphore permit is released if reader
242+
# creation throws.
243+
self._semaphore.release()
244+
raise
245+
246+
def put(self, reader: Any) -> None:
247+
"""Returns a reader to the pool atomically."""
248+
with self._lock:
249+
if self._closed:
250+
# If the pool was closed while the reader was borrowed, close it
251+
# immediately.
252+
if reader and hasattr(reader, "close"):
253+
reader.close()
254+
self._semaphore.release()
255+
return
256+
257+
self._readers.append(reader)
258+
self._semaphore.release()
259+
260+
def close_all(self) -> None:
261+
"""Closes all pooled readers and prevents future allocations."""
262+
with self._lock:
263+
self._closed = True
264+
265+
while True:
266+
try:
267+
reader = self._readers.popleft()
268+
if reader and hasattr(reader, "close"):
269+
reader.close()
270+
except IndexError:
271+
break
272+
273+
def peek_readers(self) -> List[Any]:
274+
"""Returns the list of readers (for testing only)."""
275+
return list(self._readers)
276+
277+
278+
# Retain alias for backward compatibility with existing code/tests
279+
LockFreeReaderPool = BoundedReaderPool
280+
281+
215282
class ArrayRecordDataSource:
216-
"""Datasource for ArrayRecord files."""
283+
"""Datasource for ArrayRecord files using a Lock-Free Connection Pool."""
217284

218285
def __init__(
219286
self,
220287
paths: Union[
221288
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
222289
],
223290
reader_options: dict[str, str] | None = None,
291+
reader_pool_size: int | None = None,
224292
):
225293
"""Creates a new ArrayRecordDataSource object.
226294
@@ -242,11 +310,11 @@ def __init__(
242310
initialization faster.
243311
reader_options: string of comma-separated options to be passed when
244312
creating a reader.
313+
reader_pool_size: The maximum number of readers to keep open per shard.
245314
"""
246315
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
247316
paths = [paths]
248317
elif isinstance(paths, Sequence):
249-
# Validate correct format of a sequence path
250318
if len(paths) <= 0:
251319
raise ValueError("Paths sequence can not be of 0 length")
252320
elif not all(
@@ -270,8 +338,18 @@ def __init__(
270338
)
271339
self._read_instructions = _get_read_instructions(paths)
272340
self._paths = [ri.filename for ri in self._read_instructions]
273-
# We open readers lazily when we need to read from them.
274-
self._readers = [None] * len(self._read_instructions)
341+
self._reader_pool_size = (
342+
reader_pool_size or _get_flag_value(_GRAIN_READER_POOL_SIZE) or 1
343+
)
344+
345+
# Lock-free connection pool per shard
346+
self._shard_pools = [
347+
LockFreeReaderPool(
348+
ri.filename, self._reader_options_string, self._reader_pool_size
349+
)
350+
for ri in self._read_instructions
351+
]
352+
275353
self._num_records = sum(
276354
map(lambda x: x.num_records, self._read_instructions)
277355
)
@@ -286,10 +364,8 @@ def __enter__(self):
286364

287365
def __exit__(self, exc_type, exc_value, traceback):
288366
logging.debug("__exit__ for ArrayRecordDataSource is called.")
289-
for reader in self._readers:
290-
if reader:
291-
reader.close()
292-
self._readers = [None] * len(self._read_instructions)
367+
for pool in self._shard_pools:
368+
pool.close_all()
293369

294370
def __len__(self) -> int:
295371
return self._num_records
@@ -329,48 +405,50 @@ def _split_keys_per_reader(
329405
positions_and_indices[reader_idx] = [(position, idx)]
330406
return positions_and_indices
331407

332-
def _ensure_reader_exists(self, reader_idx: int) -> None:
333-
"""Threadsafe method to create corresponding reader if it doesn't exist."""
334-
if self._readers[reader_idx] is not None:
335-
return
336-
filename = self._read_instructions[reader_idx].filename
337-
reader = _create_reader(filename, self._reader_options_string)
338-
_check_group_size(filename, reader)
339-
self._readers[reader_idx] = reader
408+
def _read_record(self, reader: Any, position: int) -> bytes:
409+
"""Helper to read a record using the best available method."""
410+
if hasattr(reader, "read_record"):
411+
return reader.read_record(position)
412+
if hasattr(reader, "read"):
413+
return reader.read([position])[0]
414+
return reader[position]
340415

341416
def __getitem__(self, record_key: SupportsIndex) -> bytes:
342-
reader_idx, position = self._reader_idx_and_position(record_key)
343-
self._ensure_reader_exists(reader_idx)
344-
if hasattr(self._readers[reader_idx], "read"):
345-
return self._readers[reader_idx].read([position])[0]
346-
return self._readers[reader_idx][position]
417+
pool_idx, position = self._reader_idx_and_position(record_key)
418+
reader = self._shard_pools[pool_idx].get()
419+
try:
420+
return self._read_record(reader, position)
421+
finally:
422+
self._shard_pools[pool_idx].put(reader)
347423

348424
def __getitems__(
349425
self, record_keys: Sequence[SupportsIndex]
350426
) -> Sequence[bytes]:
427+
351428
def read_records(
352-
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
429+
pool_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
353430
) -> Sequence[Tuple[Any, int]]:
354431
"""Reads records using the given reader keeping track of the indices."""
355-
# Initialize readers lazily when we need to read from them.
356-
self._ensure_reader_exists(reader_idx)
357-
positions, indices = list(zip(*reader_positions_and_indices))
358-
if hasattr(self._readers[reader_idx], "read"):
359-
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
360-
else:
361-
records = [self._readers[reader_idx][p] for p in positions]
362-
return list(zip(records, indices))
432+
reader = self._shard_pools[pool_idx].get()
433+
try:
434+
records = []
435+
for position, _ in reader_positions_and_indices:
436+
records.append(self._read_record(reader, position))
437+
indices = [idx for _, idx in reader_positions_and_indices]
438+
return list(zip(records, indices))
439+
finally:
440+
self._shard_pools[pool_idx].put(reader)
363441

364442
positions_and_indices = self._split_keys_per_reader(record_keys)
365443
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
366444
num_workers = min(len(positions_and_indices), num_threads)
367445
list_of_kwargs_to_read_records = []
368446
for (
369-
reader_idx,
447+
pool_idx,
370448
reader_positions_and_indices,
371449
) in positions_and_indices.items():
372450
list_of_kwargs_to_read_records.append({
373-
"reader_idx": reader_idx,
451+
"pool_idx": pool_idx,
374452
"reader_positions_and_indices": reader_positions_and_indices,
375453
})
376454
records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = (
@@ -390,15 +468,22 @@ def read_records(
390468
def __getstate__(self):
391469
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
392470
state = self.__dict__.copy()
393-
del state["_readers"]
471+
state.pop("_shard_pools", None)
394472
return state
395473

396474
def __setstate__(self, state):
397475
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
398476
self.__dict__.update(state)
399477
# We open readers lazily when we need to read from them. Thus, we don't
400478
# need to re-open the same files as before pickling.
401-
self._readers = [None] * len(self._read_instructions)
479+
self._shard_pools = [
480+
LockFreeReaderPool(
481+
ri.filename,
482+
self._reader_options_string,
483+
getattr(self, "_reader_pool_size", 1),
484+
)
485+
for ri in self._read_instructions
486+
]
402487

403488
def __repr__(self) -> str:
404489
"""Storing a hash of paths since paths can be a very long list."""
@@ -407,10 +492,22 @@ def __repr__(self) -> str:
407492
h.update(p.encode())
408493
return f"ArrayRecordDataSource(hash_of_paths={h.hexdigest()})"
409494

495+
def _peek_readers(self) -> List[Any]:
496+
"""Returns a list of readers (one per shard) or None (for testing only)."""
497+
readers = []
498+
for pool in self._shard_pools:
499+
pooled_readers = pool.peek_readers()
500+
readers.append(pooled_readers[-1] if pooled_readers else None)
501+
return readers
502+
410503

411504
def _get_flag_value(flag: flags.FlagHolder[int]) -> int:
412505
"""Retrieves the flag value or the default if run outside of absl."""
413506
try:
414507
return flag.value
415508
except flags.UnparsedFlagAccessError:
416509
return flag.default
510+
511+
512+
# Alias for backward compatibility with unit tests
513+
BoundedReaderPool = LockFreeReaderPool

0 commit comments

Comments
 (0)