From 0c384caf1f43c8230c8795f4d0c77913fbf1c26f Mon Sep 17 00:00:00 2001 From: Salvatore Mandra Date: Mon, 16 Feb 2026 11:19:21 -0800 Subject: [PATCH] Add JAX JIT support and caching for TN contraction Enhances the `contract` function to support JAX's Just-In-Time (JIT) compilation when the 'jax' backend is used. Key changes: - Refactored the core contraction logic into `_contraction_core` to enable JIT compatibility. - Implemented a caching mechanism (`__JAX_CACHE__`) for JIT-compiled contraction cores to improve performance on repeated structures. - Added index re-mapping to integers within `contract` to ensure compatibility with JAX JIT requirements. - Updated type hints and return logic to handle cases with and without explicit tensor arrays. --- tnco/utils/tn.py | 93 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 6 deletions(-) diff --git a/tnco/utils/tn.py b/tnco/utils/tn.py index 04c3919..489cae4 100644 --- a/tnco/utils/tn.py +++ b/tnco/utils/tn.py @@ -32,6 +32,11 @@ from tnco.ordered_frozenset import OrderedFrozenSet from tnco.typing import Array, Index, TensorName +try: + import jax +except ModuleNotFoundError: + jax = None + __all__ = [ 'get_random_contraction_path', 'get_symbol', 'get_einsum_subscripts', 'read_inds', 'fuse', 'decompose_hyper_inds', 'merge_contraction_paths', @@ -859,14 +864,15 @@ def decompose_hyper_inds( def contract( path: Iterable[Tuple[int, int]], - ts_inds: Iterable[List[Index]], + ts_inds: Iterable[Iterable[Index]], output_inds: Optional[Iterable[Index]] = None, arrays: Optional[Iterable[Array]] = None, dims: Optional[Union[int, Dict[Index, int]]] = None, *, backend: Optional[str] = None, verbose: Optional[int] = False -) -> Tuple[List[List[Index]], FrozenSet[Index], Optional[List[Array]]]: +) -> Union[Tuple[List[Tuple[Index, ...]], FrozenSet[Index]], Tuple[List[Tuple[ + Index, ...]], FrozenSet[Index], List[Array]]]: """Contracts a tensor network. Contracts a tensor network following a given path. @@ -967,8 +973,67 @@ def contract( if not output_inds.issubset(mit.flatten(ts_inds)): raise ValueError("'output_inds' is not consistent with 'ts_inds'.") + # Convert path + path = tuple(map(tuple, map(sorted, path))) + + # Re-map indices + inv_inds_map = dict(enumerate(mit.unique_everseen(mit.flatten(ts_inds)))) + inds_map = dict((y, x) for x, y in inv_inds_map.items()) + ts_inds = list(tuple(map(inds_map.get, xs)) for xs in ts_inds) + output_inds = frozenset(map(inds_map.get, output_inds)) + hyper_count = dict(zip(map(inds_map.get, hyper_count), + hyper_count.values())) + + # Check cache + if backend == 'jax' and arrays is not None: + if jax is None: + raise ModuleNotFoundError("'jax' not installed or not found.") + + # Get core function + core_ = _get_jit_contract_core(path, tuple(ts_inds), output_inds, + tuple(sorted(hyper_count.items())), + verbose) + + # Complete contraction + ts_inds, output_inds, arrays = core_(arrays) + + else: + # Complete contraction + ts_inds, output_inds, arrays = _contraction_core( + path, ts_inds, output_inds, hyper_count, verbose, arrays) + + # Revert to the original indices + ts_inds = list(tuple(map(inv_inds_map.get, map(int, xs))) for xs in ts_inds) + output_inds = frozenset(map(inv_inds_map.get, map(int, output_inds))) + + # Return + return (ts_inds, output_inds) if arrays is None else (ts_inds, output_inds, + arrays) + + +def _contraction_core( + path: Tuple[Tuple[int, int], + ...], ts_inds: List[Tuple[int, + ...]], output_inds: FrozenSet[int], + hyper_count: Dict[int, int], verbose: int, arrays: Optional[List[Array]] +) -> Tuple[List[Tuple[int, ...]], Tuple[int, ...], Optional[List[Array]]]: + """Internal contraction core. + + Args: + path: Contraction path. + ts_inds: List of indices for each tensor. + output_inds: Output indices. + hyper_count: Hyper-count for each index. + verbose: Verbosity level. + arrays: List of tensor arrays (optional). + + Returns: + A tuple containing the updated ``ts_inds``, ``output_inds``, and + ``arrays``. + """ + # Contract - for x, y in track(map(sorted, path), + for x, y in track(path, console=Console(stderr=True), description="Contracting...", total=len(path), @@ -1019,8 +1084,24 @@ def contract( ts_inds.append(zs) # Get the new output indices - output_inds = output_inds.intersection(mit.flatten(ts_inds)) + output_inds = tuple(output_inds.intersection(mit.flatten(ts_inds))) # Return new arrays - return (ts_inds, output_inds) if arrays is None else (ts_inds, output_inds, - arrays) + return ts_inds, output_inds, arrays + + +if jax: + + @fts.lru_cache(maxsize=1024) + def _get_jit_contract_core(path: Tuple[Tuple[int, int], ...], + ts_inds: Tuple[Tuple[int, ...], + ...], output_inds: FrozenSet[int], + hyper_count: Tuple[Tuple[int, int], + ...], verbose: int): + """Returns a jit-compiled contraction core.""" + + def core(arrays): + return _contraction_core(path, list(ts_inds), output_inds, + dict(hyper_count), verbose, arrays) + + return jax.jit(core)