From 6bc3b1473431d14ea17f8846e841212f61a22820 Mon Sep 17 00:00:00 2001 From: Weida Hong Date: Wed, 8 Apr 2026 17:16:20 +0000 Subject: [PATCH] Thread-safety for building global environment Signed-off-by: Weida Hong --- test/test_threading.py | 24 +++++++++++++++++++++--- torchax/__init__.py | 42 +++++++++++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/test/test_threading.py b/test/test_threading.py index 21670b1..7f59acb 100644 --- a/test/test_threading.py +++ b/test/test_threading.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import threading import unittest @@ -19,12 +20,11 @@ class TestThreading(unittest.TestCase): - def test_access_config_thread(reraise): + def test_access_config_thread(self): torchax.default_env() def task(): - with reraise: - print(torchax.default_env().param) + print(torchax.default_env().param) threads = [] for _ in range(5): @@ -35,6 +35,24 @@ def task(): for thread in threads: thread.join() + def test_thread_safe_init(self): + # Force a reset to simulate pristine state + torchax._env = None + + def task(): + return torchax.default_env() + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(task) for _ in range(32)] + results = [f.result() for f in futures] + + # All threads should return the same environment object + assert len(results) > 0 + lead = results[0] + for r in results: + self.assertIsNotNone(r) + self.assertIs(r, lead) + if __name__ == "__main__": unittest.main() diff --git a/torchax/__init__.py b/torchax/__init__.py index d06c63e..481c4e3 100644 --- a/torchax/__init__.py +++ b/torchax/__init__.py @@ -15,6 +15,7 @@ import contextlib import dataclasses import os +import threading from contextlib import contextmanager from typing import Any @@ -40,6 +41,7 @@ "default_env", "extract_jax", "enable_globally", + "disable_globally", "save_checkpoint", "load_checkpoint", ] @@ -55,15 +57,31 @@ ) # torchax:oss-end -env = None +_env: tensor.Environment | None = None +_env_lock = threading.Lock() -def default_env(): - global env +def default_env() -> tensor.Environment: + """Returns the default environment. - if env is None: - env = tensor.Environment() - return env + The (global) environment is constructed lazily on the first call, + with default configuration. Construct it manually for advanced + configuration. + """ + global _env + + if _env is None: + # The first thread that enters this block will create the environment. + # Other threads will wait for the lock to be released and then return + # the environment. + with _env_lock: + if _env is not None: + return _env + + _env = tensor.Environment() + + assert _env is not None + return _env def extract_jax(mod: torch.nn.Module, env=None, *, dedup_parameters=True): @@ -94,13 +112,15 @@ def jax_func(states, args, kwargs=None): return states, jax_func -def enable_globally(): - env = default_env().enable_torch_modes() - return env +def enable_globally() -> None: + """Enables torchax globally.""" + + default_env().enable_torch_modes() + +def disable_globally() -> None: + """Disables torchax globally.""" -def disable_globally(): - global env default_env().disable_torch_modes()