Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 87 additions & 6 deletions tnco/utils/tn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
from tnco.ordered_frozenset import OrderedFrozenSet
from tnco.typing import Array, Index, TensorName

try:
import jax
except ModuleNotFoundError:
jax = None

__all__ = [
'get_random_contraction_path', 'get_symbol', 'get_einsum_subscripts',
'read_inds', 'fuse', 'decompose_hyper_inds', 'merge_contraction_paths',
Expand Down Expand Up @@ -859,14 +864,15 @@ def decompose_hyper_inds(

def contract(
path: Iterable[Tuple[int, int]],
ts_inds: Iterable[List[Index]],
ts_inds: Iterable[Iterable[Index]],
output_inds: Optional[Iterable[Index]] = None,
arrays: Optional[Iterable[Array]] = None,
dims: Optional[Union[int, Dict[Index, int]]] = None,
*,
backend: Optional[str] = None,
verbose: Optional[int] = False
) -> Tuple[List[List[Index]], FrozenSet[Index], Optional[List[Array]]]:
) -> Union[Tuple[List[Tuple[Index, ...]], FrozenSet[Index]], Tuple[List[Tuple[
Index, ...]], FrozenSet[Index], List[Array]]]:
"""Contracts a tensor network.

Contracts a tensor network following a given path.
Expand Down Expand Up @@ -967,8 +973,67 @@ def contract(
if not output_inds.issubset(mit.flatten(ts_inds)):
raise ValueError("'output_inds' is not consistent with 'ts_inds'.")

# Convert path
path = tuple(map(tuple, map(sorted, path)))

# Re-map indices
inv_inds_map = dict(enumerate(mit.unique_everseen(mit.flatten(ts_inds))))
inds_map = dict((y, x) for x, y in inv_inds_map.items())
ts_inds = list(tuple(map(inds_map.get, xs)) for xs in ts_inds)
output_inds = frozenset(map(inds_map.get, output_inds))
hyper_count = dict(zip(map(inds_map.get, hyper_count),
hyper_count.values()))

# Check cache
if backend == 'jax' and arrays is not None:
if jax is None:
raise ModuleNotFoundError("'jax' not installed or not found.")

# Get core function
core_ = _get_jit_contract_core(path, tuple(ts_inds), output_inds,
tuple(sorted(hyper_count.items())),
verbose)

# Complete contraction
ts_inds, output_inds, arrays = core_(arrays)

else:
# Complete contraction
ts_inds, output_inds, arrays = _contraction_core(
path, ts_inds, output_inds, hyper_count, verbose, arrays)

# Revert to the original indices
ts_inds = list(tuple(map(inv_inds_map.get, map(int, xs))) for xs in ts_inds)
output_inds = frozenset(map(inv_inds_map.get, map(int, output_inds)))

# Return
return (ts_inds, output_inds) if arrays is None else (ts_inds, output_inds,
arrays)


def _contraction_core(
path: Tuple[Tuple[int, int],
...], ts_inds: List[Tuple[int,
...]], output_inds: FrozenSet[int],
hyper_count: Dict[int, int], verbose: int, arrays: Optional[List[Array]]
) -> Tuple[List[Tuple[int, ...]], Tuple[int, ...], Optional[List[Array]]]:
"""Internal contraction core.

Args:
path: Contraction path.
ts_inds: List of indices for each tensor.
output_inds: Output indices.
hyper_count: Hyper-count for each index.
verbose: Verbosity level.
arrays: List of tensor arrays (optional).

Returns:
A tuple containing the updated ``ts_inds``, ``output_inds``, and
``arrays``.
"""

# Contract
for x, y in track(map(sorted, path),
for x, y in track(path,
console=Console(stderr=True),
description="Contracting...",
total=len(path),
Expand Down Expand Up @@ -1019,8 +1084,24 @@ def contract(
ts_inds.append(zs)

# Get the new output indices
output_inds = output_inds.intersection(mit.flatten(ts_inds))
output_inds = tuple(output_inds.intersection(mit.flatten(ts_inds)))

# Return new arrays
return (ts_inds, output_inds) if arrays is None else (ts_inds, output_inds,
arrays)
return ts_inds, output_inds, arrays


if jax:

@fts.lru_cache(maxsize=1024)
def _get_jit_contract_core(path: Tuple[Tuple[int, int], ...],
ts_inds: Tuple[Tuple[int, ...],
...], output_inds: FrozenSet[int],
hyper_count: Tuple[Tuple[int, int],
...], verbose: int):
"""Returns a jit-compiled contraction core."""

def core(arrays):
return _contraction_core(path, list(ts_inds), output_inds,
dict(hyper_count), verbose, arrays)

return jax.jit(core)