diff --git a/python/array_record_data_source.py b/python/array_record_data_source.py index 117d244..90bab86 100644 --- a/python/array_record_data_source.py +++ b/python/array_record_data_source.py @@ -39,6 +39,7 @@ def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]: import os import pathlib import re +import threading import typing from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union @@ -272,6 +273,7 @@ def __init__( self._paths = [ri.filename for ri in self._read_instructions] # We open readers lazily when we need to read from them. self._readers = [None] * len(self._read_instructions) + self._lock = threading.Lock() self._num_records = sum( map(lambda x: x.num_records, self._read_instructions) ) @@ -333,10 +335,13 @@ def _ensure_reader_exists(self, reader_idx: int) -> None: """Threadsafe method to create corresponding reader if it doesn't exist.""" if self._readers[reader_idx] is not None: return - filename = self._read_instructions[reader_idx].filename - reader = _create_reader(filename, self._reader_options_string) - _check_group_size(filename, reader) - self._readers[reader_idx] = reader + with self._lock: + if self._readers[reader_idx] is not None: + return + filename = self._read_instructions[reader_idx].filename + reader = _create_reader(filename, self._reader_options_string) + _check_group_size(filename, reader) + self._readers[reader_idx] = reader def __getitem__(self, record_key: SupportsIndex) -> bytes: reader_idx, position = self._reader_idx_and_position(record_key) @@ -391,6 +396,7 @@ def __getstate__(self): logging.debug("__getstate__ for ArrayRecordDataSource is called.") state = self.__dict__.copy() del state["_readers"] + del state["_lock"] return state def __setstate__(self, state): @@ -399,6 +405,7 @@ def __setstate__(self, state): # We open readers lazily when we need to read from them. Thus, we don't # need to re-open the same files as before pickling. self._readers = [None] * len(self._read_instructions) + self._lock = threading.Lock() def __repr__(self) -> str: """Storing a hash of paths since paths can be a very long list.""" diff --git a/python/array_record_data_source_test.py b/python/array_record_data_source_test.py index 8977a27..3b06410 100644 --- a/python/array_record_data_source_test.py +++ b/python/array_record_data_source_test.py @@ -17,6 +17,7 @@ import dataclasses import os import pathlib +import time from unittest import mock from absl import flags